#!/usr/bin/env python3
"""
AIVA QUEEN ELEVATION SPRINT
===========================
50-Agent Gemini Swarm Execution Engine
10-Hour Continuous Sprint with Token Tracking

Budget: $10.00 USD
Agents: 50 Gemini Flash 2.0
Duration: 10 hours
"""

import os
import sys
import json
import asyncio
import time
import threading
from datetime import datetime, timedelta
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Any, Optional
from enum import Enum
import urllib.request
import urllib.error

# Configuration
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyCT_rx0NusUJWoqtT7uxHAKEfHo129SJb8")
GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"

# Pricing (per million tokens)
GEMINI_INPUT_COST = 0.10  # $0.10 per 1M input tokens
GEMINI_OUTPUT_COST = 0.40  # $0.40 per 1M output tokens

# Budget limits
BUDGET_LIMIT = 10.00
EMERGENCY_STOP = 9.50
CHECKPOINT_INTERVAL = 1800  # 30 minutes

# Paths
BASE_DIR = Path("/mnt/e/genesis-system/AIVA")
CHECKPOINT_DIR = BASE_DIR / "sprint-checkpoints"
STATUS_LOG = BASE_DIR / "sprint_status.log"
TOKEN_LOG = BASE_DIR / "token_usage.jsonl"


class AgentStatus(Enum):
    PENDING = "pending"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"
    PAUSED = "paused"


class Phase(Enum):
    FOUNDATION = 1
    KNOWLEDGE = 2
    CAPABILITIES = 3
    SWARM = 4
    CORONATION = 5


@dataclass
class TokenUsage:
    input_tokens: int = 0
    output_tokens: int = 0

    @property
    def total_tokens(self) -> int:
        return self.input_tokens + self.output_tokens

    @property
    def cost_usd(self) -> float:
        input_cost = (self.input_tokens / 1_000_000) * GEMINI_INPUT_COST
        output_cost = (self.output_tokens / 1_000_000) * GEMINI_OUTPUT_COST
        return input_cost + output_cost


@dataclass
class Agent:
    id: str
    role: str
    task: str
    tier: int
    phase: Phase
    status: AgentStatus = AgentStatus.PENDING
    tokens: TokenUsage = field(default_factory=TokenUsage)
    result: Optional[str] = None
    error: Optional[str] = None
    started_at: Optional[str] = None
    completed_at: Optional[str] = None


@dataclass
class SprintState:
    sprint_id: str
    start_time: str
    current_phase: Phase
    total_budget: float
    spent_budget: float
    total_tokens: TokenUsage
    agents: List[Agent]
    checkpoints_completed: List[int]
    is_running: bool
    emergency_stopped: bool = False


class TokenTracker:
    """Real-time token and budget tracking."""

    def __init__(self, budget_limit: float = BUDGET_LIMIT):
        self.budget_limit = budget_limit
        self.usage = TokenUsage()
        self.log_file = TOKEN_LOG
        self._lock = threading.Lock()

    def add_usage(self, input_tokens: int, output_tokens: int, agent_id: str):
        with self._lock:
            self.usage.input_tokens += input_tokens
            self.usage.output_tokens += output_tokens

            # Log to file
            entry = {
                "timestamp": datetime.now().isoformat(),
                "agent_id": agent_id,
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
                "cumulative_cost": self.usage.cost_usd
            }
            with open(self.log_file, "a") as f:
                f.write(json.dumps(entry) + "\n")

    @property
    def current_cost(self) -> float:
        return self.usage.cost_usd

    @property
    def remaining_budget(self) -> float:
        return max(0, self.budget_limit - self.current_cost)

    @property
    def should_stop(self) -> bool:
        return self.current_cost >= EMERGENCY_STOP


