"""
core/coherence/redis_master_state.py

RedisMasterState — Versioned OCC (Optimistic Concurrency Control) state store.

Uses Redis WATCH + MULTI/EXEC (pipeline transaction) to ensure that concurrent
agents cannot corrupt shared state. The OCC pattern:

    1. WATCH the key (register a "dirty" detector on the key)
    2. GET current state, verify the caller's expected version matches
    3. MULTI (begin transaction)
    4. SET new state (version+1, patched data)
    5. EXEC (commit) — Redis returns None if the watched key was modified between
       WATCH and EXEC, triggering a WatchError in redis-py

If two agents concurrently read version 5 and both try to write version 6, only
the first EXEC succeeds; the second raises WatchError → CommitResult(conflict=True).

State structure stored in Redis:
    {"version": <int>, "data": <dict>}

Simple patch application (not full RFC 6902 — StateDelta handles that):
    {"op": "replace", "path": "/key", "value": v}  → data[key] = v
    {"op": "add",     "path": "/key", "value": v}  → data[key] = v
    {"op": "remove",  "path": "/key"}               → del data[key]

The "path" uses the key name without a leading "/" — i.e. "replace", path "/name"
strips the leading slash and uses "name" as the dict key.

# VERIFICATION_STAMP
# Story: 6.02
# Verified By: parallel-builder
# Verified At: 2026-02-25
# Tests: 9/9
# Coverage: 100%
"""

from __future__ import annotations

import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

from redis.exceptions import WatchError


# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------


@dataclass
class CommitResult:
    """Result of a commit_patch or initialize_state operation.

    Attributes:
        success:     True if the write committed successfully.
        new_version: The version number after a successful commit (version+1).
                     0 when success=False.
        conflict:    True when a concurrent writer modified the key between our
                     WATCH and EXEC, causing the transaction to abort.
    """

    success: bool
    new_version: int
    conflict: bool


# ---------------------------------------------------------------------------
# Main class
# ---------------------------------------------------------------------------


