#!/usr/bin/env python3
"""
CHECKPOINT RECOVERY UTILITY
===========================
Recover and resume sprint from saved checkpoints.

Usage:
    python3 checkpoint_recovery.py list        # List all checkpoints
    python3 checkpoint_recovery.py status      # Current sprint status
    python3 checkpoint_recovery.py resume [N]  # Resume from checkpoint phase N
"""

import sys
import json
from datetime import datetime
from pathlib import Path

BASE_DIR = Path("/mnt/e/genesis-system/AIVA")
CHECKPOINT_DIR = BASE_DIR / "sprint-checkpoints"
TOKEN_LOG = BASE_DIR / "token_usage.jsonl"
STATUS_LOG = BASE_DIR / "sprint_status.log"


def list_checkpoints():
    """List all available checkpoints."""
    if not CHECKPOINT_DIR.exists():
        print("No checkpoints directory found.")
        return

    checkpoints = list(CHECKPOINT_DIR.glob("*.json"))
    if not checkpoints:
        print("No checkpoints found.")
        return

    print("\n" + "="*60)
    print("AVAILABLE CHECKPOINTS")
    print("="*60)

    for cp in sorted(checkpoints):
        try:
            with open(cp, "r") as f:
                data = json.load(f)

            mtime = datetime.fromtimestamp(cp.stat().st_mtime)
            phase = data.get("current_phase", "?")
            completed = data.get("checkpoints_completed", [])
            agents_data = data.get("agents", [])
            completed_agents = sum(1 for a in agents_data if a.get("status") == "completed")
            spent = data.get("spent_budget", 0)

            print(f"\n{cp.name}")
            print(f"  Modified: {mtime.strftime('%Y-%m-%d %H:%M:%S')}")
            print(f"  Phase: {phase}")
            print(f"  Completed Phases: {completed}")
            print(f"  Agents Completed: {completed_agents}/50")
            print(f"  Budget Spent: ${spent:.2f}")

        except Exception as e:
            print(f"\n{cp.name}: Error reading - {e}")

    print("\n" + "="*60)


def show_status():
    """Show current sprint status from logs."""
    print("\n" + "="*60)
    print("CURRENT SPRINT STATUS")
    print("="*60)

    # Token usage
    total_cost = 0.0
    total_input = 0
    total_output = 0
    last_agent = None

    if TOKEN_LOG.exists():
        with open(TOKEN_LOG, "r") as f:
            for line in f:
                try:
                    entry = json.loads(line.strip())
                    total_cost = entry.get("cumulative_cost", total_cost)
                    total_input += entry.get("input_tokens", 0)
                    total_output += entry.get("output_tokens", 0)
                    last_agent = entry.get("agent_id")
                except json.JSONDecodeError:
                    continue

        print(f"\nToken Usage:")
        print(f"  Input:  {total_input:,} tokens")
        print(f"  Output: {total_output:,} tokens")
        print(f"  Total:  {total_input + total_output:,} tokens")
        print(f"  Cost:   ${total_cost:.2f}")
        print(f"  Last Agent: {last_agent}")
    else:
        print("\nNo token usage log found.")

    # Latest checkpoint
    latest_cp = None
    if CHECKPOINT_DIR.exists():
        for cp in CHECKPOINT_DIR.glob("*.json"):
            if latest_cp is None or cp.stat().st_mtime > latest_cp.stat().st_mtime:
                latest_cp = cp

    if latest_cp:
        print(f"\nLatest Checkpoint: {latest_cp.name}")
        with open(latest_cp, "r") as f:
            data = json.load(f)
        print(f"  Phases Complete: {data.get('checkpoints_completed', [])}")

        agents = data.get("agents", [])
        by_status = {}
        for a in agents:
            status = a.get("status", "unknown")
            by_status[status] = by_status.get(status, 0) + 1

        print(f"  Agent Status: {by_status}")

    # Budget remaining
    budget_remaining = 10.00 - total_cost
    print(f"\nBudget Remaining: ${budget_remaining:.2f}")

    if budget_remaining <= 0.50:
        print("  WARNING: Budget nearly exhausted!")
    elif budget_remaining <= 2.00:
        print("  CAUTION: Budget running low")

    print("\n" + "="*60)


def create_resume_state(phase: int):
    """Create a resume state file for the sprint executor."""
    checkpoint_file = None

    if CHECKPOINT_DIR.exists():
        for cp in CHECKPOINT_DIR.glob(f"phase-{phase}-*.json"):
            checkpoint_file = cp
            break

    if checkpoint_file is None:
        print(f"No checkpoint found for phase {phase}")
        return

    print(f"Creating resume state from: {checkpoint_file}")

    with open(checkpoint_file, "r") as f:
        checkpoint = json.load(f)

    resume_state = {
        "resume_from_phase": phase + 1,  # Resume from NEXT phase
        "checkpoint_data": checkpoint,
        "created_at": datetime.now().isoformat()
    }

    resume_file = BASE_DIR / "resume_state.json"
    with open(resume_file, "w") as f:
        json.dump(resume_state, f, indent=2)

    print(f"Resume state saved to: {resume_file}")
    print(f"Sprint will resume from phase {phase + 1}")
    print("\nRun: python3 queen_elevation_sprint.py")


def main():
    if len(sys.argv) < 2:
        print(__doc__)
        return

    command = sys.argv[1].lower()

    if command == "list":
        list_checkpoints()
    elif command == "status":
        show_status()
    elif command == "resume":
        if len(sys.argv) < 3:
            print("Usage: python3 checkpoint_recovery.py resume <phase_number>")
            print("Example: python3 checkpoint_recovery.py resume 2")
            return
        try:
            phase = int(sys.argv[2])
            create_resume_state(phase)
        except ValueError:
            print("Phase must be a number (1-5)")
    else:
        print(f"Unknown command: {command}")
        print(__doc__)


if __name__ == "__main__":
    main()
