# swarm_coordinator.py
import asyncio
import json
import os
import psutil
import time
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass, field, asdict
from enum import Enum

class AgentStatus(Enum):
    IDLE = "idle"
    WORKING = "working"
    FAILED = "failed"
    SUCCESS = "success"

@dataclass
class Agent:
    agent_id: str
    role: str
    capabilities: List[str]
    status: AgentStatus = AgentStatus.IDLE
    task_id: Optional[str] = None
    result: Optional[Any] = None
    error: Optional[str] = None

    def to_dict(self) -> Dict:
        d = asdict(self)
        d['status'] = self.status.value if self.status else None
        return d

class TaskStatus(Enum):
    PENDING = "pending"
    IN_PROGRESS = "in_progress"
    COMPLETED = "completed"
    FAILED = "failed"

@dataclass
class Task:
    task_id: str
    description: str
    required_capabilities: List[str]
    status: TaskStatus = TaskStatus.PENDING
    assigned_agent: Optional[str] = None
    result: Optional[Any] = None
    error: Optional[str] = None

class SwarmCoordinator:
    """
    Manages a swarm of agents, distributing tasks and aggregating results.
    """
    def __init__(self, workspace_path: str = "aiva_swarm", config_file: str = "swarm_config.json", budget: float = 100.0):
        self.workspace = Path(workspace_path)
        self.config_path = self.workspace / "config" / config_file
        self.data_path = self.workspace / "data"
        self.data_path.mkdir(parents=True, exist_ok=True)
        self.agents: Dict[str, Agent] = {}
        self.tasks: Dict[str, Task] = {}
        self.task_counter = 0
        self.budget = budget
        self.spent = 0.0
        self._load_config()

    def _load_config(self):
        """Loads agent configurations from a JSON file."""
        if self.config_path.exists():
            try:
                with open(self.config_path, "r", encoding="utf-8") as f:
                    config = json.load(f)
                    for agent_data in config.get("agents", []):
                        agent = Agent(
                            agent_id=agent_data["agent_id"],
                            role=agent_data["role"],
                            capabilities=agent_data["capabilities"]
                        )
                        self.agents[agent.agent_id] = agent
                print("Agent configurations loaded.")
            except Exception as e:
                print(f"Error loading agent configurations: {e}")
        else:
            print("No agent configuration file found. Starting with an empty swarm.")

    def create_task(self, description: str, required_capabilities: List[str]) -> str:
        """Creates a new task and adds it to the task queue."""
        task_id = f"task_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{self.task_counter}"
        self.task_counter += 1
        task = Task(
            task_id=task_id,
            description=description,
            required_capabilities=required_capabilities
        )
        self.tasks[task_id] = task
        print(f"Task created: {task_id} - {description}")
        return task_id

    async def assign_task(self, task_id: str) -> bool:
        """Assigns a task to a suitable agent based on capabilities."""
        task = self.tasks.get(task_id)
        if not task:
            print(f"Task not found: {task_id}")
            return False

        if task.status != TaskStatus.PENDING:
            print(f"Task already assigned or completed: {task_id}")
            return False

        available_agents = [
            agent for agent_id, agent in self.agents.items()
            if agent.status == AgentStatus.IDLE and
               all(cap in agent.capabilities for cap in task.required_capabilities)
        ]

        if not available_agents:
            print(f"No suitable agent available for task: {task_id}")
            return False

        # Assign to the first available agent (can be improved with prioritization)
        agent = available_agents[0]
        task.assigned_agent = agent.agent_id
        agent.status = AgentStatus.WORKING
        agent.task_id = task_id
        task.status = TaskStatus.IN_PROGRESS

        print(f"Task {task_id} assigned to agent {agent.agent_id}")
        return True

    async def execute_task(self, task_id: str, executor: Callable):
        """Executes a task using the assigned agent and a given executor function."""
        task = self.tasks.get(task_id)
        if not task or not task.assigned_agent:
            print(f"Task {task_id} not assigned or not found.")
            return

        agent = self.agents.get(task.assigned_agent)
        if not agent:
            print(f"Assigned agent {task.assigned_agent} not found.")
            task.status = TaskStatus.FAILED
            task.error = f"Assigned agent {task.assigned_agent} not found."
            return

        print(f"Executing task {task_id} with agent {agent.agent_id}...")
        try:
            if not executor:
                raise ValueError("Executor function is required.")

            result = await executor(task.description, agent.role)  # Pass agent role
            task.result = result
            task.status = TaskStatus.COMPLETED
            agent.status = AgentStatus.SUCCESS
            agent.result = result
            print(f"Task {task_id} completed successfully.")

        except Exception as e:
            print(f"Task {task_id} failed: {e}")
            task.status = TaskStatus.FAILED
            task.error = str(e)
            agent.status = AgentStatus.FAILED
            agent.error = str(e)

        finally:
            agent.task_id = None  # Reset agent's task ID
            await self.save_task_status(task_id)
            await self.save_agent_status(agent.agent_id)


    async def reassign_task(self, task_id: str):
        """Reassigns a failed task to another suitable agent."""
        task = self.tasks.get(task_id)
        if not task:
            print(f"Task not found: {task_id}")
            return

        if task.status != TaskStatus.FAILED:
            print(f"Task is not in a failed state: {task_id}")
            return

        # Reset the failed agent's status
        failed_agent = self.agents.get(task.assigned_agent)
        if failed_agent:
            failed_agent.status = AgentStatus.IDLE
            failed_agent.task_id = None
            failed_agent.error = None
            failed_agent.result = None

        task.assigned_agent = None
        task.status = TaskStatus.PENDING
        task.error = None

        if await self.assign_task(task_id):
            print(f"Task {task_id} successfully reassigned.")
        else:
            print(f"Failed to reassign task {task_id}.")

    async def aggregate_results(self) -> Dict[str, Any]:
        """Aggregates results from completed tasks."""
        completed_tasks = [
            task for task_id, task in self.tasks.items() if task.status == TaskStatus.COMPLETED
        ]
        results = [task.result for task in completed_tasks]

        # Simple aggregation: concatenate results
        aggregated_result = {
            "task_count": len(completed_tasks),
            "results": results
        }

        return aggregated_result

    async def track_progress(self) -> Dict[str, int]:
        """Tracks the progress of all tasks."""
        pending_tasks = sum(1 for task in self.tasks.values() if task.status == TaskStatus.PENDING)
        in_progress_tasks = sum(1 for task in self.tasks.values() if task.status == TaskStatus.IN_PROGRESS)
        completed_tasks = sum(1 for task in self.tasks.values() if task.status == TaskStatus.COMPLETED)
        failed_tasks = sum(1 for task in self.tasks.values() if task.status == TaskStatus.FAILED)

        progress = {
            "pending": pending_tasks,
            "in_progress": in_progress_tasks,
            "completed": completed_tasks,
            "failed": failed_tasks
        }
        return progress

    def update_budget(self, cost: float):
        """Updates the remaining budget."""
        self.spent += cost
        self.budget -= cost
        if self.budget < 0:
            print("Warning: Budget exceeded!")

    async def save_task_status(self, task_id: str):
        """Saves the status of a task to a file."""
        task = self.tasks.get(task_id)
        if not task:
            print(f"Task not found: {task_id}")
            return

        file_path = self.data_path / f"task_{task_id}_status.json"
        with open(file_path, "w", encoding="utf-8") as f:
            json.dump(asdict(task), f, indent=2, default=str)  # Use str() for enum serialization

    async def save_agent_status(self, agent_id: str):
        """Saves the status of an agent to a file."""
        agent = self.agents.get(agent_id)
        if not agent:
            print(f"Agent not found: {agent_id}")
            return

        file_path = self.data_path / f"agent_{agent_id}_status.json"
        with open(file_path, "w", encoding="utf-8") as f:
            json.dump(agent.to_dict(), f, indent=2, default=str)

    async def run_swarm(self, tasks: List[Dict[str, Any]], executor: Callable):
        """
        Runs the swarm to process a list of tasks.

        Args:
            tasks: A list of task dictionaries, each containing 'description' and 'required_capabilities'.
            executor: An asynchronous function that takes a task description and agent role as input and returns a result.
        """
        task_ids = []
        for task_data in tasks:
            task_id = self.create_task(task_data['description'], task_data['required_capabilities'])
            task_ids.append(task_id)

        # Assign and execute tasks concurrently
        assignment_tasks = [self.assign_task(task_id) for task_id in task_ids]
        await asyncio.gather(*assignment_tasks)

        execution_tasks = [self.execute_task(task_id, executor) for task_id in task_ids if self.tasks[task_id].status == TaskStatus.IN_PROGRESS]
        await asyncio.gather(*execution_tasks)

        # Handle failed tasks (reassign)
        for task_id in task_ids:
            if self.tasks[task_id].status == TaskStatus.FAILED:
                await self.reassign_task(task_id)
                if self.tasks[task_id].status == TaskStatus.PENDING:
                    await self.assign_task(task_id)
                    if self.tasks[task_id].status == TaskStatus.IN_PROGRESS:
                        await self.execute_task(task_id, executor)

        # Aggregate results
        aggregated_results = await self.aggregate_results()
        print("Aggregated Results:", aggregated_results)

        # Track progress
        progress = await self.track_progress()
        print("Progress:", progress)

