"""
core/coherence/coherence_orchestrator.py

CoherenceOrchestrator — 8-Step Coherence Execution Flow Coordinator.

Runs the complete coherence pipeline for a multi-agent swarm round:

    Step 1  MAP     — Push task DAG to Redis Stream via TaskDAGPusher.
    Step 2  CLAIM   — Workers auto-claim tasks via XREADGROUP (background).
    Step 3  PROPOSE — Workers run and submit StateDelta to StagingArea (background).
    Step 4  BARRIER — Wait for all worker deltas via StagingArea.wait_for_all.
    Step 5  REDUCE  — OCCCommitEngine collects deltas + SemanticMergeInterceptor merges.
    Step 6  COMMIT  — OCCCommitEngine issues OCC conditional write to RedisMasterState.
    Step 7  RELEASE — Execute merged side-effects (log to events.jsonl).
    Step 8  SCAR    — On any worker failure, write scar event to events.jsonl.

Steps 2 and 3 are implicit — workers are background consumers that run
independently once the DAG is pushed to the stream. The orchestrator's
role is to coordinate steps 1, 4-8.

Total orchestration timeout: 120 seconds (asyncio.wait_for on execute()).

Dependency injection:
    All collaborators are injected via the constructor. None are required
    at import time — pass mocks in tests.

Usage::

    orchestrator = CoherenceOrchestrator(
        dag_pusher=pusher,
        staging_area=staging,
        occ_engine=engine,
        bulkhead=guard,
    )
    result = await orchestrator.execute("sess-abc", tasks=[
        {"task_type": "research", "payload": {"query": "..."}},
        {"task_type": "synthesize", "payload": {}},
    ])
    assert result.success

# VERIFICATION_STAMP
# Story: 6.08
# Verified By: parallel-builder
# Verified At: 2026-02-25
# Tests: 22/22
# Coverage: 100%
"""

from __future__ import annotations

import asyncio
import json
import logging
import os
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional
from uuid import uuid4

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Events log path
# ---------------------------------------------------------------------------

EVENTS_LOG_PATH = Path("/mnt/e/genesis-system/data/observability/events.jsonl")

# ---------------------------------------------------------------------------
# Total orchestration timeout (seconds)
# ---------------------------------------------------------------------------

ORCHESTRATION_TIMEOUT_SECONDS = 120


# ---------------------------------------------------------------------------
# Result dataclass
# ---------------------------------------------------------------------------


@dataclass
class CoherenceResult:
    """
    Result of a CoherenceOrchestrator.execute() call.

    Attributes:
        success:          True when the full pipeline completed with a
                          committed OCC write.
        committed_state:  The merged patch list from the OCC commit engine.
                          Empty dict when success=False.
        saga_id:          UUID4 identifier for this orchestration run.
        workers_succeeded: Number of worker tasks (from BulkheadResult)
                          that completed successfully.
        workers_failed:   Number of worker tasks that raised exceptions.
    """

    success: bool
    committed_state: dict
    saga_id: str
    workers_succeeded: int
    workers_failed: int


# ---------------------------------------------------------------------------
# CoherenceOrchestrator
# ---------------------------------------------------------------------------


