#!/usr/bin/env python3
"""
Genesis Orchestrator Agent
===========================
Spawns and coordinates multiple worker agents using asyncio.gather().

Based on research findings:
- Hierarchical Orchestrator-Worker Pattern (Anthropic's approach)
- 37% faster with parallel execution
- asyncio.gather() for Python-native concurrent execution

Usage:
    from orchestrator import Orchestrator

    orch = Orchestrator()
    result = await orch.execute("Research MCP servers and summarize")
"""

import asyncio
import json
from datetime import datetime
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass, field, asdict
from pathlib import Path
from enum import Enum


from core.orchestrator_tiered_integration import OrchestratorTieredIntegration, route_task_with_true_method

# Initialize the integration (do this once in __init__ or module load)
_tiered_integration = OrchestratorTieredIntegration()


class WorkerStatus(Enum):
    PENDING = "pending"
    RUNNING = "running"
    SUCCESS = "success"
    FAILED = "failed"


@dataclass
class WorkerResult:
    """Result from a worker agent."""
    worker_id: str
    worker_type: str
    status: WorkerStatus
    output: Any = None
    error: Optional[str] = None
    duration_ms: float = 0.0
    tokens_used: int = 0

    def to_dict(self) -> Dict:
        d = asdict(self)
        d['status'] = self.status.value
        return d


@dataclass
class OrchestratorState:
    """Shared state for orchestration session."""
    task_id: str
    original_task: str
    started_at: str = field(default_factory=lambda: datetime.now().isoformat())
    workers_spawned: List[str] = field(default_factory=list)
    results: List[WorkerResult] = field(default_factory=list)
    aggregated_result: Optional[Dict] = None
    status: str = "running"


class WorkerAgent:
    """
    A worker agent that executes a specific role.

    Workers are specialized by type and receive a subset of the task.
    """

    def __init__(self, worker_id: str, worker_type: str, role: str):
        self.worker_id = worker_id
        self.worker_type = worker_type
        self.role = role
        self.status = WorkerStatus.PENDING

    async def execute(self, task: str, shared_state: Dict, executor: Callable = None) -> WorkerResult:
        """
        Execute the worker's assigned task.

        Args:
            task: The subtask to execute
            shared_state: Shared context from orchestrator
            executor: Optional custom executor function

        Returns:
            WorkerResult with output or error
        """
        start_time = datetime.now()
        self.status = WorkerStatus.RUNNING

        try:
            if executor:
                output = await executor(task, self.role, shared_state)
            else:
                # Default: simulate work (replace with actual LLM call)
                output = {
                    "worker_id": self.worker_id,
                    "role": self.role,
                    "task_received": task,
                    "result": f"Completed {self.role} analysis"
                }
                await asyncio.sleep(0.1)  # Simulate work

            self.status = WorkerStatus.SUCCESS
            duration = (datetime.now() - start_time).total_seconds() * 1000

            return WorkerResult(
                worker_id=self.worker_id,
                worker_type=self.worker_type,
                status=WorkerStatus.SUCCESS,
                output=output,
                duration_ms=duration
            )

        except Exception as e:
            self.status = WorkerStatus.FAILED
            duration = (datetime.now() - start_time).total_seconds() * 1000

            return WorkerResult(
                worker_id=self.worker_id,
                worker_type=self.worker_type,
                status=WorkerStatus.FAILED,
                error=str(e),
                duration_ms=duration
            )


