# swarm_coordinator.py
import asyncio
import json
import logging
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass, field, asdict

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class AgentStatus(Enum):
    IDLE = "idle"
    WORKING = "working"
    FAILED = "failed"
    COMPLETED = "completed"

@dataclass
class Agent:
    agent_id: str
    agent_type: str
    capabilities: List[str]
    status: AgentStatus = AgentStatus.IDLE
    task_id: Optional[str] = None
    result: Optional[Any] = None
    error: Optional[str] = None
    start_time: Optional[datetime] = None
    end_time: Optional[datetime] = None

    def to_dict(self) -> Dict:
        d = asdict(self)
        d['status'] = self.status.value
        return d

@dataclass
class Task:
    task_id: str
    description: str
    requirements: List[str]
    status: str = "pending"
    assigned_agents: List[str] = field(default_factory=list)
    results: List[Dict] = field(default_factory=list)
    budget: float = 100.0  # Example budget
    cost: float = 0.0

class SwarmCoordinator:
    """
    Manages a swarm of agents, distributes tasks, and aggregates results.

    Architecture:
    - Queen Core: Central decision hub
    - Guardian Ring: 6-node defensive validation
    - Processing Ring: 10-node operational tier
    - Worker Swarm: Dynamic execution layer
    """

    def __init__(self, config_path: str = "swarm_config.json"):
        self.config_path = Path(config_path)
        self.agents: Dict[str, Agent] = {}
        self.tasks: Dict[str, Task] = {}
        self.load_config()
        self.queen_core = None # To be initialized with a specific agent ID
        self.guardian_ring = [] # List of agent IDs
        self.processing_ring = [] # List of agent IDs
        self.initialize_tiers()


    def load_config(self):
        """Loads agent configurations from a JSON file."""
        try:
            with open(self.config_path, "r", encoding="utf-8") as f:
                config = json.load(f)
                for agent_data in config.get("agents", []):
                    agent = Agent(
                        agent_id=agent_data["agent_id"],
                        agent_type=agent_data["agent_type"],
                        capabilities=agent_data["capabilities"]
                    )
                    self.agents[agent.agent_id] = agent
            logging.info(f"Loaded configuration from {self.config_path}")
        except FileNotFoundError:
            logging.warning(f"Config file not found at {self.config_path}. Starting with empty agent list.")
        except json.JSONDecodeError as e:
            logging.error(f"Error decoding JSON from {self.config_path}: {e}")
        except Exception as e:
            logging.error(f"Error loading config: {e}")

    def initialize_tiers(self):
        """Initializes the Queen Core, Guardian Ring, and Processing Ring."""
        # Assign Queen Core (simplistic assignment)
        if self.agents:
            self.queen_core = list(self.agents.keys())[0]
            logging.info(f"Queen Core assigned to agent: {self.queen_core}")

            # Assign Guardian Ring (first 6 agents, if available)
            self.guardian_ring = list(self.agents.keys())[1:7]
            logging.info(f"Guardian Ring agents: {self.guardian_ring}")

            # Assign Processing Ring (next 10 agents, if available)
            self.processing_ring = list(self.agents.keys())[7:17]
            logging.info(f"Processing Ring agents: {self.processing_ring}")
        else:
            logging.warning("No agents loaded. Cannot initialize tiers.")


    def register_task(self, description: str, requirements: List[str], budget: float = 100.0) -> str:
        """Registers a new task and returns its ID."""
        task_id = f"task_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{len(self.tasks) + 1}"
        task = Task(task_id=task_id, description=description, requirements=requirements, budget=budget)
        self.tasks[task_id] = task
        logging.info(f"Registered task: {task_id} - {description}")
        return task_id

    async def assign_task(self, task_id: str) -> bool:
        """Assigns a task to suitable agents based on their capabilities."""
        task = self.tasks.get(task_id)
        if not task:
            logging.error(f"Task not found: {task_id}")
            return False

        # Find suitable agents based on requirements
        suitable_agents = [
            agent for agent_id, agent in self.agents.items()
            if all(req in agent.capabilities for req in task.requirements) and agent.status == AgentStatus.IDLE
        ]

        if not suitable_agents:
            logging.warning(f"No suitable agents found for task: {task_id}")
            return False

        # Prioritize agents in the Processing Ring
        processing_ring_agents = [agent for agent in suitable_agents if agent.agent_id in self.processing_ring]
        if processing_ring_agents:
            suitable_agents = processing_ring_agents

        # Assign the task to the first suitable agent
        agent = suitable_agents[0]
        agent.status = AgentStatus.WORKING
        agent.task_id = task_id
        task.assigned_agents.append(agent.agent_id)
        logging.info(f"Assigned task {task_id} to agent {agent.agent_id}")
        return True

    async def execute_task(self, task_id: str, executor: Callable = None):
        """Executes a task by calling the agent's execution function."""
        task = self.tasks.get(task_id)
        if not task:
            logging.error(f"Task not found: {task_id}")
            return

        if not task.assigned_agents:
            logging.warning(f"No agents assigned to task: {task_id}")
            return

        agent_id = task.assigned_agents[0]
        agent = self.agents[agent_id]

        agent.start_time = datetime.now()

        try:
            if executor:
                result = await executor(task.description, agent.capabilities)
            else:
                # Simulate task execution
                await asyncio.sleep(1)
                result = f"Task '{task.description}' completed by {agent.agent_id}"

            agent.status = AgentStatus.COMPLETED
            agent.result = result
            agent.end_time = datetime.now()
            task.status = "completed"
            task.results.append({"agent_id": agent_id, "result": result})

            logging.info(f"Task {task_id} completed by agent {agent_id}: {result}")

        except Exception as e:
            agent.status = AgentStatus.FAILED
            agent.error = str(e)
            task.status = "failed"
            logging.error(f"Task {task_id} failed on agent {agent_id}: {e}")
            # Handle failure: reassign task or escalate

        finally:
            # Update agent status
            agent.task_id = None

    async def handle_agent_failure(self, agent_id: str):
        """Handles agent failures by reassigning the task."""
        agent = self.agents.get(agent_id)
        if not agent:
            logging.error(f"Agent not found: {agent_id}")
            return

        if agent.status != AgentStatus.FAILED:
            logging.warning(f"Agent {agent_id} is not in FAILED state.")
            return

        task_id = agent.task_id
        if not task_id:
            logging.warning(f"Agent {agent_id} has no task assigned.")
            return

        task = self.tasks.get(task_id)
        if not task:
            logging.error(f"Task not found: {task_id}")
            return

        # Remove agent from assigned agents list
        if agent_id in task.assigned_agents:
            task.assigned_agents.remove(agent_id)

        # Reset agent status
        agent.status = AgentStatus.IDLE
        agent.error = None
        agent.result = None

        # Reassign the task
        if await self.assign_task(task_id):
            logging.info(f"Task {task_id} reassigned after agent failure.")
        else:
            logging.warning(f"Failed to reassign task {task_id} after agent failure.")
            task.status = "failed"  # Mark task as failed if reassignment fails


    def aggregate_results(self, task_id: str) -> List[Dict]:
        """Aggregates results from all agents for a given task."""
        task = self.tasks.get(task_id)
        if not task:
            logging.error(f"Task not found: {task_id}")
            return []

        return task.results

    def get_task_status(self, task_id: str) -> str:
        """Returns the status of a task."""
        task = self.tasks.get(task_id)
        if not task:
            logging.error(f"Task not found: {task_id}")
            return "unknown"
        return task.status

    def get_agent_status(self, agent_id: str) -> str:
        """Returns the status of an agent."""
        agent = self.agents.get(agent_id)
        if not agent:
            logging.error(f"Agent not found: {agent_id}")
            return "unknown"
        return agent.status.value

    def track_progress(self) -> Dict[str, Any]:
        """Tracks overall progress of the swarm."""
        total_tasks = len(self.tasks)
        completed_tasks = sum(1 for task in self.tasks.values() if task.status == "completed")
        failed_tasks = sum(1 for task in self.tasks.values() if task.status == "failed")
        active_agents = sum(1 for agent in self.agents.values() if agent.status == AgentStatus.WORKING)

        progress = {
            "total_tasks": total_tasks,
            "completed_tasks": completed_tasks,
            "failed_tasks": failed_tasks,
            "active_agents": active_agents,
            "completion_rate": (completed_tasks / total_tasks) * 100 if total_tasks else 0
        }
        return progress

    def track_budget(self) -> Dict[str, float]:
        """Tracks the overall budget usage."""
        total_budget = sum(task.budget for task in self.tasks.values())
        total_cost = sum(task.cost for task in self.tasks.values())  # Assuming cost tracking is implemented
        remaining_budget = total_budget - total_cost

        budget_info = {
            "total_budget": total_budget,
            "total_cost": total_cost,
            "remaining_budget": remaining_budget
        }
        return budget_info

    def save_state(self, filepath: str = "swarm_state.json"):
        """Saves the current state of the swarm to a JSON file."""
        state = {
            "agents": {agent_id: agent.to_dict() for agent_id, agent in self.agents.items()},
            "tasks": {task_id: asdict(task) for task_id, task in self.tasks.items()}
        }
        try:
            with open(filepath, "w", encoding="utf-8") as f:
                json.dump(state, f, indent=4)
            logging.info(f"Swarm state saved to {filepath}")
        except Exception as e:
            logging.error(f"Error saving swarm state: {e}")

    def load_state(self, filepath: str = "swarm_state.json"):
        """Loads the swarm state from a JSON file."""
        try:
            with open(filepath, "r", encoding="utf-8") as f:
                state = json.load(f)

                # Load agents
                for agent_id, agent_data in state.get("agents", {}).items():
                    agent = Agent(
                        agent_id=agent_data["agent_id"],
                        agent_type=agent_data["agent_type"],
                        capabilities=agent_data["capabilities"],
                        status=AgentStatus(agent_data["status"]),
                        task_id=agent_data.get("task_id"),
                        result=agent_data.get("result"),
                        error=agent_data.get("error")
                    )
                    self.agents[agent_id] = agent

                # Load tasks
                for task_id, task_data in state.get("tasks", {}).items():
                    task = Task(
                        task_id=task_data["task_id"],
                        description=task_data["description"],
                        requirements=task_data["requirements"],
                        status=task_data["status"],
                        assigned_agents=task_data["assigned_agents"],
                        results=task_data["results"],
                        budget=task_data["budget"],
                        cost=task_data["cost"]
                    )
                    self.tasks[task_id] = task

            logging.info(f"Swarm state loaded from {filepath}")

            # Re-initialize tiers after loading state
            self.initialize_tiers()

        except FileNotFoundError:
            logging.warning(f"Swarm state file not found at {filepath}. Starting with empty state.")
        except json.JSONDecodeError as e:
            logging.error(f"Error decoding JSON from {filepath}: {e}")
        except Exception as e:
            logging.error(f"Error loading swarm state: {e}")