async def dummy_executor(task_description: str, agent_role: str) -> str:
    """A dummy executor function that simulates task execution."""
    await asyncio.sleep(0.1)  # Simulate work
    return f"Task '{task_description}' completed by {agent_role}."

if __name__ == "__main__":
    async def main():
        # Example Usage
        coordinator = SwarmCoordinator(budget=500.0)

        # Define some tasks
        tasks = [
            {"description": "Research climate change impact on coastal cities", "required_capabilities": ["research", "analysis"]},
            {"description": "Develop a Python script to automate data analysis", "required_capabilities": ["coding", "analysis"]},
            {"description": "Write a report summarizing the research findings", "required_capabilities": ["writing", "summary"]},
        ]

        # Create a dummy swarm_config.json (or load a real one)
        config_data = {
            "agents": [
                {"agent_id": "research_agent_1", "role": "Researcher", "capabilities": ["research", "analysis"]},
                {"agent_id": "coding_agent_1", "role": "Developer", "capabilities": ["coding", "analysis"]},
                {"agent_id": "writer_agent_1", "role": "Writer", "capabilities": ["writing", "summary"]},
            ]
        }
        os.makedirs(coordinator.workspace / "config", exist_ok=True)
        with open(coordinator.config_path, "w") as f:
            json.dump(config_data, f, indent=4)
        coordinator._load_config() # Reload config to apply changes

        # Run the swarm to process the tasks
        await coordinator.run_swarm(tasks, dummy_executor)

        # Print remaining budget
        print(f"Remaining Budget: ${coordinator.budget:.2f}")

    asyncio.run(main())