# swarm_coordinator.py
import asyncio
import json
import os
import psutil
import time
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass, field, asdict
from enum import Enum
import random

class WorkerStatus(Enum):
    PENDING = "pending"
    RUNNING = "running"
    SUCCESS = "success"
    FAILED = "failed"


@dataclass
class WorkerResult:
    """Result from a worker agent."""
    worker_id: str
    worker_type: str
    status: WorkerStatus
    output: Any = None
    error: Optional[str] = None
    duration_ms: float = 0.0
    tokens_used: int = 0

    def to_dict(self) -> Dict:
        d = asdict(self)
        d['status'] = self.status.value
        return d


@dataclass
class SwarmState:
    """Shared state for swarm session."""
    task_id: str
    original_task: str
    started_at: str = field(default_factory=lambda: datetime.now().isoformat())
    workers_spawned: List[str] = field(default_factory=list)
    results: List[WorkerResult] = field(default_factory=list)
    aggregated_result: Optional[Dict] = None
    status: str = "running"
    budget: float = 100.0  # Example budget
    cost_spent: float = 0.0


class SwarmAgent:
    """
    A swarm agent that executes a specific role.

    Agents are specialized by type and receive a subset of the task.
    """

    def __init__(self, agent_id: str, agent_type: str, capabilities: List[str]):
        self.agent_id = agent_id
        self.agent_type = agent_type
        self.capabilities = capabilities
        self.status = WorkerStatus.PENDING

    async def execute(self, task: str, swarm_state: SwarmState, executor: Callable = None) -> WorkerResult:
        """
        Execute the agent's assigned task.

        Args:
            task: The subtask to execute
            swarm_state: Shared context from the swarm
            executor: Optional custom executor function

        Returns:
            WorkerResult with output or error
        """
        start_time = datetime.now()
        self.status = WorkerStatus.RUNNING
        cost_per_task = random.uniform(0.1, 1.0) #Simulate cost of execution

        if swarm_state.cost_spent + cost_per_task > swarm_state.budget:
            self.status = WorkerStatus.FAILED
            duration = (datetime.now() - start_time).total_seconds() * 1000
            return WorkerResult(
                worker_id=self.agent_id,
                worker_type=self.agent_type,
                status=WorkerStatus.FAILED,
                error="Budget exceeded",
                duration_ms=duration
            )

        try:
            if executor:
                output = await executor(task, self.capabilities, swarm_state)
            else:
                # Default: simulate work (replace with actual LLM call)
                output = {
                    "agent_id": self.agent_id,
                    "capabilities": self.capabilities,
                    "task_received": task,
                    "result": f"Completed task using {self.capabilities}"
                }
                await asyncio.sleep(0.1)  # Simulate work

            self.status = WorkerStatus.SUCCESS
            duration = (datetime.now() - start_time).total_seconds() * 1000
            swarm_state.cost_spent += cost_per_task

            return WorkerResult(
                worker_id=self.agent_id,
                worker_type=self.agent_type,
                status=WorkerStatus.SUCCESS,
                output=output,
                duration_ms=duration
            )

        except Exception as e:
            self.status = WorkerStatus.FAILED
            duration = (datetime.now() - start_time).total_seconds() * 1000

            return WorkerResult(
                worker_id=self.agent_id,
                worker_type=self.agent_type,
                status=WorkerStatus.FAILED,
                error=str(e),
                duration_ms=duration
            )


