"""
core/coherence/occ_commit.py

OCCCommitEngine — Barrier Sync + OCC Write.

Orchestrates the full commit cycle for a multi-agent swarm round:

    1. BARRIER: wait for all expected workers to submit their StateDelta
       proposals to the StagingArea (uses StagingArea.wait_for_all).
    2. MERGE: pass collected deltas through the SemanticMergeInterceptor to
       resolve contradictions and produce a single merged patch.
    3. OCC WRITE: commit the merged patch to RedisMasterState with OCC
       retry logic (max 3 attempts on version conflict).
    4. SAGA: record outcome via saga_writer if provided.
    5. LEDGER: emit a merge event to ColdLedger on success if provided.

OCC retry loop (MAX_RETRIES = 3):
    On each attempt, the engine re-reads the current state snapshot
    (version + data) from RedisMasterState and re-tries commit_patch.
    If all retries are exhausted the commit returns
    OccCommitResult(success=False, saga_status="conflict_exhausted").

Dependency injection:
    All collaborators are injected via the constructor.  None of them are
    required at import time — pass mocks in tests.

Usage::

    engine = OCCCommitEngine(
        staging_area=staging,
        merge_interceptor=merger,
        master_state=rms,
        saga_writer=saga,   # optional
        cold_ledger=ledger, # optional
    )
    result = await engine.execute_commit(session_id="sess-abc", expected_workers=3)
    assert result.success

# VERIFICATION_STAMP
# Story: 6.06
# Verified By: parallel-builder
# Verified At: 2026-02-25
# Tests: 14/14
# Coverage: 100%
"""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import Any, List, Optional

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

MAX_RETRIES = 3  # Maximum OCC retry attempts on version conflict


# ---------------------------------------------------------------------------
# Result dataclass
# ---------------------------------------------------------------------------


@dataclass
class OccCommitResult:
    """
    Result of an OCCCommitEngine.execute_commit() call.

    Attributes:
        success:      True when the patch was committed successfully.
        merged_patch: The merged patch list produced by the merge interceptor.
                      Empty list when success=False.
        version:      The new version number after a successful commit
                      (i.e. old_version + 1).  The last-known version on
                      failure.
        retries:      Number of OCC retry attempts made (0-based: 0 means the
                      first attempt succeeded or failed without a conflict retry).
        saga_status:  Saga outcome string:
                        "completed"          — committed successfully.
                        "merge_failed"       — merge interceptor returned failure.
                        "conflict_exhausted" — all OCC retries failed.
                        "unknown"            — initial/unset value (should not
                                               appear in a finished result).
    """

    success: bool
    merged_patch: List[Any] = field(default_factory=list)
    version: int = 0
    retries: int = 0
    saga_status: str = "unknown"


# ---------------------------------------------------------------------------
# Engine
# ---------------------------------------------------------------------------


