# queen_orchestrator.py
import time
import datetime
import subprocess
import json
import os
import logging
import threading
import random  # For simulating agent activity and failures

# --- Configuration ---
SPRINT_PLAN_PATH = "QUEEN_ELEVATION_SPRINT_PLAN.md"
GENESIS_PACKAGE_PATH = "GENESIS_COMPLETE_PACKAGE.md"
CHECKPOINTS_DIR = "sprint-checkpoints"
LOG_FILE = "queen_orchestrator.log"
AGENT_COUNT = 50

# --- Logging Setup ---
logging.basicConfig(filename=LOG_FILE, level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

# --- Utility Functions ---
def load_sprint_plan(filepath):
    try:
        with open(filepath, 'r') as f:
            content = f.read()
            # Basic parsing (can be improved with regex or a proper parser)
            plan = {}
            plan['phases'] = []
            current_phase = None
            for line in content.splitlines():
                line = line.strip()
                if line.startswith("## PHASE"):
                    current_phase = {'name': line[3:].strip(), 'waves': []}
                    plan['phases'].append(current_phase)
                elif current_phase and line.startswith("#### Wave"):
                    current_wave = {'name': line[4:].strip(), 'agents': []}
                    current_phase['waves'].append(current_wave)
                elif current_wave and line.startswith("|") and "Agent ID" not in line and "----------" not in line:
                    parts = [p.strip() for p in line.split("|")[1:-1]] #splits the table row
                    if len(parts) >= 3:  # Adjust based on the expected number of columns in the table
                        agent_data = {
                            'agent_id': parts[0],
                            'role': parts[1],
                            'task': parts[2] if len(parts) > 2 else None,
                            'priority': parts[3] if len(parts) > 3 else None
                        }
                        current_wave['agents'].append(agent_data)
            return plan
    except FileNotFoundError:
        logging.error(f"Sprint plan file not found: {filepath}")
        return None

def load_genesis_package(filepath):
    try:
        with open(filepath, 'r') as f:
            return f.read()
    except FileNotFoundError:
        logging.error(f"Genesis package file not found: {filepath}")
        return None

def save_checkpoint(phase, data):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{CHECKPOINTS_DIR}/checkpoint_{phase.replace(' ', '_')}_{timestamp}.json"
    try:
        os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
        with open(filename, 'w') as f:
            json.dump(data, f, indent=4)
        logging.info(f"Checkpoint saved: {filename}")
        return filename
    except Exception as e:
        logging.error(f"Error saving checkpoint: {e}")
        return None

def load_checkpoint(filepath):
    try:
        with open(filepath, 'r') as f:
            return json.load(f)
    except FileNotFoundError:
        logging.warning(f"Checkpoint file not found: {filepath}")
        return None
    except Exception as e:
        logging.error(f"Error loading checkpoint: {e}")
        return None

def execute_shell_command(command):
    try:
        result = subprocess.run(command, shell=True, capture_output=True, text=True, check=True)
        logging.info(f"Command executed: {command}, Output: {result.stdout.strip()}")
        return result.stdout.strip()
    except subprocess.CalledProcessError as e:
        logging.error(f"Command failed: {command}, Error: {e.stderr.strip()}")
        return None

def update_status_log(phase, hour, budget_spent, input_tokens, output_tokens, active_agents, completed_agents, pending_agents, failed_agents, checkpoints):
    total_budget = 10.00
    budget_remaining = total_budget - budget_spent
    token_total = input_tokens + output_tokens
    token_percentage = (token_total / 53.75) * 100

    now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    phase_num = int(phase.split('/')[0].split(' ')[1]) if '/' in phase else 0 # Extract phase number
    status = f"""
╔══════════════════════════════════════════════════════════════════╗
║                    AIVA QUEEN SPRINT - TOKEN STATUS              ║
╠══════════════════════════════════════════════════════════════════╣
║ Time: {now} | Phase: {phase_num}/5 | Hour: {hour}/10            ║
╠══════════════════════════════════════════════════════════════════╣
║ BUDGET                                                           ║
║ ├── Total: ${total_budget:.2f}                                                ║
║ ├── Spent: ${budget_spent:.2f} ({budget_spent/total_budget*100:.1f}%)                                         ║
║ ├── Remaining: ${budget_remaining:.2f}                                             ║
║ └── Rate: ${budget_spent/hour if hour > 0 else 0:.2f}/hr                                               ║
╠══════════════════════════════════════════════════════════════════╣
║ TOKENS                                                           ║
║ ├── Input:  {input_tokens:.1f}M tokens consumed                                ║
║ ├── Output: {output_tokens:.1f}M tokens generated                                ║
║ └── Total:  {token_total:.1f}M / 53.75M ({token_percentage:.1f}%)                              ║
╠══════════════════════════════════════════════════════════════════╣
║ AGENTS                                                           ║
║ ├── Active: {active_agents}/{AGENT_COUNT}                                                ║
║ ├── Completed: {completed_agents}/{AGENT_COUNT}                                             ║
║ ├── Pending: {pending_agents}/{AGENT_COUNT}                                             ║
║ └── Failed: {failed_agents}/{AGENT_COUNT}                                                 ║
╠══════════════════════════════════════════════════════════════════╣
║ CHECKPOINTS                                                      ║
"""
    for checkpoint_name, checkpoint_status in checkpoints.items():
        status += f"║ ├── [{checkpoint_status}] Phase {checkpoint_name.split(' ')[1]}: {checkpoint_name}                           ║\n"

    status += "╚══════════════════════════════════════════════════════════════════╝"
    logging.info(status)
    print(status) # Print to console as well

# --- Core Classes ---
class TokenBudgetMonitor:
    BUDGET_LIMIT = 10.00  # USD
    EMERGENCY_STOP = 9.50  # USD

    def __init__(self):
        self.current_spend = 0.0
        self.lock = threading.Lock() # Thread safety

    def add_cost(self, cost):
        with self.lock:
            self.current_spend += cost
            logging.info(f"Added cost: ${cost:.2f}, Current spend: ${self.current_spend:.2f}")

    def check_budget(self):
        with self.lock:
            if self.current_spend >= self.EMERGENCY_STOP:
                logging.warning("Emergency stop threshold reached!")
                self.trigger_graceful_shutdown()
                return False
            return True # Budget OK

    def trigger_graceful_shutdown(self):
        logging.critical("Initiating graceful shutdown due to budget exceeded!")
        # Implement shutdown logic here (stop agents, save state, etc.)
        orchestrator.stop_all_agents() # Example: Call orchestrator method
        print("System shutting down gracefully...")
        os._exit(1) # Force exit

class Agent(threading.Thread):
    def __init__(self, agent_id, role, task, budget_monitor, initial_state=None):
        super().__init__()
        self.agent_id = agent_id
        self.role = role
        self.task = task
        self.budget_monitor = budget_monitor
        self.active = False
        self.completed = False
        self.failed = False
        self.state = initial_state or {}  # Agent state
        self.start_time = None
        self.end_time = None
        self.input_tokens = 0
        self.output_tokens = 0
        self.lock = threading.Lock() #Thread safety

    def run(self):
        self.active = True
        self.start_time = time.time()
        logging.info(f"Agent {self.agent_id} ({self.role}) started: {self.task}")
        try:
            self.execute_task()
        except Exception as e:
            self.failed = True
            logging.exception(f"Agent {self.agent_id} failed: {e}")
        finally:
            self.active = False
            self.completed = not self.failed
            self.end_time = time.time()
            logging.info(f"Agent {self.agent_id} finished. Completed: {self.completed}, Failed: {self.failed}")

    def execute_task(self):
        # Simulate task execution with random token usage and cost
        # Replace with actual task logic using Gemini Flash 2.0 or other models
        estimated_input_tokens = random.randint(10000, 50000) #Simulate input
        estimated_output_tokens = random.randint(5000, 20000) #Simulate output

        input_cost = (estimated_input_tokens / 1000000) * 0.10
        output_cost = (estimated_output_tokens / 1000000) * 0.40
        total_cost = input_cost + output_cost

        with self.lock:
            if not self.budget_monitor.check_budget():
                logging.warning(f"Agent {self.agent_id} stopping due to budget limit.")
                return # Stop if budget exceeded

            self.budget_monitor.add_cost(total_cost)
            self.input_tokens = estimated_input_tokens
            self.output_tokens = estimated_output_tokens

        logging.info(f"Agent {self.agent_id} running task. Estimated cost: ${total_cost:.2f}, Input tokens: {estimated_input_tokens}, Output tokens: {estimated_output_tokens}")

        # Simulate work
        time.sleep(random.randint(1, 5)) # Simulate task duration

        # Simulate outcome (success or failure)
        if random.random() < 0.2: # 20% chance of failure
            raise Exception("Simulated task failure.")

        logging.info(f"Agent {self.agent_id} completed task successfully.")

    def get_token_usage(self):
        with self.lock:
             return self.input_tokens, self.output_tokens


class QueenOrchestrator:
    def __init__(self, sprint_plan_path, genesis_package_path):
        self.sprint_plan = load_sprint_plan(sprint_plan_path)
        self.genesis_package = load_genesis_package(genesis_package_path)
        self.budget_monitor = TokenBudgetMonitor()
        self.agents = {}
        self.active_phase = None
        self.start_time = None
        self.end_time = None
        self.checkpoints = {
            "Phase 1: Foundation": " ",
            "Phase 2: Knowledge": " ",
            "Phase 3: Capabilities": " ",
            "Phase 4: Swarm": " ",
            "Phase 5: Coronation": " "
        }

        self.lock = threading.Lock() # Thread safety

        if not self.sprint_plan or not self.genesis_package:
            logging.error("Failed to load sprint plan or genesis package. Exiting.")
            os._exit(1)

        self.create_agents()

    def create_agents(self):
        agent_id_prefix_map = {
            "INFRA": "Infrastructure",
            "LOOP": "Consciousness",
            "PATENT": "Patent Processing",
            "GATE": "Validation Gates",
            "CAP": "Capabilities",
            "BRIDGE": "Integration",
            "HIVE": "Hive Architecture",
            "RANK": "Rank Validation",
            "FINAL": "Final Systems"
        }

        for phase in self.sprint_plan['phases']:
            for wave in phase['waves']:
                for agent_data in wave['agents']:
                    agent_id = agent_data['agent_id']
                    role = agent_data['role']
                    task = agent_data['task']

                    agent_id_prefix = agent_id.split('_')[0]
                    agent_group = agent_id_prefix_map.get(agent_id_prefix, "Unknown")
                    agent_name = f"{agent_group}: {role}"  # Combine group and role for agent name

                    agent = Agent(agent_id, agent_name, task, self.budget_monitor)
                    self.agents[agent_id] = agent
        logging.info(f"Created {len(self.agents)} agents.")

    def execute_phase(self, phase_data):
        phase_name = phase_data['name']
        logging.info(f"Starting phase: {phase_name}")
        self.active_phase = phase_name

        active_agents = 0
        completed_agents = 0
        pending_agents = 0
        failed_agents = 0

        # Launch agents for all waves in the phase
        agent_threads = []
        for wave in phase_data['waves']:
            for agent_data in wave['agents']:
                agent_id = agent_data['agent_id']
                agent = self.agents.get(agent_id)
                if agent:
                    agent_threads.append(agent)
                    pending_agents += 1
                    agent.start()
                else:
                    logging.warning(f"Agent {agent_id} not found in agent registry.")

        # Monitor agents and update status log
        start_time = time.time()
        while any(agent.is_alive() for agent in agent_threads):
            time.sleep(60) # Check every 60 seconds

            active_agents = sum(1 for agent in self.agents.values() if agent.active)
            completed_agents = sum(1 for agent in self.agents.values() if agent.completed)
            pending_agents = sum(1 for agent in self.agents.values() if not agent.active and not agent.completed and not agent.failed)
            failed_agents = sum(1 for agent in self.agents.values() if agent.failed)

            input_tokens = sum(agent.input_tokens for agent in self.agents.values())/1000000
            output_tokens = sum(agent.output_tokens for agent in self.agents.values())/1000000

            elapsed_time = time.time() - start_time
            update_status_log(
                phase_name,
                elapsed_time / 3600,  # Hours
                self.budget_monitor.current_spend,
                input_tokens,
                output_tokens,
                active_agents,
                completed_agents,
                pending_agents,
                failed_agents,
                self.checkpoints
            )

            if not self.budget_monitor.check_budget():
                logging.warning("Budget exceeded during phase. Terminating agents.")
                self.stop_all_agents()
                return False

        # Wait for all agents to complete
        for agent in agent_threads:
            agent.join()

        # Check if all agents completed successfully
        phase_success = all(agent.completed for agent in self.agents.values() if agent.agent_id in [a['agent_id'] for wave in phase_data['waves'] for a in wave['agents']])

        if phase_success:
            logging.info(f"Phase {phase_name} completed successfully.")
            self.checkpoints[phase_name] = "✓"

            # Save checkpoint data (can be improved with more relevant data)
            checkpoint_data = {
                "phase": phase_name,
                "agents": {agent_id: agent.state for agent_id, agent in self.agents.items()},
                "budget_spent": self.budget_monitor.current_spend
            }
            save_checkpoint(phase_name, checkpoint_data)
            return True
        else:
            logging.error(f"Phase {phase_name} failed. Some agents did not complete successfully.")
            self.checkpoints[phase_name] = "✗"
            return False


    def run_sprint(self):
        logging.info("Starting AIVA Queen Elevation Sprint...")
        self.start_time = time.time()

        for phase in self.sprint_plan['phases']:
            if not self.execute_phase(phase):
                logging.critical(f"Sprint aborted during phase: {phase['name']}")
                self.end_time = time.time()
                return False

        self.end_time = time.time()
        logging.info("AIVA Queen Elevation Sprint completed!")
        return True

    def stop_all_agents(self):
        logging.info("Stopping all agents...")
        for agent in self.agents.values():
            # Implement a way to gracefully stop agents.  This might involve setting a flag
            # that the agent checks periodically, or using a threading.Event.
            agent.active = False #Basic Attempt
            logging.info(f"Agent {agent.agent_id} stopped.")

    def run_watchdog(self):
        while True:
            # Check system health (e.g., CPU usage, memory usage, disk space)
            # Implement budget check
            self.budget_monitor.check_budget()

            # Implement checkpoint verification and auto-recovery if needed
            time.sleep(300)  # Check every 5 minutes

    def finalize_sprint(self):
        # Generate documentation, finalize audit trail, etc.
        logging.info("Finalizing sprint and generating documentation...")

        # Example: Execute shell command to generate documentation
        # documentation_command = "python3 /mnt/e/genesis-system/AIVA/documentation_generator.py"
        # execute_shell_command(documentation_command)

        # Example: Save final state to a file
        final_state = {
            "status": "Queen Achieved",
            "checkpoints": self.checkpoints,
            "budget_spent": self.budget_monitor.current_spend,
            "end_time": datetime.datetime.now().isoformat()
        }
        save_checkpoint("final_state", final_state)
        logging.info("Sprint finalized.")

# --- Main Execution ---
if __name__ == "__main__":
    orchestrator = QueenOrchestrator(SPRINT_PLAN_PATH, GENESIS_PACKAGE_PATH)

    # Start the watchdog process in a separate thread
    watchdog_thread = threading.Thread(target=orchestrator.run_watchdog, daemon=True)
    watchdog_thread.start()

    success = orchestrator.run_sprint()

    if success:
        orchestrator.finalize_sprint()
    else:
        logging.error("Sprint failed.")