# swarm_coordinator.py
import asyncio
import json
import os
import psutil
import time
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass, field, asdict
from enum import Enum
import random

class WorkerStatus(Enum):
    PENDING = "pending"
    RUNNING = "running"
    SUCCESS = "success"
    FAILED = "failed"


@dataclass
class WorkerResult:
    """Result from a worker agent."""
    worker_id: str
    worker_type: str
    status: WorkerStatus
    output: Any = None
    error: Optional[str] = None
    duration_ms: float = 0.0
    tokens_used: int = 0

    def to_dict(self) -> Dict:
        d = asdict(self)
        d['status'] = self.status.value
        return d


@dataclass
class SwarmState:
    """Shared state for the entire swarm."""
    swarm_id: str
    original_task: str
    budget: float
    spent: float = 0.0
    started_at: str = field(default_factory=lambda: datetime.now().isoformat())
    workers_spawned: List[str] = field(default_factory=list)
    results: List[WorkerResult] = field(default_factory=list)
    aggregated_result: Optional[Dict] = None
    status: str = "running"
    tasks_assigned: int = 0


class SwarmAgent:
    """Base class for all agents in the swarm."""
    def __init__(self, agent_id: str, agent_type: str, capabilities: List[str]):
        self.agent_id = agent_id
        self.agent_type = agent_type
        self.capabilities = capabilities
        self.status = WorkerStatus.PENDING

    async def execute(self, task: str, shared_state: Dict, executor: Callable = None) -> WorkerResult:
        """Execute agent task."""
        start_time = datetime.now()
        self.status = WorkerStatus.RUNNING

        try:
            if executor:
                output = await executor(task, self.capabilities, shared_state)
            else:
                # Simulate work
                output = {
                    "agent_id": self.agent_id,
                    "agent_type": self.agent_type,
                    "task_received": task,
                    "result": f"Completed {self.agent_type} task with capabilities: {self.capabilities}"
                }
                await asyncio.sleep(random.uniform(0.1, 0.5))  # Simulate varying work times

            self.status = WorkerStatus.SUCCESS
            duration = (datetime.now() - start_time).total_seconds() * 1000

            # Simulate token usage and cost
            tokens_used = random.randint(100, 1000)
            cost = tokens_used * 0.0001  # Example cost per token

            return WorkerResult(
                worker_id=self.agent_id,
                worker_type=self.agent_type,
                status=WorkerStatus.SUCCESS,
                output=output,
                duration_ms=duration,
                tokens_used=tokens_used
            )

        except Exception as e:
            self.status = WorkerStatus.FAILED
            duration = (datetime.now() - start_time).total_seconds() * 1000

            return WorkerResult(
                worker_id=self.agent_id,
                worker_type=self.agent_type,
                status=WorkerStatus.FAILED,
                error=str(e),
                duration_ms=duration
            )


class QueenCore(SwarmAgent):
    """Central decision hub, manages task distribution and result aggregation."""
    def __init__(self, agent_id: str = "queen_core", budget: float = 100.0):
        super().__init__(agent_id, "queen_core", ["task_analysis", "resource_allocation", "result_aggregation"])
        self.budget = budget
        self.swarm_state = None

    async def initialize_swarm(self, original_task: str) -> str:
        """Initialize swarm with task and budget."""
        swarm_id = f"swarm_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        self.swarm_state = SwarmState(swarm_id=swarm_id, original_task=original_task, budget=self.budget)
        return swarm_id

    async def analyze_and_decompose_task(self, task: str) -> List[Dict[str, Any]]:
        """Decompose the main task into smaller subtasks."""
        # In production, use an LLM to analyze and decompose the task
        # For now, return a static list of subtasks
        subtasks = [
            {"task": f"Research the main aspects of {task}", "required_capabilities": ["research"]},
            {"task": f"Analyze the data gathered for {task}", "required_capabilities": ["analysis"]},
            {"task": f"Validate the findings related to {task}", "required_capabilities": ["validation"]},
            {"task": f"Summarize the results for {task}", "required_capabilities": ["summarization"]}
        ]
        return subtasks

    async def assign_tasks_to_agents(self, subtasks: List[Dict[str, Any]], agents: List[SwarmAgent]) -> List[Dict[str, Any]]:
        """Assign subtasks to agents based on capabilities."""
        assignments = []
        available_agents = agents[:]  # Create a copy to track available agents

        for subtask in subtasks:
            assigned = False
            for agent in available_agents:
                if all(cap in agent.capabilities for cap in subtask["required_capabilities"]):
                    assignments.append({
                        "agent_id": agent.agent_id,
                        "task": subtask["task"]
                    })
                    available_agents.remove(agent)  # Remove agent to prevent reassignment
                    assigned = True
                    self.swarm_state.tasks_assigned += 1
                    break

            if not assigned:
                print(f"Warning: No agent found for task: {subtask['task']}")
        return assignments

    async def aggregate_results(self, results: List[WorkerResult]) -> Dict[str, Any]:
        """Aggregate results from worker agents."""
        successful = [r for r in results if r.status == WorkerStatus.SUCCESS]
        failed = [r for r in results if r.status == WorkerStatus.FAILED]

        total_duration = sum(r.duration_ms for r in results)
        success_rate = len(successful) / len(results) if results else 0

        synthesized = {
            "worker_outputs": [r.output for r in successful],
            "key_findings": [r.output.get("result", "") for r in successful if r.output],
        }

        aggregated = {
            "swarm_id": self.swarm_state.swarm_id,
            "original_task": self.swarm_state.original_task,
            "synthesis": synthesized,
            "metrics": {
                "total_workers": len(results),
                "successful": len(successful),
                "failed": len(failed),
                "success_rate": success_rate,
                "total_duration_ms": total_duration
            },
            "failed_workers": [{"id": r.worker_id, "error": r.error} for r in failed],
            "completed_at": datetime.now().isoformat()
        }

        self.swarm_state.aggregated_result = aggregated
        self.swarm_state.status = "completed"

        return aggregated


