# queen_orchestrator.py
import time
import datetime
import subprocess
import json
import os
import logging
import random  # For simulating agent completion and failures
import threading

# --- Configuration ---
SPRINT_PLAN_PATH = "QUEEN_ELEVATION_SPRINT_PLAN.md"
GENESIS_PACKAGE_PATH = "GENESIS_COMPLETE_PACKAGE.md"
CHECKPOINTS_DIR = "sprint-checkpoints"

# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Utility Functions ---
def load_sprint_plan(filepath):
    """Loads the sprint plan from a markdown file."""
    try:
        with open(filepath, 'r') as f:
            content = f.read()
        # Basic parsing - improve this for real-world usage (regex, etc.)
        plan = {}
        plan['phases'] = []
        current_phase = None
        lines = content.splitlines()
        for line in lines:
            if line.startswith("## PHASE"):
                current_phase = {}
                current_phase['name'] = line.split(":")[1].strip()
                current_phase['waves'] = []
                plan['phases'].append(current_phase)
            elif current_phase and line.startswith("#### Wave"):
                current_wave = {}
                current_wave['name'] = line.split(":")[1].strip() if ":" in line else line.split("Wave")[1].strip() # handle missing colon
                current_wave['tasks'] = []
                current_phase['waves'].append(current_wave)
            elif current_wave and line.startswith("|"):
                # Basic table parsing - assumes consistent format
                parts = [p.strip() for p in line.split("|") if p.strip()]
                if len(parts) > 1 and parts[0] != "Agent ID":  # Skip header row
                    task = {
                        'agent_id': parts[0],
                        'role': parts[1] if len(parts) > 1 else '',
                        'task': parts[2] if len(parts) > 2 else '',
                        'priority': parts[3] if len(parts) > 3 else ''
                    }
                    current_wave['tasks'].append(task)
        return plan
    except FileNotFoundError:
        logging.error(f"Sprint plan file not found: {filepath}")
        return None
    except Exception as e:
        logging.error(f"Error loading sprint plan: {e}")
        return None

def load_genesis_package(filepath):
    """Loads the genesis package from a markdown file."""
    try:
        with open(filepath, 'r') as f:
            content = f.read()
        return content
    except FileNotFoundError:
        logging.error(f"Genesis package file not found: {filepath}")
        return None
    except Exception as e:
        logging.error(f"Error loading genesis package: {e}")
        return None

def execute_shell_command(command):
    """Executes a shell command and returns the output."""
    try:
        result = subprocess.run(command, shell=True, capture_output=True, text=True, check=True)
        return result.stdout.strip()
    except subprocess.CalledProcessError as e:
        logging.error(f"Command failed: {e}")
        return None

def save_checkpoint(phase_name, data):
    """Saves a checkpoint to a JSON file."""
    filename = f"{CHECKPOINTS_DIR}/{phase_name.lower().replace(' ', '_')}_checkpoint.json"
    try:
        os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
        with open(filename, 'w') as f:
            json.dump(data, f, indent=4)
        logging.info(f"Checkpoint saved: {filename}")
    except Exception as e:
        logging.error(f"Error saving checkpoint: {e}")

def load_checkpoint(phase_name):
    """Loads a checkpoint from a JSON file."""
    filename = f"{CHECKPOINTS_DIR}/{phase_name.lower().replace(' ', '_')}_checkpoint.json"
    try:
        with open(filename, 'r') as f:
            data = json.load(f)
        logging.info(f"Checkpoint loaded: {filename}")
        return data
    except FileNotFoundError:
        logging.warning(f"Checkpoint file not found: {filename}")
        return None
    except Exception as e:
        logging.error(f"Error loading checkpoint: {e}")
        return None