class OCCCommitEngine:
    """
    Barrier-sync + OCC commit coordinator for multi-agent swarm rounds.

    Constructor args:
        staging_area:      StagingArea — collects worker deltas in Redis.
        merge_interceptor: Object with an async merge(deltas, current_state, version)
                           method that returns an object with .success and
                           .merged_patch attributes.
        master_state:      RedisMasterState — versioned OCC state store.
                           Must expose:
                             - async get_snapshot(session_id) -> (version, data)
                             - async commit_patch(session_id, version, patch)
                               -> CommitResult(success, new_version, conflict)
        saga_writer:       Optional object with async close_saga(session_id, status).
        cold_ledger:       Optional object with sync write_event(session_id, event, data).
    """

    def __init__(
        self,
        staging_area: Any,
        merge_interceptor: Any,
        master_state: Any,
        saga_writer: Optional[Any] = None,
        cold_ledger: Optional[Any] = None,
    ) -> None:
        self.staging = staging_area
        self.merger = merge_interceptor
        self.master = master_state
        self.saga_writer = saga_writer
        self.ledger = cold_ledger

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    async def execute_commit(
        self,
        session_id: str,
        expected_workers: int,
    ) -> OccCommitResult:
        """
        Run the full barrier-sync + OCC commit cycle.

        Steps:
            1. Wait for all expected_workers to submit deltas (barrier).
            2. Merge all deltas via merge_interceptor.
            3. Attempt OCC commit with up to MAX_RETRIES retries on conflict.
            4. Record outcome to saga_writer (if provided).
            5. Emit merge event to cold_ledger on success (if provided).

        Args:
            session_id:       Session identifier shared by all workers.
            expected_workers: How many worker delta submissions to wait for.

        Returns:
            OccCommitResult describing the outcome of the commit cycle.
        """
        # ---------------------------------------------------------------
        # Step 1: Barrier — collect all worker deltas
        # ---------------------------------------------------------------
        logger.debug(
            "OCCCommitEngine: waiting for %d workers on session %s",
            expected_workers,
            session_id,
        )
        deltas = await self.staging.wait_for_all(
            session_id,
            expected_workers,
            timeout_ms=60_000,
        )
        logger.debug(
            "OCCCommitEngine: collected %d delta(s) for session %s",
            len(deltas),
            session_id,
        )

        # ---------------------------------------------------------------
        # OCC retry loop
        # ---------------------------------------------------------------
        version = 0  # will be updated on each attempt

        for attempt in range(MAX_RETRIES):
            # Step 2a: Rehydrate current state + version from master store
            version, current_state = await self.master.get_snapshot(session_id)

            # Step 2: Merge deltas into a single patch
            merge_result = await self.merger.merge(deltas, current_state, version)

            if not merge_result.success:
                logger.warning(
                    "OCCCommitEngine: merge failed on session %s (attempt %d)",
                    session_id,
                    attempt,
                )
                await self._close_saga(session_id, "merge_failed")
                return OccCommitResult(
                    success=False,
                    merged_patch=[],
                    version=version,
                    retries=attempt,
                    saga_status="merge_failed",
                )

            # Step 3: OCC conditional write
            commit_result = await self.master.commit_patch(
                session_id, version, merge_result.merged_patch
            )

            if commit_result.success:
                # Success path
                new_version = version + 1
                logger.info(
                    "OCCCommitEngine: committed session %s at version %d "
                    "(attempt %d)",
                    session_id,
                    new_version,
                    attempt,
                )
                await self._close_saga(session_id, "completed")
                self._emit_ledger_event(session_id, "saga_committed")
                return OccCommitResult(
                    success=True,
                    merged_patch=merge_result.merged_patch,
                    version=new_version,
                    retries=attempt,
                    saga_status="completed",
                )

            # Version conflict — log and retry
            logger.warning(
                "OCCCommitEngine: OCC conflict on session %s, attempt %d/%d",
                session_id,
                attempt + 1,
                MAX_RETRIES,
            )

        # ---------------------------------------------------------------
        # All retries exhausted
        # ---------------------------------------------------------------
        logger.error(
            "OCCCommitEngine: all %d OCC retries exhausted for session %s",
            MAX_RETRIES,
            session_id,
        )
        await self._close_saga(session_id, "conflict_exhausted")
        return OccCommitResult(
            success=False,
            merged_patch=[],
            version=version,
            retries=MAX_RETRIES,
            saga_status="conflict_exhausted",
        )

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    async def _close_saga(self, session_id: str, status: str) -> None:
        """Call saga_writer.close_saga if a writer is configured."""
        if self.saga_writer is not None:
            try:
                await self.saga_writer.close_saga(session_id, status)
            except Exception as exc:  # pragma: no cover
                logger.warning(
                    "OCCCommitEngine: saga_writer.close_saga raised: %s", exc
                )

    def _emit_ledger_event(self, session_id: str, event: str) -> None:
        """Call cold_ledger.write_event if a ledger is configured."""
        if self.ledger is not None:
            try:
                self.ledger.write_event(session_id, event, {})
            except Exception as exc:  # pragma: no cover
                logger.warning(
                    "OCCCommitEngine: cold_ledger.write_event raised: %s", exc
                )