class SwarmCoordinator:
    """
    Manages the swarm coordination for task execution.

    Architecture:
    - Queen Core: Central decision hub
    - Guardian Ring: 6-node defensive validation
    - Processing Ring: 10-node operational tier
    - Worker Swarm: Dynamic execution layer
    """

    def __init__(self):
        self.queen_core = None
        self.guardian_ring: List[SwarmAgent] = []
        self.processing_ring: List[SwarmAgent] = []
        self.worker_swarm: List[SwarmAgent] = []
        self.agents: Dict[str, SwarmAgent] = {}
        self.swarm_state: Optional[SwarmState] = None
        self.task_counter = 0

        self._initialize_swarm()

    def _generate_task_id(self) -> str:
        """Generate unique task ID."""
        self.task_counter += 1
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        return f"swarm_{timestamp}_{self.task_counter}"

    def _initialize_swarm(self):
        """Initializes the swarm with predefined agents."""
        # Queen Core
        self.queen_core = SwarmAgent(
            agent_id="queen_1",
            agent_type="queen",
            capabilities=["task_decomposition", "resource_allocation", "result_aggregation"]
        )
        self.agents["queen_1"] = self.queen_core

        # Guardian Ring (6 nodes)
        for i in range(1, 7):
            agent = SwarmAgent(
                agent_id=f"guardian_{i}",
                agent_type="guardian",
                capabilities=["validation", "security", "risk_assessment"]
            )
            self.guardian_ring.append(agent)
            self.agents[f"guardian_{i}"] = agent

        # Processing Ring (10 nodes)
        for i in range(1, 11):
            agent = SwarmAgent(
                agent_id=f"processor_{i}",
                agent_type="processor",
                capabilities=["data_processing", "analysis", "reporting"]
            )
            self.processing_ring.append(agent)
            self.agents[f"processor_{i}"] = agent

        # Example Worker Swarm (can be dynamic)
        for i in range(1, 6):
            agent = SwarmAgent(
                agent_id=f"worker_{i}",
                agent_type="worker",
                capabilities=["research", "coding", "documentation"]
            )
            self.worker_swarm.append(agent)
            self.agents[f"worker_{i}"] = agent


    async def decompose_task(self, task: str) -> List[Dict[str, Any]]:
        """
        Decomposes the main task into subtasks using the Queen Core.

        In production, this would use an LLM.
        """
        # Simulate task decomposition
        subtasks = [
            {"task": f"Research related information for: {task}", "required_capabilities": ["research"]},
            {"task": f"Analyze the data gathered for: {task}", "required_capabilities": ["analysis", "data_processing"]},
            {"task": f"Write a summary report for: {task}", "required_capabilities": ["reporting", "documentation"]},
            {"task": f"Validate results and assess risks for: {task}", "required_capabilities": ["validation", "risk_assessment"]}
        ]
        return subtasks

    def allocate_agents(self, subtasks: List[Dict[str, Any]]) -> Dict[str, List[str]]:
        """
        Allocates agents to subtasks based on their capabilities.
        """
        allocation = {}
        for i, subtask in enumerate(subtasks):
            required_capabilities = subtask["required_capabilities"]
            eligible_agents = [
                agent.agent_id for agent in self.agents.values()
                if any(cap in agent.capabilities for cap in required_capabilities)
            ]
            allocation[f"subtask_{i+1}"] = eligible_agents
        return allocation

    async def execute_subtasks(self, subtasks: List[Dict[str, Any]], allocation: Dict[str, List[str]], executor: Callable = None) -> List[WorkerResult]:
        """
        Executes subtasks using allocated agents in parallel.
        """
        tasks = []
        results = []
        for i, subtask in enumerate(subtasks):
            task_id = f"subtask_{i+1}"
            eligible_agents = allocation[task_id]
            if not eligible_agents:
                print(f"No agents available for {task_id}")
                continue

            # Assign a random agent for simplicity
            agent_id = random.choice(eligible_agents)
            agent = self.agents[agent_id]
            tasks.append(agent.execute(subtask["task"], self.swarm_state, executor))

        # Execute all tasks in parallel
        worker_results = await asyncio.gather(*tasks, return_exceptions=True)

        # Process results
        for result in worker_results:
            if isinstance(result, Exception):
                print(f"Task failed: {result}")
                results.append(WorkerResult(
                    worker_id="unknown",
                    worker_type="unknown",
                    status=WorkerStatus.FAILED,
                    error=str(result)
                ))
            else:
                results.append(result)

        self.swarm_state.results.extend(results)
        return results


    async def aggregate_results(self, results: List[WorkerResult]) -> Dict[str, Any]:
        """
        Aggregates and synthesizes results from all subtasks.
        """
        successful = [r for r in results if r.status == WorkerStatus.SUCCESS]
        failed = [r for r in results if r.status == WorkerStatus.FAILED]

        # Calculate metrics
        total_duration = sum(r.duration_ms for r in results)
        success_rate = len(successful) / len(results) if results else 0

        # Synthesize outputs
        synthesized = {
            "worker_outputs": [r.output for r in successful],
            "key_findings": [r.output.get("result", "") for r in successful if r.output],
        }

        aggregated = {
            "task_id": self.swarm_state.task_id,
            "original_task": self.swarm_state.original_task,
            "synthesis": synthesized,
            "metrics": {
                "total_workers": len(results),
                "successful": len(successful),
                "failed": len(failed),
                "success_rate": success_rate,
                "total_duration_ms": total_duration
            },
            "failed_workers": [{"id": r.worker_id, "error": r.error} for r in failed],
            "completed_at": datetime.now().isoformat()
        }

        self.swarm_state.aggregated_result = aggregated
        self.swarm_state.status = "completed"

        return aggregated

    async def execute_swarm(self, task: str, executor: Callable = None) -> Dict[str, Any]:
        """
        Main function to execute the swarm for a given task.
        """
        self.swarm_state = SwarmState(task_id=self._generate_task_id(), original_task=task)
        print(f"Swarm started for task: {task} (Task ID: {self.swarm_state.task_id})")

        # 1. Decompose the task
        subtasks = await self.decompose_task(task)
        print(f"Decomposed into {len(subtasks)} subtasks.")

        # 2. Allocate agents to subtasks
        allocation = self.allocate_agents(subtasks)
        print(f"Agent allocation: {allocation}")

        # 3. Execute subtasks in parallel
        results = await self.execute_subtasks(subtasks, allocation, executor)
        print(f"Subtasks execution completed. {len(results)} results received.")

        # 4. Aggregate results
        aggregated_results = await self.aggregate_results(results)
        print("Results aggregated.")

        print(f"Swarm completed with status: {self.swarm_state.status}")
        print(f"Total cost spent: {self.swarm_state.cost_spent:.2f} / {self.swarm_state.budget:.2f}")
        return aggregated_results


async def main():
    coordinator = SwarmCoordinator()
    task = "Analyze the current trends in AI and their potential impact on society."
    results = await coordinator.execute_swarm(task)
    print(f"Final Results: {results}")

if __name__ == "__main__":
    asyncio.run(main())