# swarm_coordinator.py
import asyncio
import json
import logging
import os
import random
import time
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass, field, asdict

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class AgentStatus(Enum):
    IDLE = "idle"
    WORKING = "working"
    FAILED = "failed"
    COMPLETED = "completed"

@dataclass
class AgentState:
    agent_id: str
    agent_type: str
    status: AgentStatus = AgentStatus.IDLE
    task_id: Optional[str] = None
    start_time: Optional[datetime] = None
    end_time: Optional[datetime] = None
    result: Optional[Any] = None
    error: Optional[str] = None
    tokens_used: int = 0  # Placeholder for tracking token usage

    def to_dict(self) -> Dict:
        d = asdict(self)
        d['status'] = self.status.value if self.status else None
        d['start_time'] = self.start_time.isoformat() if self.start_time else None
        d['end_time'] = self.end_time.isoformat() if self.end_time else None
        return d

class SwarmCoordinator:
    """
    Manages a swarm of agents to execute tasks, handle failures, and aggregate results.
    """

    def __init__(self, config_path: str = "swarm_config.json", workspace_path: str = "swarm_data"):
        self.config_path = Path(config_path)
        self.workspace_path = Path(workspace_path)
        self.workspace_path.mkdir(parents=True, exist_ok=True)
        self.task_counter = 0
        self.agents: Dict[str, AgentState] = {}
        self.tasks: Dict[str, Dict] = {}  # task_id: task details
        self.load_config()
        self.budget = 10000  # Example budget
        self.budget_spent = 0

    def load_config(self):
        """Load agent configurations from a JSON file."""
        try:
            with open(self.config_path, "r", encoding="utf-8") as f:
                config = json.load(f)
                for agent_spec in config.get("agents", []):
                    agent_id = agent_spec["id"]
                    self.agents[agent_id] = AgentState(
                        agent_id=agent_id,
                        agent_type=agent_spec["type"]
                    )
            logging.info(f"Loaded configuration from {self.config_path}")
        except FileNotFoundError:
            logging.warning(f"Configuration file not found at {self.config_path}. Using default configuration.")
            self.create_default_config()  # Create a default config if it doesn't exist
            self.load_config()  # Reload the config
        except json.JSONDecodeError as e:
            logging.error(f"Error decoding JSON from {self.config_path}: {e}")

    def create_default_config(self):
        """Creates a default swarm_config.json file."""
        default_config = {
            "agents": [
                {"id": "queen_core", "type": "queen"},
                {"id": "guardian_1", "type": "guardian"},
                {"id": "guardian_2", "type": "guardian"},
                {"id": "guardian_3", "type": "guardian"},
                {"id": "guardian_4", "type": "guardian"},
                {"id": "guardian_5", "type": "guardian"},
                {"id": "guardian_6", "type": "guardian"},
                {"id": "processor_1", "type": "processor"},
                {"id": "processor_2", "type": "processor"},
                {"id": "processor_3", "type": "processor"},
                {"id": "processor_4", "type": "processor"},
                {"id": "processor_5", "type": "processor"},
                {"id": "processor_6", "type": "processor"},
                {"id": "processor_7", "type": "processor"},
                {"id": "processor_8", "type": "processor"},
                {"id": "processor_9", "type": "processor"},
                {"id": "processor_10", "type": "processor"},
                {"id": "worker_1", "type": "worker"},
                {"id": "worker_2", "type": "worker"},
                {"id": "worker_3", "type": "worker"},
                {"id": "worker_4", "type": "worker"},
                {"id": "worker_5", "type": "worker"},
                {"id": "worker_6", "type": "worker"},
                {"id": "worker_7", "type": "worker"},
                {"id": "worker_8", "type": "worker"},
                {"id": "worker_9", "type": "worker"},
                {"id": "worker_10", "type": "worker"},
                {"id": "worker_11", "type": "worker"},
                {"id": "worker_12", "type": "worker"},
                {"id": "worker_13", "type": "worker"}
            ]
        }

        with open(self.config_path, "w", encoding="utf-8") as f:
            json.dump(default_config, f, indent=4)
        logging.info(f"Created default configuration file at {self.config_path}")

    def generate_task_id(self) -> str:
        """Generate a unique task ID."""
        self.task_counter += 1
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        return f"task_{timestamp}_{self.task_counter}"

    def assign_task(self, task: Dict) -> str:
        """Assign a task to an available agent based on its capabilities."""
        task_id = self.generate_task_id()
        self.tasks[task_id] = task

        # Simple assignment strategy: find the first available worker
        for agent_id, agent_state in self.agents.items():
            if agent_state.status == AgentStatus.IDLE and agent_state.agent_type == task["required_type"]:
                agent_state.status = AgentStatus.WORKING
                agent_state.task_id = task_id
                agent_state.start_time = datetime.now()
                logging.info(f"Assigned task {task_id} to agent {agent_id}")
                return task_id

        logging.warning(f"No available agent of type {task['required_type']} to assign task {task_id}.")
        return None

    async def execute_task(self, agent_id: str, executor: Callable):
        """Execute a task assigned to an agent."""
        agent_state = self.agents[agent_id]
        task_id = agent_state.task_id

        if not task_id:
            logging.error(f"Agent {agent_id} has no assigned task.")
            return

        task = self.tasks[task_id]

        try:
            result = await executor(task["description"])  # Pass the task description to the executor
            agent_state.status = AgentStatus.COMPLETED
            agent_state.end_time = datetime.now()
            agent_state.result = result
            logging.info(f"Agent {agent_id} completed task {task_id} successfully.")
            return result

        except Exception as e:
            agent_state.status = AgentStatus.FAILED
            agent_state.end_time = datetime.now()
            agent_state.error = str(e)
            logging.error(f"Agent {agent_id} failed to execute task {task_id}: {e}")
            return None

    def handle_failure(self, agent_id: str):
        """Handle agent failures by reassigning the task."""
        agent_state = self.agents[agent_id]
        task_id = agent_state.task_id

        if not task_id:
            logging.warning(f"Agent {agent_id} failed but has no assigned task.")
            return

        task = self.tasks[task_id]
        logging.info(f"Reassigning task {task_id} from failed agent {agent_id}.")

        agent_state.status = AgentStatus.IDLE
        agent_state.task_id = None
        agent_state.start_time = None
        agent_state.end_time = None

        new_task_id = self.assign_task(task)  # Reassign the task
        if not new_task_id:
             logging.error(f"Failed to reassign task {task_id} after agent {agent_id}'s failure.")

    def aggregate_results(self, task_ids: List[str]) -> List[Any]:
        """Aggregate results from completed tasks."""
        results = []
        for task_id in task_ids:
            for agent_id, agent_state in self.agents.items():
                if agent_state.task_id == task_id and agent_state.status == AgentStatus.COMPLETED:
                    results.append(agent_state.result)
                    break
        return results

    def track_progress(self):
        """Track the progress of tasks and agents."""
        logging.info("Swarm Progress:")
        for agent_id, agent_state in self.agents.items():
            logging.info(f"  Agent {agent_id}: Status={agent_state.status.value}, Task={agent_state.task_id}")

    def update_budget(self, cost: float):
        """Update the budget and track spending."""
        self.budget_spent += cost
        self.budget -= cost
        logging.info(f"Budget updated: Spent=${self.budget_spent:.2f}, Remaining=${self.budget:.2f}")

    def save_state(self, filename: str = "swarm_state.json"):
        """Save the current state of the swarm to a JSON file."""
        filepath = self.workspace_path / filename
        state = {
            "agents": {agent_id: agent_state.to_dict() for agent_id, agent_state in self.agents.items()},
            "tasks": self.tasks,
            "budget": self.budget,
            "budget_spent": self.budget_spent
        }
        try:
            with open(filepath, "w", encoding="utf-8") as f:
                json.dump(state, f, indent=4)
            logging.info(f"Swarm state saved to {filepath}")
        except Exception as e:
            logging.error(f"Error saving swarm state to {filepath}: {e}")

    def load_state(self, filename: str = "swarm_state.json"):
        """Load the swarm state from a JSON file."""
        filepath = self.workspace_path / filename
        try:
            with open(filepath, "r", encoding="utf-8") as f:
                state = json.load(f)

            # Load agent states
            for agent_id, agent_data in state.get("agents", {}).items():
                if agent_id in self.agents:
                    self.agents[agent_id].status = AgentStatus(agent_data["status"]) if agent_data["status"] else AgentStatus.IDLE
                    self.agents[agent_id].task_id = agent_data.get("task_id")
                    self.agents[agent_id].start_time = datetime.fromisoformat(agent_data["start_time"]) if agent_data.get("start_time") else None
                    self.agents[agent_id].end_time = datetime.fromisoformat(agent_data["end_time"]) if agent_data.get("end_time") else None
                    self.agents[agent_id].result = agent_data.get("result")
                    self.agents[agent_id].error = agent_data.get("error")

            # Load tasks
            self.tasks = state.get("tasks", {})

            # Load budget
            self.budget = state.get("budget", self.budget)
            self.budget_spent = state.get("budget_spent", self.budget_spent)

            logging.info(f"Swarm state loaded from {filepath}")
        except FileNotFoundError:
            logging.warning(f"Swarm state file not found at {filepath}.")
        except json.JSONDecodeError as e:
            logging.error(f"Error decoding JSON from {filepath}: {e}")
        except Exception as e:
            logging.error(f"Error loading swarm state from {filepath}: {e}")