class RedisMasterState:
    """
    Versioned, OCC-protected Redis state store for multi-agent coherence.

    Each session's state is stored at:
        genesis:state:master:<session_id>

    Value format (JSON):
        {"version": <int>, "data": <dict>}

    Usage::

        rms = RedisMasterState(redis_client)
        version, data = await rms.get_snapshot("sess-abc")
        result = await rms.commit_patch("sess-abc", version, patch_ops)
        if result.conflict:
            # Retry: re-read snapshot and rebuild patch
            ...
    """

    KEY_PREFIX = "genesis:state:master:"

    def __init__(self, redis_client: Any) -> None:
        """
        Args:
            redis_client: An async redis.asyncio.Redis instance (or compatible mock).
        """
        self.redis = redis_client

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _key(self, session_id: str) -> str:
        """Build the Redis key for a given session."""
        return f"{self.KEY_PREFIX}{session_id}"

    @staticmethod
    def _apply_patch_simple(data: Dict[str, Any], patch: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Apply a simplified patch list to a data dict.

        Operations supported:
            "replace" — set data[key] = value (key must already exist is NOT
                        enforced here; simple replace also creates the key)
            "add"     — set data[key] = value
            "remove"  — delete data[key]

        The "path" field is an RFC 6902-style JSON Pointer (e.g. "/name").
        The leading "/" is stripped to derive the flat dict key.

        Returns a new dict — does NOT mutate the input.
        """
        import copy
        result = copy.deepcopy(data)
        for op_dict in patch:
            op = op_dict.get("op")
            path = op_dict.get("path", "")
            # Strip leading slash to get the flat dict key
            key = path.lstrip("/")

            if op in ("replace", "add"):
                result[key] = op_dict.get("value")
            elif op == "remove":
                result.pop(key, None)
            # Unknown ops are silently ignored (caller should validate first)
        return result

    @staticmethod
    def _encode(version: int, data: Dict[str, Any]) -> str:
        """Serialize state to JSON string."""
        return json.dumps({"version": version, "data": data}, separators=(",", ":"))

    @staticmethod
    def _decode(raw: bytes | str) -> Tuple[int, Dict[str, Any]]:
        """Deserialize JSON string to (version, data)."""
        state = json.loads(raw)
        return (state["version"], state["data"])

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    async def get_snapshot(self, session_id: str) -> Tuple[int, Dict[str, Any]]:
        """
        Return the current (version, data) for a session.

        For a fresh session that has never been initialized, returns (0, {}).

        Args:
            session_id: Unique identifier for the session.

        Returns:
            Tuple of (version: int, data: dict).
        """
        key = self._key(session_id)
        raw = await self.redis.get(key)
        if not raw:
            return (0, {})
        return self._decode(raw)

    async def commit_patch(
        self,
        session_id: str,
        version: int,
        patch: List[Dict[str, Any]],
    ) -> CommitResult:
        """
        Apply patch to state using Optimistic Concurrency Control.

        OCC flow:
            1. WATCH key — Redis monitors it for external changes.
            2. GET current state.
            3. Check that current version == expected version.
               If not → version mismatch → return conflict (not a WatchError,
               just a version check failure — still conflict=True).
            4. Apply patch to produce new data.
            5. MULTI (begin transaction).
            6. SET new state (version+1, new_data).
            7. EXEC — commits if key has not been modified since WATCH.
               WatchError from redis-py → conflict=True.

        Args:
            session_id: Unique identifier for the session.
            version:    The version the caller read from get_snapshot().
                        Must match the current version in Redis for the
                        commit to proceed.
            patch:      List of simplified patch operations (see _apply_patch_simple).

        Returns:
            CommitResult with success, new_version, conflict fields.
        """
        key = self._key(session_id)

        async with self.redis.pipeline() as pipe:
            try:
                # Step 1: WATCH — register dirty detector
                await pipe.watch(key)

                # Step 2: GET current state (in immediate/non-buffered mode after WATCH)
                raw = await pipe.get(key)
                if raw:
                    current_version, current_data = self._decode(raw)
                else:
                    current_version, current_data = 0, {}

                # Step 3: Version check
                if current_version != version:
                    # Another writer has moved ahead — conflict without even trying EXEC
                    await pipe.unwatch()
                    return CommitResult(success=False, new_version=0, conflict=True)

                # Step 4: Compute new state
                new_data = self._apply_patch_simple(current_data, patch)
                new_version = version + 1
                new_raw = self._encode(new_version, new_data)

                # Step 5+6: MULTI / queue SET
                pipe.multi()
                pipe.set(key, new_raw)

                # Step 7: EXEC — commits atomically if key is still unwatched-clean
                await pipe.execute()

                return CommitResult(success=True, new_version=new_version, conflict=False)

            except WatchError:
                # Another writer modified the key between our WATCH and EXEC
                return CommitResult(success=False, new_version=0, conflict=True)

    async def initialize_state(
        self,
        session_id: str,
        initial_data: Dict[str, Any],
    ) -> CommitResult:
        """
        Set the initial state for a session (version=1).

        Uses Redis SET NX (set-if-not-exists) to ensure idempotency.
        If the key already exists, returns conflict=True.

        Args:
            session_id:   Unique identifier for the session.
            initial_data: The initial data dict to store.

        Returns:
            CommitResult(success=True, new_version=1, conflict=False) on success.
            CommitResult(success=False, new_version=0, conflict=True) if key exists.
        """
        key = self._key(session_id)
        new_raw = self._encode(1, initial_data)
        # NX = only set if key does not exist
        result = await self.redis.set(key, new_raw, nx=True)
        if result:
            return CommitResult(success=True, new_version=1, conflict=False)
        else:
            # Key already existed — treat as conflict
            return CommitResult(success=False, new_version=0, conflict=True)