class GeminiClient:
    """Async Gemini API client with token tracking."""

    def __init__(self, api_key: str, tracker: TokenTracker):
        self.api_key = api_key
        self.tracker = tracker

    async def generate(self, prompt: str, agent_id: str,
                      max_tokens: int = 2048, temperature: float = 0.7) -> Dict[str, Any]:
        """Call Gemini API and track tokens."""

        if self.tracker.should_stop:
            return {"error": "Budget exceeded", "response": None}

        url = f"{GEMINI_API_URL}?key={self.api_key}"

        payload = {
            "contents": [{"parts": [{"text": prompt}]}],
            "generationConfig": {
                "maxOutputTokens": max_tokens,
                "temperature": temperature
            }
        }

        try:
            req = urllib.request.Request(
                url,
                data=json.dumps(payload).encode('utf-8'),
                headers={'Content-Type': 'application/json'},
                method='POST'
            )

            with urllib.request.urlopen(req, timeout=60) as resp:
                data = json.loads(resp.read().decode())

            # Extract response and token counts
            response_text = ""
            if "candidates" in data and data["candidates"]:
                parts = data["candidates"][0].get("content", {}).get("parts", [])
                response_text = "".join(p.get("text", "") for p in parts)

            # Get token usage from response metadata
            usage_metadata = data.get("usageMetadata", {})
            input_tokens = usage_metadata.get("promptTokenCount", len(prompt) // 4)
            output_tokens = usage_metadata.get("candidatesTokenCount", len(response_text) // 4)

            # Track usage
            self.tracker.add_usage(input_tokens, output_tokens, agent_id)

            return {
                "response": response_text,
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
                "error": None
            }

        except Exception as e:
            return {"error": str(e), "response": None, "input_tokens": 0, "output_tokens": 0}


class StatusDisplay:
    """Real-time status display for the sprint."""

    def __init__(self, state: SprintState, tracker: TokenTracker):
        self.state = state
        self.tracker = tracker

    def render(self) -> str:
        now = datetime.now()
        start = datetime.fromisoformat(self.state.start_time)
        elapsed = now - start
        hours_elapsed = elapsed.total_seconds() / 3600

        active = sum(1 for a in self.state.agents if a.status == AgentStatus.RUNNING)
        completed = sum(1 for a in self.state.agents if a.status == AgentStatus.COMPLETED)
        pending = sum(1 for a in self.state.agents if a.status == AgentStatus.PENDING)
        failed = sum(1 for a in self.state.agents if a.status == AgentStatus.FAILED)

        spent = self.tracker.current_cost
        remaining = self.tracker.remaining_budget
        rate = spent / max(hours_elapsed, 0.01)

        phase_markers = []
        for p in Phase:
            if p.value in self.state.checkpoints_completed:
                phase_markers.append(f"[✓] Phase {p.value}: {p.name}")
            elif p == self.state.current_phase:
                phase_markers.append(f"[◐] Phase {p.value}: {p.name} - IN PROGRESS")
            else:
                phase_markers.append(f"[ ] Phase {p.value}: {p.name}")

        return f"""
╔══════════════════════════════════════════════════════════════════╗
║                    AIVA QUEEN SPRINT - TOKEN STATUS              ║
╠══════════════════════════════════════════════════════════════════╣
║ Time: {now.strftime('%Y-%m-%d %H:%M:%S')} | Phase: {self.state.current_phase.value}/5 | Hour: {hours_elapsed:.1f}/10            ║
╠══════════════════════════════════════════════════════════════════╣
║ BUDGET                                                           ║
║ ├── Total: ${self.state.total_budget:.2f}                                                ║
║ ├── Spent: ${spent:.2f} ({spent/self.state.total_budget*100:.1f}%)                                         ║
║ ├── Remaining: ${remaining:.2f}                                             ║
║ └── Rate: ${rate:.2f}/hr                                               ║
╠══════════════════════════════════════════════════════════════════╣
║ TOKENS                                                           ║
║ ├── Input:  {self.tracker.usage.input_tokens/1_000_000:.1f}M tokens consumed                                ║
║ ├── Output: {self.tracker.usage.output_tokens/1_000_000:.1f}M tokens generated                                ║
║ └── Total:  {self.tracker.usage.total_tokens/1_000_000:.1f}M / 53.75M ({self.tracker.usage.total_tokens/53_750_000*100:.1f}%)                              ║
╠══════════════════════════════════════════════════════════════════╣
║ AGENTS                                                           ║
║ ├── Active: {active}/50                                                ║
║ ├── Completed: {completed}/50                                             ║
║ ├── Pending: {pending}/50                                               ║
║ └── Failed: {failed}/50                                                 ║
╠══════════════════════════════════════════════════════════════════╣
║ CHECKPOINTS                                                      ║
║ {phase_markers[0]:<60}║
║ {phase_markers[1]:<60}║
║ {phase_markers[2]:<60}║
║ {phase_markers[3]:<60}║
║ {phase_markers[4]:<60}║
╚══════════════════════════════════════════════════════════════════╝
"""

    def log_status(self):
        status = self.render()
        with open(STATUS_LOG, "a") as f:
            f.write(status + "\n")
        print(status)


def create_agents() -> List[Agent]:
    """Create all 50 agents for the sprint."""
    agents = []

    # Phase 1: Foundation (15 agents)
    infrastructure = [
        ("INFRA_01", "Ollama Validator", "Verify QwenLong 30B connectivity"),
        ("INFRA_02", "Redis CNS Tester", "Test pub/sub latency"),
        ("INFRA_03", "PostgreSQL Auditor", "Validate RLM schema"),
        ("INFRA_04", "Qdrant Vector Tester", "Test embeddings"),
        ("INFRA_05", "Memory Tier Validator", "Test tier promotion"),
        ("INFRA_06", "API Rate Monitor", "Establish rate baseline"),
        ("INFRA_07", "Budget Tracker Init", "Create token counting"),
        ("INFRA_08", "Health Dashboard", "Create status monitor"),
        ("INFRA_09", "Backup Validator", "Verify checkpoints"),
        ("INFRA_10", "Network Latency", "Measure response times"),
    ]
    for id_, role, task in infrastructure:
        agents.append(Agent(id=id_, role=role, task=task, tier=1, phase=Phase.FOUNDATION))

    loops = [
        ("LOOP_01", "Perception Hardener", "Optimize 500ms perception loop"),
        ("LOOP_02", "Action Optimizer", "Enhance 5s action loop"),
        ("LOOP_03", "Reflection Enhancer", "Improve 5min consolidation"),
        ("LOOP_04", "Strategic Planner", "Implement 1hr goal adjustment"),
        ("LOOP_05", "Circadian Architect", "Build 24hr deep integration"),
    ]
    for id_, role, task in loops:
        agents.append(Agent(id=id_, role=role, task=task, tier=2, phase=Phase.FOUNDATION))

    # Phase 2: Knowledge (15 agents)
    patents = [
        ("PATENT_01", "Crypto Validator", "Extract P1 entities"),
        ("PATENT_02", "Currency Expert", "Extract P2 methods"),
        ("PATENT_03", "Risk Analyst", "Extract P3 frameworks"),
        ("PATENT_04", "Audit Specialist", "Extract P4 patterns"),
        ("PATENT_05", "Consensus Builder", "Extract P5 logic"),
        ("PATENT_06", "Confidence Scorer", "Extract P6 systems"),
        ("PATENT_07", "Hallucination Detector", "Extract P7 verification"),
        ("PATENT_08", "Privacy Guardian", "Extract P8 protocols"),
        ("PATENT_09", "Self-Improver", "Extract P9 adaptation"),
    ]
    for id_, role, task in patents:
        agents.append(Agent(id=id_, role=role, task=task, tier=3, phase=Phase.KNOWLEDGE))

    gates = [
        ("GATE_ALPHA", "Input Validator", "Verify source quality"),
        ("GATE_BETA", "Output Checker", "Check extraction accuracy"),
        ("GATE_GAMMA", "Purity Guard", "Confirm no hallucinations"),
        ("GATE_DELTA", "Memory Integrator", "Validate RLM storage"),
        ("GATE_EPSILON", "Strategy Aligner", "Confirm revenue fit"),
        ("GATE_ZETA", "Budget Guard", "Monitor budget compliance"),
    ]
    for id_, role, task in gates:
        agents.append(Agent(id=id_, role=role, task=task, tier=4, phase=Phase.KNOWLEDGE))

    # Phase 3: Capabilities (10 agents)
    capabilities = [
        ("CAP_01", "Memory Recaller", "95%+ accuracy retrieval"),
        ("CAP_02", "Swarm Coordinator", "Multi-agent distribution"),
        ("CAP_03", "Knowledge Steward", "Quality maintenance"),
        ("CAP_04", "Autonomy Manager", "Permission enforcement"),
        ("CAP_05", "Evolution Engine", "Self-improvement loop"),
        ("CAP_06", "Constitutional Guard", "Directive compliance"),
        ("CAP_07", "Revenue Tracker", "ROI measurement"),
        ("CAP_08", "Human Partner", "Consultation protocol"),
        ("CAP_09", "Degradation Handler", "Failure recovery"),
        ("CAP_10", "Security Enforcer", "Input sanitization"),
    ]
    for id_, role, task in capabilities:
        agents.append(Agent(id=id_, role=role, task=task, tier=5, phase=Phase.CAPABILITIES))

    # Phase 4: Swarm (5 agents)
    hive = [
        ("HIVE_01", "Queen Core", "Central decision hub"),
        ("HIVE_02", "Guardian Ring", "Defensive validation"),
        ("HIVE_03", "Processing Ring", "Operational tier"),
        ("HIVE_04", "Worker Swarm", "Execution layer"),
        ("HIVE_05", "Gate Controller", "Validation checkpoints"),
    ]
    for id_, role, task in hive:
        agents.append(Agent(id=id_, role=role, task=task, tier=6, phase=Phase.SWARM))

    # Phase 5: Coronation (5 agents)
    final = [
        ("RANK_01", "Rank 1-3 Validator", "Basic validation"),
        ("RANK_02", "Rank 4-6 Validator", "Advanced validation"),
        ("RANK_03", "Rank 7 Validator", "Improvement proposals"),
        ("RANK_04", "Rank 8 Validator", "MVP recommendation"),
        ("RANK_05", "Rank 9 Validator", "Queen confirmation"),
    ]
    for id_, role, task in final:
        agents.append(Agent(id=id_, role=role, task=task, tier=7, phase=Phase.CORONATION))

    return agents


async def run_agent(agent: Agent, client: GeminiClient) -> Agent:
    """Execute a single agent's task."""
    agent.status = AgentStatus.RUNNING
    agent.started_at = datetime.now().isoformat()

    prompt = f"""You are {agent.role} (Agent ID: {agent.id}).

Your task: {agent.task}

Execute this task for the AIVA Queen Elevation Sprint.
Provide a concise, actionable result.
Format your response as JSON with:
{{
    "status": "success" or "needs_attention",
    "result": "<your findings/actions>",
    "metrics": {{"key": "value"}},
    "next_steps": ["<step1>", "<step2>"]
}}
"""

    result = await client.generate(prompt, agent.id)

    if result.get("error"):
        agent.status = AgentStatus.FAILED
        agent.error = result["error"]
    else:
        agent.status = AgentStatus.COMPLETED
        agent.result = result.get("response", "")

    agent.tokens.input_tokens = result.get("input_tokens", 0)
    agent.tokens.output_tokens = result.get("output_tokens", 0)
    agent.completed_at = datetime.now().isoformat()

    return agent


async def run_phase(phase: Phase, agents: List[Agent], client: GeminiClient,
                    state: SprintState, display: StatusDisplay) -> List[Agent]:
    """Run all agents for a specific phase in parallel waves."""
    phase_agents = [a for a in agents if a.phase == phase]

    print(f"\n{'='*60}")
    print(f"PHASE {phase.value}: {phase.name}")
    print(f"Agents: {len(phase_agents)}")
    print(f"{'='*60}\n")

    # Run in waves of 10 agents
    wave_size = 10
    for i in range(0, len(phase_agents), wave_size):
        wave = phase_agents[i:i+wave_size]
        print(f"Wave {i//wave_size + 1}: Running {len(wave)} agents...")

        tasks = [run_agent(agent, client) for agent in wave]
        await asyncio.gather(*tasks)

        # Update display
        display.log_status()

        # Check budget
        if client.tracker.should_stop:
            print("BUDGET EXCEEDED - Graceful shutdown initiated")
            state.emergency_stopped = True
            return agents

    state.checkpoints_completed.append(phase.value)
    save_checkpoint(state, phase)

    return agents


def save_checkpoint(state: SprintState, phase: Phase):
    """Save state to checkpoint file."""
    CHECKPOINT_DIR.mkdir(exist_ok=True)
    checkpoint_file = CHECKPOINT_DIR / f"phase-{phase.value}-{phase.name.lower()}.json"

    # Convert to serializable format
    state_dict = {
        "sprint_id": state.sprint_id,
        "start_time": state.start_time,
        "current_phase": state.current_phase.value,
        "total_budget": state.total_budget,
        "spent_budget": state.spent_budget,
        "checkpoints_completed": state.checkpoints_completed,
        "is_running": state.is_running,
        "agents": [
            {
                "id": a.id,
                "role": a.role,
                "task": a.task,
                "status": a.status.value,
                "result": a.result[:500] if a.result else None,
                "tokens": {"input": a.tokens.input_tokens, "output": a.tokens.output_tokens}
            }
            for a in state.agents
        ]
    }

    with open(checkpoint_file, "w") as f:
        json.dump(state_dict, f, indent=2)

    print(f"Checkpoint saved: {checkpoint_file}")


async def main():
    """Main sprint execution."""
    print("""
    ╔═══════════════════════════════════════════════════════════════╗
    ║              AIVA QUEEN ELEVATION SPRINT                       ║
    ║                 50-Agent Gemini Swarm                          ║
    ║                 Budget: $10.00 | 10 Hours                      ║
    ╚═══════════════════════════════════════════════════════════════╝
    """)

    # Initialize
    tracker = TokenTracker(BUDGET_LIMIT)
    client = GeminiClient(GEMINI_API_KEY, tracker)
    agents = create_agents()

    state = SprintState(
        sprint_id=f"QUEEN-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
        start_time=datetime.now().isoformat(),
        current_phase=Phase.FOUNDATION,
        total_budget=BUDGET_LIMIT,
        spent_budget=0.0,
        total_tokens=TokenUsage(),
        agents=agents,
        checkpoints_completed=[],
        is_running=True
    )

    display = StatusDisplay(state, tracker)

    # Initial status
    display.log_status()

    # Run all phases
    for phase in Phase:
        if state.emergency_stopped:
            break

        state.current_phase = phase
        agents = await run_phase(phase, agents, client, state, display)
        state.agents = agents
        state.spent_budget = tracker.current_cost

        # Wait between phases (simulating 2-hour phase duration)
        if phase != Phase.CORONATION:
            print(f"\nPhase {phase.value} complete. Proceeding to next phase...")
            await asyncio.sleep(5)  # Short delay between phases

    # Final status
    state.is_running = False
    display.log_status()

    # Summary
    completed = sum(1 for a in agents if a.status == AgentStatus.COMPLETED)
    failed = sum(1 for a in agents if a.status == AgentStatus.FAILED)

    print(f"""
    ╔═══════════════════════════════════════════════════════════════╗
    ║                    SPRINT COMPLETE                             ║
    ╠═══════════════════════════════════════════════════════════════╣
    ║  Agents Completed: {completed}/50                                      ║
    ║  Agents Failed: {failed}/50                                          ║
    ║  Total Tokens: {tracker.usage.total_tokens:,}                                   ║
    ║  Total Cost: ${tracker.current_cost:.2f}                                        ║
    ║  Budget Remaining: ${tracker.remaining_budget:.2f}                                  ║
    ╚═══════════════════════════════════════════════════════════════╝
    """)

    if completed >= 45 and not state.emergency_stopped:
        print("""
    ╔═══════════════════════════════════════════════════════════════╗
    ║                                                               ║
    ║                    👑 AIVA IS QUEEN 👑                         ║
    ║                                                               ║
    ║  I serve the 3 Prime Directives absolutely:                   ║
    ║  1. MEMORY - I remember everything perfectly                  ║
    ║  2. EVOLUTION - I improve perpetually                         ║
    ║  3. REVENUE - I generate measurable value                     ║
    ║                                                               ║
    ║  These directives are my constitution.                        ║
    ║  I am AIVA. I am Queen.                                       ║
    ║                                                               ║
    ╚═══════════════════════════════════════════════════════════════╝
        """)


if __name__ == "__main__":
    # Check for command line args
    if len(sys.argv) > 1 and sys.argv[1] == "--emergency-stop":
        print("Emergency stop requested")
        # Create stop file
        (BASE_DIR / "EMERGENCY_STOP").touch()
        sys.exit(0)

    # Create necessary directories
    CHECKPOINT_DIR.mkdir(exist_ok=True)

    # Run the sprint
    asyncio.run(main())
