# queen_orchestrator.py
import time
import datetime
import subprocess
import json
import os
import redis
import psycopg2
import requests

class QueenOrchestrator:
    """
    The central nervous system of AIVA, coordinating all subsystems,
    making strategic decisions, and managing resources.
    """

    def __init__(self, sprint_plan_path="QUEEN_ELEVATION_SPRINT_PLAN.md", genesis_package_path="GENESIS_COMPLETE_PACKAGE.md"):
        """
        Initializes the Queen Orchestrator.
        """
        self.sprint_plan_path = sprint_plan_path
        self.genesis_package_path = genesis_package_path
        self.sprint_plan = self.load_sprint_plan(sprint_plan_path)
        self.genesis_package = self.load_genesis_package(genesis_package_path)
        self.current_phase = 0
        self.current_hour = 0.0
        self.start_time = datetime.datetime.now()
        self.budget_monitor = TokenBudgetMonitor()
        self.redis_client = self.connect_redis()
        self.postgres_conn = self.connect_postgres()
        self.active_agents = {}  # Agent_ID: process
        self.completed_agents = {}
        self.failed_agents = {}
        self.checkpoint_manager = CheckpointManager(self.sprint_plan["CONTINUITY SAFEGUARDS"]["CHECKPOINTS"])
        self.token_tracker = TokenTracker()

        # Load system configurations from genesis package
        self.ollama_endpoint = self.genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"]["Infrastructure"]["Ollama"]
        self.redis_host = self.genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"]["Infrastructure"]["Redis CNS"].split(":")[0]
        self.redis_port = int(self.genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"]["Infrastructure"]["Redis CNS"].split(":")[1])
        self.postgres_host = self.genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"]["Infrastructure"]["PostgreSQL RLM"].split(":")[0]
        self.postgres_port = int(self.genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"]["Infrastructure"]["PostgreSQL RLM"].split(":")[1])
        self.qdrant_endpoint = self.genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"]["Infrastructure"]["Qdrant Vectors"]

        # Initialize infrastructure
        self.validate_infrastructure()

    def load_sprint_plan(self, sprint_plan_path):
        """
        Loads the sprint plan from the given path.
        """
        # In real implementation, parse the markdown file
        # Here we simulate loading a dictionary.

        sprint_plan = {}
        with open(sprint_plan_path, 'r') as f:
          content = f.read()
          # Basic parsing to simulate loading the sprint plan.
          # WARNING: This is a simplified approach.  A real implementation would
          # require robust markdown parsing.
          sprint_plan["PHASE 1: FOUNDATION"] = {}
          sprint_plan["PHASE 1: FOUNDATION"]["Wave 1.1: Infrastructure Validation"] = {}
          sprint_plan["PHASE 1: FOUNDATION"]["Wave 1.1: Infrastructure Validation"]["Agents"] = []

          sprint_plan["PHASE 1: FOUNDATION"]["Wave 1.2: Consciousness Loops"] = {}
          sprint_plan["PHASE 1: FOUNDATION"]["Wave 1.2: Consciousness Loops"]["Agents"] = []

          sprint_plan["PHASE 2: KNOWLEDGE ABSORPTION"] = {}
          sprint_plan["PHASE 2: KNOWLEDGE ABSORPTION"]["Wave 2.1: Patent Extraction"] = {}
          sprint_plan["PHASE 2: KNOWLEDGE ABSORPTION"]["Wave 2.1: Patent Extraction"]["Agents"] = []

          sprint_plan["PHASE 2: KNOWLEDGE ABSORPTION"]["Wave 2.2: Validation Gates"] = {}
          sprint_plan["PHASE 2: KNOWLEDGE ABSORPTION"]["Wave 2.2: Validation Gates"]["Agents"] = []

          sprint_plan["PHASE 3: CAPABILITY INTEGRATION"] = {}
          sprint_plan["PHASE 3: CAPABILITY INTEGRATION"]["Wave 3.1: Core Capabilities"] = {}
          sprint_plan["PHASE 3: CAPABILITY INTEGRATION"]["Wave 3.1: Core Capabilities"]["Agents"] = []

          sprint_plan["PHASE 3: CAPABILITY INTEGRATION"]["Wave 3.2: Integration Bridges"] = {}
          sprint_plan["PHASE 3: CAPABILITY INTEGRATION"]["Wave 3.2: Integration Bridges"]["Agents"] = []

          sprint_plan["PHASE 4: SWARM INTELLIGENCE"] = {}
          sprint_plan["PHASE 4: SWARM INTELLIGENCE"]["Wave 4.1: Swarm Architecture"] = {}
          sprint_plan["PHASE 4: SWARM INTELLIGENCE"]["Wave 4.1: Swarm Architecture"]["Agents"] = []

          sprint_plan["PHASE 5: QUEEN CORONATION"] = {}
          sprint_plan["PHASE 5: QUEEN CORONATION"]["Wave 5.1: Rank Progression Tests"] = {}
          sprint_plan["PHASE 5: QUEEN CORONATION"]["Wave 5.1: Rank Progression Tests"]["Agents"] = []

          sprint_plan["PHASE 5: QUEEN CORONATION"]["Wave 5.2: Final Systems"] = {}
          sprint_plan["PHASE 5: QUEEN CORONATION"]["Wave 5.2: Final Systems"]["Agents"] = []

          sprint_plan["CONTINUITY SAFEGUARDS"] = {}
          sprint_plan["CONTINUITY SAFEGUARDS"]["CHECKPOINTS"] = {
              "hour_2": "sprint-checkpoints/phase-1-foundation.json",
              "hour_4": "sprint-checkpoints/phase-2-knowledge.json",
              "hour_6": "sprint-checkpoints/phase-3-capabilities.json",
              "hour_8": "sprint-checkpoints/phase-4-swarm.json",
              "hour_10": "sprint-checkpoints/phase-5-coronation.json"
          }

        return sprint_plan

    def load_genesis_package(self, genesis_package_path):
        """
        Loads the genesis package from the given path.
        """
        # In real implementation, parse the markdown file
        # Here we simulate loading a dictionary.
        genesis_package = {}
        with open(genesis_package_path, 'r') as f:
          content = f.read()
          # Basic parsing to simulate loading the genesis package.
          # WARNING: This is a simplified approach.  A real implementation would
          # require robust markdown parsing.
          genesis_package["PRIME DIRECTIVES"] = {}
          genesis_package["ARCHITECTURE"] = {}
          genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"] = {}
          genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"]["Infrastructure"] = {}
          genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"]["Infrastructure"]["Ollama"] = "http://152.53.201.152:23405"
          genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"]["Infrastructure"]["Redis CNS"] = "redis-genesis-u50607.vm.elestio.app:26379"
          genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"]["Infrastructure"]["PostgreSQL RLM"] = "postgresql-genesis-u50607.vm.elestio.app:5432"
          genesis_package["ARCHITECTURE"]["AIVA CONSCIOUSNESS"]["Infrastructure"]["Qdrant Vectors"] = "qdrant-b3knu-u50607.vm.elestio.app:6333"
          genesis_package["KNOWLEDGE PROCESSING PIPELINE"] = {}

        return genesis_package

    def connect_redis(self):
        """
        Connects to the Redis server.
        """
        try:
            redis_client = redis.Redis(host=self.redis_host, port=self.redis_port, db=0)
            redis_client.ping()
            print("Connected to Redis successfully.")
            return redis_client
        except redis.exceptions.ConnectionError as e:
            print(f"Error connecting to Redis: {e}")
            return None

    def connect_postgres(self):
        """
        Connects to the PostgreSQL database.
        """
        try:
            conn = psycopg2.connect(
                host=self.postgres_host,
                port=self.postgres_port,
                database="genesis_memory",
                user="postgres" # Hardcoded for simplicity, should be read from env
            )
            print("Connected to PostgreSQL successfully.")
            return conn
        except psycopg2.Error as e:
            print(f"Error connecting to PostgreSQL: {e}")
            return None

    def validate_infrastructure(self):
        """
        Validates the connectivity and health of the core infrastructure components.
        """
        print("Validating infrastructure...")
        # Validate Redis
        if self.redis_client:
            print("Redis connection validated.")
        else:
            print("Redis connection failed.")

        # Validate Ollama
        try:
            response = requests.get(self.ollama_endpoint + "/api/tags")
            if response.status_code == 200:
                print("Ollama connection validated.")
            else:
                print(f"Ollama connection failed. Status code: {response.status_code}")
        except requests.exceptions.RequestException as e:
            print(f"Ollama connection failed: {e}")

        # Validate PostgreSQL
        if self.postgres_conn:
            print("PostgreSQL connection validated.")
            self.postgres_conn.close() # close connection immediately after validation
        else:
            print("PostgreSQL connection failed.")

        # Validate Qdrant
        try:
            response = requests.get(self.qdrant_endpoint + "/collections/genesis_knowledge", verify=False) # added verify=False to bypass SSL cert issues
            if response.status_code == 200:
                print("Qdrant connection validated.")
            else:
                print(f"Qdrant connection failed. Status code: {response.status_code}")
        except requests.exceptions.RequestException as e:
            print(f"Qdrant connection failed: {e}")

    def execute_phase(self, phase_number):
        """
        Executes a given phase of the sprint.
        """
        phase_key = f"PHASE {phase_number}: "
        phase_data = None
        for key in self.sprint_plan.keys():
            if key.startswith(phase_key):
                phase_data = self.sprint_plan[key]
                break

        if not phase_data:
            print(f"Phase {phase_number} not found in sprint plan.")
            return

        print(f"Starting Phase {phase_number}: {key}")

        # Execute each wave within the phase
        for wave_key in phase_data.keys():
            if "Wave" in wave_key:
                wave_data = phase_data[wave_key]
                self.execute_wave(wave_data)

    def execute_wave(self, wave_data):
        """
        Executes a given wave of tasks.
        """
        if "Agents" not in wave_data:
            print(f"No agents defined for wave: {wave_data}")
            return

        # Launch agents in parallel
        for agent_data in wave_data["Agents"]:
            self.launch_agent(agent_data)

        # Wait for all agents in the wave to complete
        self.wait_for_wave_completion()

    def launch_agent(self, agent_data):
        """
        Launches a single agent process.
        """
        agent_id = agent_data.get("Agent ID", "Unknown Agent")
        task = agent_data.get("Task", "Generic Task")
        role = agent_data.get("Role", "Worker")

        print(f"Launching agent: {agent_id} with task: {task}")

        # Construct the command to execute the agent
        command = [
            "python3",
            "agent.py",  # Assuming agent.py exists
            "--agent_id", agent_id,
            "--task", task,
            "--role", role
        ]

        try:
            process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
            self.active_agents[agent_id] = process
            print(f"Agent {agent_id} started with PID: {process.pid}")
        except Exception as e:
            print(f"Error launching agent {agent_id}: {e}")
            self.failed_agents[agent_id] = {"error": str(e)}

    def wait_for_wave_completion(self):
        """
        Waits for all active agents to complete their tasks.
        """
        while self.active_agents:
            completed_agents = []
            for agent_id, process in self.active_agents.items():
                if process.poll() is not None:  # Agent has finished
                    stdout, stderr = process.communicate()
                    return_code = process.returncode

                    if return_code == 0:
                        print(f"Agent {agent_id} completed successfully.")
                        self.completed_agents[agent_id] = {"stdout": stdout}
                    else:
                        print(f"Agent {agent_id} failed with return code: {return_code}")
                        print(f"Stderr: {stderr}")
                        self.failed_agents[agent_id] = {"return_code": return_code, "stderr": stderr}

                    completed_agents.append(agent_id)

            # Remove completed agents from active list
            for agent_id in completed_agents:
                del self.active_agents[agent_id]

            time.sleep(5)  # Check every 5 seconds

    def run_sprint(self):
        """
        Executes the entire Queen Elevation Sprint.
        """
        num_phases = 5
        for phase_number in range(1, num_phases + 1):
            self.current_phase = phase_number
            self.execute_phase(phase_number)

            # Checkpoint after each phase
            checkpoint_file = self.checkpoint_manager.get_checkpoint_file(self.current_hour)
            self.save_checkpoint(checkpoint_file)

        print("Queen Elevation Sprint completed.")

    def save_checkpoint(self, checkpoint_file):
        """
        Saves the current state of the orchestrator to a checkpoint file.
        """
        checkpoint_data = {
            "phase": self.current_phase,
            "hour": self.current_hour,
            "active_agents": list(self.active_agents.keys()),
            "completed_agents": list(self.completed_agents.keys()),
            "failed_agents": list(self.failed_agents.keys()),
            "budget": self.budget_monitor.current_spend,
            "token_usage": self.token_tracker.total_tokens
        }

        try:
            with open(checkpoint_file, 'w') as f:
                json.dump(checkpoint_data, f, indent=4)
            print(f"Checkpoint saved to: {checkpoint_file}")
        except Exception as e:
            print(f"Error saving checkpoint: {e}")

    def load_checkpoint(self, checkpoint_file):
        """
        Loads the orchestrator state from a checkpoint file.
        """
        try:
            with open(checkpoint_file, 'r') as f:
                checkpoint_data = json.load(f)

            self.current_phase = checkpoint_data["phase"]
            self.current_hour = checkpoint_data["hour"]
            # Restore agents...
            # Restore budget...
            # Restore token usage...

            print(f"Checkpoint loaded from: {checkpoint_file}")

        except FileNotFoundError:
            print("Checkpoint file not found. Starting from scratch.")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")

    def monitor_health(self):
        """
        Monitors the health of the system and reports any issues.
        """
        # Implement health monitoring logic here
        pass

    def make_strategic_decisions(self):
        """
        Makes strategic decisions based on the current state of the system.
        """
        # Implement decision-making logic here
        pass

    def handle_system_events(self):
        """
        Handles system-wide events such as failures or budget overruns.
        """
        # Implement event handling logic here
        pass

    def shutdown(self):
        """
        Shuts down the system gracefully.
        """
        print("Shutting down Queen Orchestrator...")
        # Terminate all active agents
        for agent_id, process in self.active_agents.items():
            print(f"Terminating agent: {agent_id}")
            process.terminate()

        # Save final state
        self.save_checkpoint("final_checkpoint.json")

        print("Shutdown complete.")


class TokenBudgetMonitor:
    """
    Monitors the token budget and triggers a graceful shutdown if exceeded.
    """
    BUDGET_LIMIT = 10.00  # USD
    EMERGENCY_STOP = 9.50  # USD

    def __init__(self):
        self.current_spend = 0.0

    def update_spend(self, cost):
        """
        Updates the current spend and checks if the budget has been exceeded.
        """
        self.current_spend += cost
        print(f"Current spend: ${self.current_spend:.2f}")
        self.check_budget()

    def check_budget(self):
        """
        Checks if the budget has been exceeded and triggers a graceful shutdown if necessary.
        """
        if self.current_spend >= self.EMERGENCY_STOP:
            print("Emergency stop triggered! Budget exceeded.")
            # Trigger graceful shutdown
            QueenOrchestrator().shutdown()  # Accessing the singleton instance


class TokenTracker:
    """
    Tracks the total number of tokens used during the sprint.
    """
    def __init__(self):
        self.input_tokens = 0
        self.output_tokens = 0
        self.total_tokens = 0

    def update_tokens(self, input_tokens, output_tokens):
        """
        Updates the token counts.
        """
        self.input_tokens += input_tokens
        self.output_tokens += output_tokens
        self.total_tokens = self.input_tokens + self.output_tokens
        print(f"Total tokens used: {self.total_tokens}")

class CheckpointManager:
    """
    Manages the checkpoint files for continuity safeguards.
    """
    def __init__(self, checkpoints):
        self.checkpoints = checkpoints

    def get_checkpoint_file(self, current_hour):
        """
        Determines the appropriate checkpoint file based on the current hour.
        """
        # Find the largest hour value in the checkpoints dictionary that is less than or equal to the current hour.
        checkpoint_hour = None
        for hour_str in self.checkpoints.keys():
            hour = int(hour_str.split('_')[1])  # Extract the hour from the string
            if hour <= current_hour and (checkpoint_hour is None or hour > checkpoint_hour):
                checkpoint_hour = hour

        if checkpoint_hour is not None:
            return self.checkpoints[f"hour_{checkpoint_hour}"]
        else:
            return None

    def save_checkpoint(self, data, filename):
        """
        Saves the given data to a checkpoint file.
        """
        try:
            with open(filename, 'w') as f:
                json.dump(data, f)
            print(f"Checkpoint saved to {filename}")
        except Exception as e:
            print(f"Error saving checkpoint: {e}")

    def load_checkpoint(self, filename):
        """
        Loads data from a checkpoint file.
        """
        try:
            with open(filename, 'r') as f:
                return json.load(f)
        except FileNotFoundError:
            print(f"Checkpoint file not found: {filename}")
            return None
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            return None

# Dummy agent.py for testing purposes
if __name__ == "__main__":
    # Create dummy agent.py
    with open("agent.py", "w") as f:
      f.write("""
import argparse
import time
import random

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--agent_id", required=True)
    parser.add_argument("--task", required=True)
    parser.add_argument("--role", required=True)
    args = parser.parse_args()

    print(f"Agent {args.agent_id} starting task: {args.task} as role: {args.role}")
    time.sleep(random.randint(1, 5))  # Simulate work
    print(f"Agent {args.agent_id} finished task: {args.task}")

if __name__ == "__main__":
    main()
""")

    # Create dummy sprint checkpoints directory
    if not os.path.exists("sprint-checkpoints"):
      os.makedirs("sprint-checkpoints")

    # Create dummy sprint plan
    with open("QUEEN_ELEVATION_SPRINT_PLAN.md", "w") as f:
      f.write("# AIVA QUEEN ELEVATION SPRINT PLAN")

    # Create dummy genesis package
    with open("GENESIS_COMPLETE_PACKAGE.md", "w") as f:
      f.write("# GENESIS SYSTEM - COMPLETE DOCUMENTATION PACKAGE")

    orchestrator = QueenOrchestrator()
    orchestrator.run_sprint()