class GuardianRing(SwarmAgent):
    """Defensive validation node."""
    def __init__(self, agent_id: str):
        super().__init__(agent_id, "guardian", ["validation", "security"])


class ProcessingRing(SwarmAgent):
    """Operational tier node."""
    def __init__(self, agent_id: str, capabilities: List[str]):
        super().__init__(agent_id, "processor", capabilities)


class WorkerSwarm(SwarmAgent):
    """Dynamic execution layer node."""
    def __init__(self, agent_id: str, capabilities: List[str]):
        super().__init__(agent_id, "worker", capabilities)


class SwarmCoordinator:
    """Manages the swarm deployment and coordination."""
    def __init__(self, num_guardians: int = 6, num_processors: int = 10, num_workers: int = 15):
        self.queen_core = QueenCore()
        self.guardian_ring = [GuardianRing(f"guardian_{i}") for i in range(num_guardians)]
        self.processing_ring = [ProcessingRing(f"processor_{i}", ["analysis", "processing"]) for i in range(num_processors)]
        self.worker_swarm = [WorkerSwarm(f"worker_{i}", ["research", "summarization"]) for i in range(num_workers)]
        self.all_agents = [self.queen_core] + self.guardian_ring + self.processing_ring + self.worker_swarm

    async def execute_swarm(self, task: str, executor: Callable = None) -> Dict[str, Any]:
        """Main execution flow for the swarm."""
        # 1. Initialize swarm
        swarm_id = await self.queen_core.initialize_swarm(task)
        print(f"Swarm initialized with ID: {swarm_id}")

        # 2. Analyze and decompose task
        subtasks = await self.queen_core.analyze_and_decompose_task(task)
        print(f"Task decomposed into {len(subtasks)} subtasks.")

        # 3. Assign tasks to agents
        assignments = await self.queen_core.assign_tasks_to_agents(subtasks, self.all_agents)
        print(f"Assigned {len(assignments)} tasks to agents.")

        # 4. Execute tasks in parallel
        worker_results = await self.execute_tasks_parallel(assignments, executor)

        # 5. Aggregate results
        aggregated_results = await self.queen_core.aggregate_results(worker_results)
        print(f"Aggregated results: {aggregated_results}")

        return aggregated_results

    async def execute_tasks_parallel(self, assignments: List[Dict[str, Any]], executor: Callable = None) -> List[WorkerResult]:
        """Execute assigned tasks in parallel."""
        tasks = []
        agent_map = {agent.agent_id: agent for agent in self.all_agents}

        for assignment in assignments:
            agent_id = assignment["agent_id"]
            task = assignment["task"]
            agent = agent_map[agent_id]

            shared_state = {
                "swarm_id": self.queen_core.swarm_state.swarm_id,
                "original_task": self.queen_core.swarm_state.original_task
            }

            tasks.append(agent.execute(task, shared_state, executor))

        results = await asyncio.gather(*tasks, return_exceptions=True)
        worker_results = []

        for i, result in enumerate(results):
            if isinstance(result, Exception):
                # Handle exception case
                agent_id = assignments[i]["agent_id"]
                agent = agent_map[agent_id]
                worker_results.append(WorkerResult(
                    worker_id=agent_id,
                    worker_type=agent.agent_type,
                    status=WorkerStatus.FAILED,
                    error=str(result)
                ))
            else:
                worker_results.append(result)

            # Update spent budget
            if self.queen_core.swarm_state:
                if isinstance(result, WorkerResult) and result.tokens_used:
                    cost = result.tokens_used * 0.0001 # Example cost
                    self.queen_core.swarm_state.spent += cost

        return worker_results

# Example usage
async def main():
    coordinator = SwarmCoordinator()
    task = "Research the current state of AI and its potential impact on society."
    results = await coordinator.execute_swarm(task)
    print(f"Final Swarm Results: {results}")
    if coordinator.queen_core.swarm_state:
        print(f"Total budget spent: {coordinator.queen_core.swarm_state.spent:.4f}")

if __name__ == "__main__":
    asyncio.run(main())