async def dummy_executor(task_description: str) -> str:
    """A dummy executor that simulates task execution with a delay."""
    await asyncio.sleep(random.uniform(0.1, 0.5))  # Simulate work
    return f"Task completed: {task_description}"

async def main():
    """Main function to demonstrate swarm coordination."""
    coordinator = SwarmCoordinator()

    # Example tasks
    task1 = {"description": "Research quantum computing", "required_type": "worker"}
    task2 = {"description": "Analyze market trends", "required_type": "processor"}
    task3 = {"description": "Validate security protocols", "required_type": "guardian"}
    task4 = {"description": "Coordinate resource allocation", "required_type": "queen"}

    # Assign tasks
    task_id1 = coordinator.assign_task(task1)
    task_id2 = coordinator.assign_task(task2)
    task_id3 = coordinator.assign_task(task3)
    task_id4 = coordinator.assign_task(task4)

    # Execute tasks
    if task_id1:
        agent_id1 = next(agent_id for agent_id, agent_state in coordinator.agents.items() if agent_state.task_id == task_id1)
        result1 = await coordinator.execute_task(agent_id1, dummy_executor)

    if task_id2:
        agent_id2 = next(agent_id for agent_id, agent_state in coordinator.agents.items() if agent_state.task_id == task_id2)
        result2 = await coordinator.execute_task(agent_id2, dummy_executor)

    if task_id3:
        agent_id3 = next(agent_id for agent_id, agent_state in coordinator.agents.items() if agent_state.task_id == task_id3)
        result3 = await coordinator.execute_task(agent_id3, dummy_executor)

    if task_id4:
        agent_id4 = next(agent_id for agent_id, agent_state in coordinator.agents.items() if agent_state.task_id == task_id4)
        result4 = await coordinator.execute_task(agent_id4, dummy_executor)

    # Simulate a failure
    if task_id1:
        agent_id1 = next(agent_id for agent_id, agent_state in coordinator.agents.items() if agent_state.task_id == task_id1)
        coordinator.handle_failure(agent_id1)

    # Track progress
    coordinator.track_progress()

    # Aggregate results
    completed_tasks = [task_id for task_id in coordinator.tasks if any(agent_state.status == AgentStatus.COMPLETED and agent_state.task_id == task_id for agent_state in coordinator.agents.values())]
    results = coordinator.aggregate_results(completed_tasks)
    logging.info(f"Aggregated results: {results}")

    # Update budget
    coordinator.update_budget(500)

    # Save state
    coordinator.save_state()

    # Load state
    loaded_coordinator = SwarmCoordinator()
    loaded_coordinator.load_state()
    logging.info("Loaded coordinator state.")
    loaded_coordinator.track_progress()

if __name__ == "__main__":
    asyncio.run(main())