class Orchestrator:
    """
    Main orchestrator that spawns and coordinates worker agents.

    Implements the Hierarchical Orchestrator-Worker pattern:
    1. Analyze task and determine required workers
    2. Spawn workers in parallel using asyncio.gather()
    3. Aggregate results from all workers
    """

    # Worker type definitions
    WORKER_TYPES = {
        "researcher": "Research and gather information",
        "analyzer": "Analyze data and extract insights",
        "implementer": "Write code and implement solutions",
        "reviewer": "Review and validate outputs",
        "documenter": "Create documentation and summaries"
    }

    def __init__(self, persist_path: Optional[str] = None):
        self.workers: Dict[str, WorkerAgent] = {}
        self.state: Optional[OrchestratorState] = None
        self.persist_path = Path(persist_path) if persist_path else None
        self.task_counter = 0

    def _generate_task_id(self) -> str:
        """Generate unique task ID."""
        self.task_counter += 1
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        return f"orch_{timestamp}_{self.task_counter}"

    async def analyze_task(self, task: str) -> List[Dict[str, str]]:
        """
        Analyze task and determine required workers.

        In production, this would use an LLM to determine worker allocation.
        For now, returns a default set of workers.
        """
        # Default: spawn 3 workers for any task
        # In production: LLM would analyze task and select appropriate workers
        return [
            {"id": f"worker_1", "type": "researcher", "role": "Gather relevant information"},
            {"id": f"worker_2", "type": "analyzer", "role": "Analyze and extract insights"},
            {"id": f"worker_3", "type": "reviewer", "role": "Validate and summarize findings"}
        ]

    async def spawn_workers(self, worker_specs: List[Dict[str, str]]) -> List[str]:
        """
        Spawn worker agents based on specifications.

        Args:
            worker_specs: List of worker configurations

        Returns:
            List of worker IDs
        """
        worker_ids = []

        for spec in worker_specs:
            worker = WorkerAgent(
                worker_id=spec["id"],
                worker_type=spec["type"],
                role=spec["role"]
            )
            self.workers[spec["id"]] = worker
            worker_ids.append(spec["id"])
            self.state.workers_spawned.append(spec["id"])

        return worker_ids

    async def execute_workers(
        self,
        worker_ids: List[str],
        task: str,
        executor: Callable = None
    ) -> List[WorkerResult]:
        """
        Execute all workers in parallel using asyncio.gather().

        This is the key performance optimization - parallel execution
        is ~37% faster than sequential.
        """
        shared_state = {
            "task_id": self.state.task_id,
            "original_task": self.state.original_task,
            "workers": worker_ids
        }

        # Create coroutines for all workers
        tasks = []
        for worker_id in worker_ids:
            worker = self.workers[worker_id]
            tasks.append(worker.execute(task, shared_state, executor))

        # Execute all workers in parallel
        results = await asyncio.gather(*tasks, return_exceptions=True)

        # Process results
        worker_results = []
        for i, result in enumerate(results):
            if isinstance(result, Exception):
                # Handle exception case
                worker_results.append(WorkerResult(
                    worker_id=worker_ids[i],
                    worker_type=self.workers[worker_ids[i]].worker_type,
                    status=WorkerStatus.FAILED,
                    error=str(result)
                ))
            else:
                worker_results.append(result)

        self.state.results = worker_results
        return worker_results

    async def aggregate_results(self, results: List[WorkerResult]) -> Dict[str, Any]:
        """
        Aggregate and synthesize worker outputs.

        Aggregation strategy:
        1. Collect successful outputs
        2. Note failures for retry/escalation
        3. Synthesize final result
        """
        successful = [r for r in results if r.status == WorkerStatus.SUCCESS]
        failed = [r for r in results if r.status == WorkerStatus.FAILED]

        # Calculate metrics
        total_duration = sum(r.duration_ms for r in results)
        success_rate = len(successful) / len(results) if results else 0

        # Synthesize outputs
        # In production: LLM would synthesize these
        synthesized = {
            "worker_outputs": [r.output for r in successful],
            "key_findings": [r.output.get("result", "") for r in successful if r.output],
        }

        aggregated = {
            "task_id": self.state.task_id,
            "original_task": self.state.original_task,
            "synthesis": synthesized,
            "metrics": {
                "total_workers": len(results),
                "successful": len(successful),
                "failed": len(failed),
                "success_rate": success_rate,
                "total_duration_ms": total_duration
            },
            "failed_workers": [{"id": r.worker_id, "error": r.error} for r in failed],
            "completed_at": datetime.now().isoformat()
        }

        self.state.aggregated_result = aggregated
        self.state.status = "completed"

        return aggregated

    async def orchestrate(
        self,
        task: str,
        custom_workers: List[Dict[str, str]] = None, # Keep for potential sub-orchestration
        executor: Callable = None # Keep for potential sub-orchestration
    ) -> Dict[str, Any]:
        """
        Main orchestration flow using the Genesis TRUE Method.
        It routes the main task to the TieredExecutor or a simple executor
        based on complexity, then aggregates the result.

        Args:
            task: The task to execute
            custom_workers: Optional custom worker specs (for future sub-orchestration)
            executor: Optional custom execution function (for future sub-orchestration)

        Returns:
            Aggregated results from the execution
        """
        self.state = OrchestratorState(
            task_id=self._generate_task_id(),
            original_task=task
        )

        # Use the new TRUE method for main task routing and execution
        task_dict_for_routing = {"id": self.state.task_id, "description": task, "task": task}
        routing_result = _tiered_integration.route_and_execute(
            task_id=self.state.task_id,
            task=task_dict_for_routing,
            additional_context=None # Can be supplied if orchestrator has context
        )

        # Aggregate the single result from the tiered execution
        self.state.aggregated_result = {
            "task_id": self.state.task_id,
            "original_task": task,
            "synthesis": routing_result.get("output", ""),
            "metrics": {
                "total_workers": 1, # Represents the single tiered execution process
                "successful": 1 if routing_result.get("success") else 0,
                "failed": 0 if routing_result.get("success") else 1,
                "total_duration_ms": routing_result.get("duration_ms", 0)
            },
            "failed_workers": [] if routing_result.get("success") else [{"id": self.state.task_id, "error": routing_result.get("error_message")}],
            "completed_at": datetime.now().isoformat(),
            "tiered_execution_details": routing_result
        }
        self.state.status = "completed" if routing_result.get("success") else "failed"

        # Persist if configured
        if self.persist_path:
            self._persist_state()

        return self.state.aggregated_result

    def _persist_state(self):
        """Save orchestration state to disk."""
        if not self.persist_path or not self.state:
            return

        self.persist_path.mkdir(parents=True, exist_ok=True)
        filepath = self.persist_path / f"{self.state.task_id}.json"

        state_dict = {
            "task_id": self.state.task_id,
            "original_task": self.state.original_task,
            "started_at": self.state.started_at,
            "workers_spawned": self.state.workers_spawned,
            "results": [r.to_dict() for r in self.state.results],
            "aggregated_result": self.state.aggregated_result,
            "status": self.state.status
        }

        with open(filepath, 'w') as f:
            json.dump(state_dict, f, indent=2)


# CLI for testing
async def main():
    """Test the orchestrator."""
    orch = Orchestrator(persist_path="/mnt/e/genesis-system/orchestrator_logs")

    result = await orch.orchestrate(
        task="Research and implement MCP server integration patterns"
    )

    print(json.dumps(result, indent=2))


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 2:
        print("""
Genesis Orchestrator
====================

Commands:
  test                     Run test orchestration
  orchestrate "<task>"     Execute task with workers

Examples:
  python orchestrator.py test
  python orchestrator.py orchestrate "Analyze codebase structure"
        """)
        sys.exit(0)

    command = sys.argv[1]

    if command == "test":
        asyncio.run(main())
    elif command == "orchestrate" and len(sys.argv) > 2:
        task = sys.argv[2]
        asyncio.run(Orchestrator().orchestrate(task))
    else:
        print(f"Unknown command: {command}")