class CoherenceOrchestrator:
    """
    Coordinates the complete 8-step coherence execution pipeline.

    Constructor args:
        dag_pusher:   TaskDAGPusher — pushes task DAGs to Redis Stream.
                      Must expose: async push_dag(session_id, tasks) -> list[str]
        staging_area: StagingArea — collects worker deltas in Redis Hash.
                      Must expose: async wait_for_all(session_id, expected_count, timeout_ms)
        occ_engine:   OCCCommitEngine — barrier sync + OCC write coordinator.
                      Must expose: async execute_commit(session_id, expected_workers) -> OccCommitResult
        bulkhead:     BulkheadGuard — asyncio.gather exception isolation.
                      Must expose: async run_with_bulkhead(tasks) -> list[BulkheadResult]
        qdrant_client: Optional Qdrant client (reserved for future use).
    """

    def __init__(
        self,
        dag_pusher: Optional[Any] = None,
        staging_area: Optional[Any] = None,
        occ_engine: Optional[Any] = None,
        bulkhead: Optional[Any] = None,
        qdrant_client: Optional[Any] = None,
    ) -> None:
        self.dag_pusher = dag_pusher
        self.staging_area = staging_area
        self.occ_engine = occ_engine
        self.bulkhead = bulkhead
        self.qdrant_client = qdrant_client

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    async def execute(
        self,
        session_id: str,
        tasks: list[dict],
    ) -> CoherenceResult:
        """
        Run the complete 8-step coherence pipeline with a 120-second timeout.

        Wraps _execute_pipeline() in asyncio.wait_for so the caller always
        gets a result within ORCHESTRATION_TIMEOUT_SECONDS.

        Args:
            session_id: Session identifier shared by all workers.
            tasks:      List of task dicts, each describing one unit of work.
                        Each task may include keys: task_type, payload, tier, priority.

        Returns:
            CoherenceResult describing the final pipeline outcome.

        Raises:
            asyncio.TimeoutError: If the pipeline exceeds 120 seconds.
                                  The caller should treat this as a failure.
        """
        saga_id = str(uuid4())
        logger.info(
            "CoherenceOrchestrator: starting saga %s for session %s (%d tasks)",
            saga_id,
            session_id,
            len(tasks),
        )

        try:
            result = await asyncio.wait_for(
                self._execute_pipeline(session_id, tasks, saga_id),
                timeout=ORCHESTRATION_TIMEOUT_SECONDS,
            )
        except asyncio.TimeoutError:
            logger.error(
                "CoherenceOrchestrator: saga %s TIMED OUT after %ds",
                saga_id,
                ORCHESTRATION_TIMEOUT_SECONDS,
            )
            # Write scar for timeout
            self._write_event(
                "scar",
                {
                    "saga_id": saga_id,
                    "session_id": session_id,
                    "reason": "orchestration_timeout",
                    "timeout_seconds": ORCHESTRATION_TIMEOUT_SECONDS,
                },
            )
            raise

        return result

    # ------------------------------------------------------------------
    # Private pipeline
    # ------------------------------------------------------------------

    async def _execute_pipeline(
        self,
        session_id: str,
        tasks: list[dict],
        saga_id: str,
    ) -> CoherenceResult:
        """
        Internal implementation of the 8-step pipeline.

        Each step is annotated with its name so the execution order is
        visible at a glance in logs and in code review.
        """

        # ------------------------------------------------------------------
        # Step 1: MAP — Push task DAG to Redis Stream
        # ------------------------------------------------------------------
        logger.debug("CoherenceOrchestrator [Step 1: MAP] saga=%s", saga_id)
        if self.dag_pusher is not None:
            entry_ids = await self.dag_pusher.push_dag(session_id, tasks)
            logger.debug(
                "CoherenceOrchestrator [MAP] pushed %d entries: %s",
                len(entry_ids),
                entry_ids,
            )
        else:
            logger.debug(
                "CoherenceOrchestrator [MAP] dag_pusher=None, skipping stream push"
            )

        # ------------------------------------------------------------------
        # Step 2: CLAIM — Workers auto-claim via XREADGROUP (background)
        # ------------------------------------------------------------------
        # Implicit: background worker processes listen on the Redis Stream
        # Consumer Group and claim tasks via XREADGROUP. The orchestrator
        # does not drive this step directly — it simply proceeds.
        logger.debug(
            "CoherenceOrchestrator [Step 2: CLAIM] saga=%s (implicit — workers listening)",
            saga_id,
        )

        # ------------------------------------------------------------------
        # Step 3: PROPOSE — Workers submit StateDelta to StagingArea (background)
        # ------------------------------------------------------------------
        # Implicit: worker processes run their tasks and call
        # StagingArea.submit_delta() for each completed unit. In tests,
        # the BulkheadGuard simulates this step explicitly.
        logger.debug(
            "CoherenceOrchestrator [Step 3: PROPOSE] saga=%s (implicit — workers proposing)",
            saga_id,
        )

        # Run workers through bulkhead if a bulkhead guard is provided.
        # This simulates the "worker execution" phase and gives us
        # workers_succeeded / workers_failed counts.
        workers_succeeded = 0
        workers_failed = 0

        if self.bulkhead is not None:
            # Build (agent_id, coroutine) tuples for each task.
            # In production, these coroutines invoke the real worker logic.
            # In tests, they are replaced with mocked coroutines.
            worker_coros = [
                (
                    task.get("task_type", f"task-{i}"),
                    self._run_worker(session_id, task, i),
                )
                for i, task in enumerate(tasks)
            ]
            bulkhead_results = await self.bulkhead.run_with_bulkhead(worker_coros)
            # Close any coroutines that were not consumed by the bulkhead
            # (e.g. when run_with_bulkhead is mocked in tests).
            for _, coro in worker_coros:
                if hasattr(coro, "close"):
                    coro.close()
            workers_succeeded = sum(1 for r in bulkhead_results if r.success)
            workers_failed = sum(1 for r in bulkhead_results if not r.success)

            logger.debug(
                "CoherenceOrchestrator [PROPOSE] saga=%s workers: %d ok, %d failed",
                saga_id,
                workers_succeeded,
                workers_failed,
            )
        else:
            # No bulkhead — assume all workers succeeded for pipeline purposes.
            workers_succeeded = len(tasks)
            workers_failed = 0

        # ------------------------------------------------------------------
        # Step 4: BARRIER — Wait for all worker deltas to arrive in StagingArea
        # ------------------------------------------------------------------
        logger.debug("CoherenceOrchestrator [Step 4: BARRIER] saga=%s", saga_id)
        if self.staging_area is not None:
            deltas = await self.staging_area.wait_for_all(
                session_id,
                expected_count=len(tasks),
                timeout_ms=60_000,
            )
            logger.debug(
                "CoherenceOrchestrator [BARRIER] saga=%s collected %d delta(s)",
                saga_id,
                len(deltas),
            )
        else:
            logger.debug(
                "CoherenceOrchestrator [BARRIER] staging_area=None, skipping barrier"
            )

        # ------------------------------------------------------------------
        # Step 5: REDUCE  — OCCCommitEngine merges deltas (SemanticMergeInterceptor)
        # Step 6: COMMIT  — OCC conditional write to RedisMasterState
        # (Both steps are encapsulated inside OCCCommitEngine.execute_commit)
        # ------------------------------------------------------------------
        logger.debug(
            "CoherenceOrchestrator [Step 5: REDUCE / Step 6: COMMIT] saga=%s", saga_id
        )

        occ_result = None
        committed_state: dict = {}

        if self.occ_engine is not None:
            occ_result = await self.occ_engine.execute_commit(
                session_id,
                expected_workers=len(tasks),
            )
            if occ_result.success:
                # merged_patch is a list (JSON Patch ops). We wrap it in a
                # dict so CoherenceResult.committed_state is always a dict.
                committed_state = {"merged_patch": occ_result.merged_patch}
                logger.info(
                    "CoherenceOrchestrator [COMMIT] saga=%s committed at version %d",
                    saga_id,
                    occ_result.version,
                )
            else:
                logger.warning(
                    "CoherenceOrchestrator [COMMIT] saga=%s OCC failed: %s",
                    saga_id,
                    occ_result.saga_status,
                )
        else:
            logger.debug(
                "CoherenceOrchestrator [COMMIT] occ_engine=None, skipping OCC commit"
            )

        # ------------------------------------------------------------------
        # Step 7: RELEASE — Execute merged side-effects (log to events.jsonl)
        # ------------------------------------------------------------------
        logger.debug("CoherenceOrchestrator [Step 7: RELEASE] saga=%s", saga_id)
        commit_success = occ_result.success if occ_result is not None else True

        if commit_success:
            self._write_event(
                "release",
                {
                    "saga_id": saga_id,
                    "session_id": session_id,
                    "workers_succeeded": workers_succeeded,
                    "workers_failed": workers_failed,
                    "committed_state": committed_state,
                },
            )

        # ------------------------------------------------------------------
        # Step 8: SCAR — On any worker failure, write scar event to events.jsonl
        # ------------------------------------------------------------------
        logger.debug("CoherenceOrchestrator [Step 8: SCAR] saga=%s", saga_id)
        if workers_failed > 0 or not commit_success:
            scar_payload: dict = {
                "saga_id": saga_id,
                "session_id": session_id,
                "workers_failed": workers_failed,
                "workers_succeeded": workers_succeeded,
            }
            if occ_result is not None and not occ_result.success:
                scar_payload["occ_failure_reason"] = occ_result.saga_status
            self._write_event("scar", scar_payload)
            logger.warning(
                "CoherenceOrchestrator [SCAR] saga=%s: %d worker(s) failed, "
                "commit_success=%s",
                saga_id,
                workers_failed,
                commit_success,
            )

        # ------------------------------------------------------------------
        # Build and return CoherenceResult
        # ------------------------------------------------------------------
        overall_success = commit_success and (workers_failed == 0)

        result = CoherenceResult(
            success=overall_success,
            committed_state=committed_state,
            saga_id=saga_id,
            workers_succeeded=workers_succeeded,
            workers_failed=workers_failed,
        )

        logger.info(
            "CoherenceOrchestrator: saga %s COMPLETE — success=%s, "
            "workers_succeeded=%d, workers_failed=%d",
            saga_id,
            overall_success,
            workers_succeeded,
            workers_failed,
        )

        return result

    # ------------------------------------------------------------------
    # Worker stub (production override point)
    # ------------------------------------------------------------------

    async def _run_worker(
        self,
        session_id: str,
        task: dict,
        task_index: int,
    ) -> dict:
        """
        Default worker coroutine.

        In production this would invoke the real Gemini swarm worker logic.
        In tests, the BulkheadGuard's task list is replaced with mock coroutines.

        Args:
            session_id:  The orchestration session identifier.
            task:        The task dict from the DAG.
            task_index:  Zero-based index of this task in the DAG.

        Returns:
            A dict representing the worker's output (placeholder for now).
        """
        return {
            "session_id": session_id,
            "task_index": task_index,
            "task_type": task.get("task_type", "unknown"),
            "status": "completed",
        }

    # ------------------------------------------------------------------
    # Events logging helper
    # ------------------------------------------------------------------

    def _write_event(self, event_type: str, payload: dict) -> None:
        """
        Append a JSON-lines event to the observability log.

        The log file is created (including parent directories) on first write.
        Failures are logged but never propagated — event logging must never
        crash the orchestration pipeline.

        Args:
            event_type: Short label (e.g. "release", "scar").
            payload:    Arbitrary dict of event metadata.
        """
        try:
            EVENTS_LOG_PATH.parent.mkdir(parents=True, exist_ok=True)
            entry = {
                "event_type": event_type,
                "timestamp": datetime.now(tz=timezone.utc).isoformat(),
                **payload,
            }
            with EVENTS_LOG_PATH.open("a", encoding="utf-8") as fh:
                fh.write(json.dumps(entry) + "\n")
            logger.debug(
                "CoherenceOrchestrator: wrote %s event to %s",
                event_type,
                EVENTS_LOG_PATH,
            )
        except Exception as exc:
            logger.error(
                "CoherenceOrchestrator: failed to write %s event: %s",
                event_type,
                exc,
            )