class TokenBudgetMonitor:
    BUDGET_LIMIT = 10.00  # USD
    EMERGENCY_STOP = 9.50  # USD

    def __init__(self):
        self.current_spend = 0.0
        self.start_time = time.time()

    def update_spend(self, cost):
        self.current_spend += cost

    def check_budget(self):
        if self.current_spend >= self.EMERGENCY_STOP:
            logging.warning("Emergency stop triggered: Budget exceeded.")
            self.trigger_graceful_shutdown()
            return False
        return True

    def trigger_graceful_shutdown(self):
        # Implement graceful shutdown logic here
        logging.info("Initiating graceful shutdown...")
        # Save current state, stop agents, etc.
        print("GRACEFUL SHUTDOWN TRIGGERED")
        exit()

    def get_status(self):
        elapsed_time = time.time() - self.start_time
        rate = self.current_spend / (elapsed_time / 3600) if elapsed_time > 0 else 0
        remaining = self.BUDGET_LIMIT - self.current_spend
        return {
            "total": self.BUDGET_LIMIT,
            "spent": self.current_spend,
            "remaining": remaining,
            "rate": rate
        }

# --- Agent Management ---
class Agent:
    def __init__(self, agent_id, role, task, status="pending", cost_per_run=0.01):  # Assume a default cost
        self.agent_id = agent_id
        self.role = role
        self.task = task
        self.status = status
        self.cost_per_run = cost_per_run
        self.start_time = None
        self.end_time = None
        self.result = None

    def start(self):
        logging.info(f"Agent {self.agent_id} starting task: {self.task}")
        self.status = "running"
        self.start_time = time.time()

    def complete(self, result="Success"): # Simulate success
        self.status = "completed"
        self.end_time = time.time()
        self.result = result
        logging.info(f"Agent {self.agent_id} completed task: {self.task} with result: {result}")

    def fail(self, reason="Unknown"): # Simulate failure
        self.status = "failed"
        self.end_time = time.time()
        self.result = reason
        logging.error(f"Agent {self.agent_id} failed task: {self.task} with reason: {reason}")

    def get_runtime(self):
        if self.start_time and self.end_time:
            return self.end_time - self.start_time
        elif self.start_time:
            return time.time() - self.start_time
        else:
            return 0

    def get_cost(self):
        if self.status in ["completed", "failed"]:
            return self.cost_per_run
        elif self.status == "running":
            # Estimate cost based on runtime (very basic)
            runtime = self.get_runtime()
            return (runtime / 60) * (self.cost_per_run * 2)  # Arbitrary estimate
        else:
            return 0

    def __repr__(self):
        return f"Agent(id={self.agent_id}, role={self.role}, task={self.task}, status={self.status})"