async def example_executor(task_description: str, agent_capabilities: List[str]) -> str:
    """A dummy executor function for demonstration purposes."""
    await asyncio.sleep(0.5)  # Simulate work
    return f"Task '{task_description}' executed with capabilities: {agent_capabilities}"

async def main():
    coordinator = SwarmCoordinator(config_path="swarm_config.json")

    # Example usage:
    task_id_1 = coordinator.register_task(
        description="Analyze market trends",
        requirements=["data_analysis", "market_research"],
        budget=50.0
    )

    task_id_2 = coordinator.register_task(
        description="Write a Python script",
        requirements=["programming", "python"],
        budget=30.0
    )

    await coordinator.assign_task(task_id_1)
    await coordinator.assign_task(task_id_2)

    # Execute tasks using the example executor
    await coordinator.execute_task(task_id_1, executor=example_executor)
    await coordinator.execute_task(task_id_2, executor=example_executor)


    # Simulate agent failure
    for task in coordinator.tasks.values():
        if task.status == "completed":
            continue
        if task.assigned_agents:
            agent_id = task.assigned_agents[0]
            coordinator.agents[agent_id].status = AgentStatus.FAILED
            await coordinator.handle_agent_failure(agent_id)
            await coordinator.execute_task(task.task_id, executor=example_executor)

    # Aggregate results
    results_1 = coordinator.aggregate_results(task_id_1)
    print(f"\nResults for task {task_id_1}: {results_1}")

    results_2 = coordinator.aggregate_results(task_id_2)
    print(f"Results for task {task_id_2}: {results_2}")

    # Track progress and budget
    progress = coordinator.track_progress()
    print(f"\nSwarm Progress: {progress}")

    budget_info = coordinator.track_budget()
    print(f"Budget Information: {budget_info}")

    # Save and load state
    coordinator.save_state("swarm_state.json")
    new_coordinator = SwarmCoordinator()
    new_coordinator.load_state("swarm_state.json")

    print("\nOriginal Coordinator Agents:")
    for agent_id, agent in coordinator.agents.items():
        print(f"Agent {agent_id}: {agent}")

    print("\nLoaded Coordinator Agents:")
    for agent_id, agent in new_coordinator.agents.items():
        print(f"Agent {agent_id}: {agent}")

if __name__ == "__main__":
    # Create a dummy swarm_config.json
    dummy_config = {
        "agents": [
            {"agent_id": "agent_1", "agent_type": "analyzer", "capabilities": ["data_analysis", "market_research"]},
            {"agent_id": "agent_2", "agent_type": "writer", "capabilities": ["programming", "python"]},
            {"agent_id": "agent_3", "agent_type": "validator", "capabilities": ["data_analysis", "validation"]},
            {"agent_id": "agent_4", "agent_type": "researcher", "capabilities": ["market_research", "data_gathering"]},
            {"agent_id": "agent_5", "agent_type": "tester", "capabilities": ["programming", "testing"]},
            {"agent_id": "agent_6", "agent_type": "documenter", "capabilities": ["writing", "documentation"]},
            {"agent_id": "agent_7", "agent_type": "integrator", "capabilities": ["programming", "integration"]},
            {"agent_id": "agent_8", "agent_type": "deployer", "capabilities": ["deployment", "cloud"]},
            {"agent_id": "agent_9", "agent_type": "monitor", "capabilities": ["monitoring", "alerts"]},
            {"agent_id": "agent_10", "agent_type": "optimizer", "capabilities": ["data_analysis", "optimization"]},
        ]
    }
    with open("swarm_config.json", "w") as f:
        json.dump(dummy_config, f, indent=4)

    asyncio.run(main())