# --- Queen Orchestrator Class ---
class QueenOrchestrator:
    def __init__(self, sprint_plan_path, genesis_package_path):
        self.sprint_plan = load_sprint_plan(sprint_plan_path)
        self.genesis_package = load_genesis_package(genesis_package_path)
        self.budget_monitor = TokenBudgetMonitor()
        self.agents = {}
        self.active_agents = [] # Keep track of active agents
        self.completed_agents = []
        self.failed_agents = []
        self.current_phase_index = 0
        self.phase_start_time = None
        self.start_time = time.time()

        if not self.sprint_plan or not self.genesis_package:
            logging.error("Failed to initialize orchestrator.  Exiting.")
            exit()

        self.load_agents()

    def load_agents(self):
        """Loads agents from the sprint plan into the agent registry."""
        if not self.sprint_plan or 'phases' not in self.sprint_plan:
            logging.error("Invalid sprint plan format.")
            return

        for phase in self.sprint_plan['phases']:
            for wave in phase['waves']:
                if 'tasks' in wave:
                    for task_data in wave['tasks']:
                        agent_id = task_data['agent_id']
                        self.agents[agent_id] = Agent(
                            agent_id=agent_id,
                            role=task_data['role'],
                            task=task_data['task'],
                            priority=task_data.get('priority', 'MEDIUM')  # Default priority
                        )
        logging.info(f"Loaded {len(self.agents)} agents.")

    def start_phase(self):
        if self.current_phase_index >= len(self.sprint_plan['phases']):
            logging.info("Sprint complete!")
            return False  # Indicate sprint is complete

        self.current_phase = self.sprint_plan['phases'][self.current_phase_index]
        logging.info(f"Starting phase: {self.current_phase['name']}")
        self.phase_start_time = time.time()
        return True

    def execute_wave(self, wave):
        """Executes a wave of tasks."""
        logging.info(f"Executing wave: {wave['name']}")
        threads = []
        for task in wave['tasks']:
            agent_id = task['agent_id']
            if agent_id in self.agents and self.agents[agent_id].status == "pending":
                agent = self.agents[agent_id]
                agent.start()
                self.active_agents.append(agent)
                thread = threading.Thread(target=self.simulate_agent_execution, args=(agent,))
                threads.append(thread)
                thread.start()
            else:
                logging.warning(f"Agent {agent_id} not found or not pending.")

        for thread in threads:
            thread.join() # Wait for all agents in this wave to complete

    def simulate_agent_execution(self, agent):
        """Simulates an agent performing its task (replace with actual execution logic)."""
        # Simulate some work
        sleep_time = random.uniform(1, 5)  # Simulate different task durations
        time.sleep(sleep_time)

        # Simulate success or failure based on priority
        if agent.priority == "CRITICAL":
            if random.random() < 0.1:  # 10% chance of failure for critical tasks
                agent.fail(reason="Critical failure during execution")
                self.failed_agents.append(agent)
            else:
                agent.complete(result="Critical task successful")
                self.completed_agents.append(agent)
        elif agent.priority == "HIGH":
            if random.random() < 0.2:  # 20% chance of failure for high priority tasks
                agent.fail(reason="High priority task failed")
                self.failed_agents.append(agent)
            else:
                agent.complete(result="High priority task successful")
                self.completed_agents.append(agent)
        else:  # MEDIUM or LOW
            if random.random() < 0.05: # 5% chance of failure for medium/low priority
                agent.fail(reason="Task failed")
                self.failed_agents.append(agent)
            else:
                agent.complete(result="Task successful")
                self.completed_agents.append(agent)

        self.active_agents.remove(agent) # Remove from active list
        self.budget_monitor.update_spend(agent.get_cost())
        if not self.budget_monitor.check_budget():
            return  # Exit if budget exceeded

    def process_phase(self):
        """Processes the current phase by executing each wave sequentially."""
        if not self.current_phase:
            logging.error("No current phase to process.")
            return

        for wave in self.current_phase['waves']:
            self.execute_wave(wave)

        # Checkpoint after phase completion
        self.checkpoint()

    def checkpoint(self):
        """Saves a checkpoint of the current system state."""
        phase_name = self.current_phase['name']
        data = {
            'phase': phase_name,
            'agents': {agent_id: agent.__dict__ for agent_id, agent in self.agents.items()}, # Serialize agent objects
            'budget': self.budget_monitor.get_status(),
            'timestamp': datetime.datetime.now().isoformat()
        }
        save_checkpoint(phase_name, data)

    def load_checkpoint(self):
        """Loads a checkpoint and restores the system state."""
        if self.current_phase:
            phase_name = self.current_phase['name']
            data = load_checkpoint(phase_name)
            if data:
                # Restore agent states, budget, etc. (Implement this carefully)
                logging.info(f"Checkpoint loaded for phase: {phase_name}")
                # Example: Restore agent statuses (very basic)
                for agent_id, agent_data in data['agents'].items():
                    if agent_id in self.agents:
                        self.agents[agent_id].status = agent_data['status']
                        # Restore other agent attributes as needed
                self.budget_monitor.current_spend = data['budget']['spent']
                # Update agent lists based on restored statuses
                self.active_agents = [agent for agent in self.agents.values() if agent.status == "running"]
                self.completed_agents = [agent for agent in self.agents.values() if agent.status == "completed"]
                self.failed_agents = [agent for agent in self.agents.values() if agent.status == "failed"]
            else:
                logging.warning(f"No checkpoint found for phase: {phase_name}")
        else:
            logging.warning("No current phase to load checkpoint for.")

    def run(self):
        """Runs the main orchestration loop."""
        logging.info("Queen Orchestrator started.")

        # Load checkpoint if available at the start
        if self.current_phase_index == 0:
            self.load_checkpoint()

        while self.start_phase():
            self.process_phase()
            self.current_phase_index += 1

        logging.info("Queen Orchestrator finished.")

    def get_status_log(self):
        """Generates a status log string."""
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        phase_num = self.current_phase_index + 1 if self.current_phase_index < len(self.sprint_plan['phases']) else len(self.sprint_plan['phases'])
        hour = (time.time() - self.start_time) / 3600
        budget_status = self.budget_monitor.get_status()

        num_active = len(self.active_agents)
        num_completed = len(self.completed_agents)
        num_failed = len(self.failed_agents)
        num_pending = len(self.agents) - num_active - num_completed - num_failed

        log = f"""╔══════════════════════════════════════════════════════════════════╗
║                    AIVA QUEEN SPRINT - TOKEN STATUS              ║
╠══════════════════════════════════════════════════════════════════╣
║ Time: {now} | Phase: {phase_num}/{len(self.sprint_plan['phases'])} | Hour: {hour:.1f}/10            ║
╠══════════════════════════════════════════════════════════════════╣
║ BUDGET                                                           ║
║ ├── Total: ${budget_status['total']:.2f}                                                ║
║ ├── Spent: ${budget_status['spent']:.2f} ({budget_status['spent'] / budget_status['total'] * 100:.1f}%)                                         ║
║ ├── Remaining: ${budget_status['remaining']:.2f}                                             ║
║ └── Rate: ${budget_status['rate']:.2f}/hr                                               ║
╠══════════════════════════════════════════════════════════════════╣
║ TOKENS                                                           ║
║ ├── Input:  NOT IMPLEMENTED tokens consumed                                ║
║ ├── Output: NOT IMPLEMENTED tokens generated                                ║
║ └── Total:  NOT IMPLEMENTED / 53.75M (NOT IMPLEMENTED%)                              ║
╠══════════════════════════════════════════════════════════════════╣
║ AGENTS                                                           ║
║ ├── Active: {num_active}/{len(self.agents)}                                                ║
║ ├── Completed: {num_completed}/{len(self.agents)}                                             ║
║ ├── Pending: {num_pending}/{len(self.agents)}                                                 ║
║ └── Failed: {num_failed}/{len(self.agents)}                                                 ║
╠══════════════════════════════════════════════════════════════════╣
║ CHECKPOINTS                                                      ║"""

        for i, phase in enumerate(self.sprint_plan['phases']):
            status = "[✓]" if i < self.current_phase_index else "[ ]"
            if i == self.current_phase_index:
                status = "[◐]"
                phase_status = " - IN PROGRESS"
            else:
                phase_status = ""
            log += f"\n║ ├── {status} Phase {i+1}: {phase['name']} (Hour {(i+1)*2}){phase_status}                           ║"

        log += """\n╚══════════════════════════════════════════════════════════════════╝"""
        return log

# --- Main Execution ---
if __name__ == "__main__":
    orchestrator = QueenOrchestrator(SPRINT_PLAN_PATH, GENESIS_PACKAGE_PATH)
    # Create a thread to periodically print the status log
    def print_status_log():
        while True:
            print(orchestrator.get_status_log())
            time.sleep(60)  # Print status every 60 seconds

    status_thread = threading.Thread(target=print_status_log)
    status_thread.daemon = True # Allow main thread to exit even if this thread is running
    status_thread.start()

    orchestrator.run()