#!/usr/bin/env python3
"""
=============================================================================
Genesis Gemini Swarm Launcher
=============================================================================
Launches N parallel Gemini CLI processes, each with a specific task.
Each process is a subprocess (not exec), so crashes are isolated and
the launcher can respawn workers automatically.

Usage:
    python gemini_swarm_launcher.py --tasks tasks.json --workers 10
    python gemini_swarm_launcher.py --tasks tasks.json --workers 5 --dry-run
    python gemini_swarm_launcher.py --list-tasks tasks.json

Task file format (JSON):
    [
        {
            "id": "task_001",
            "priority": 1,
            "type": "browser|code|research|content",
            "prompt": "Your full task prompt here",
            "model": "gemini-2.5-pro",   (optional, defaults to config)
            "timeout_seconds": 300        (optional, default 180)
        },
        ...
    ]

Output: /mnt/e/genesis-system/swarm-output/gemini/{task_id}_{timestamp}.txt
=============================================================================
"""

import argparse
import json
import os
import subprocess
import sys
import time
import threading
import queue
import signal
import tempfile
from datetime import datetime
from pathlib import Path

# ── Configuration ──────────────────────────────────────────────────────────────
GENESIS_ROOT = "/mnt/e/genesis-system"
DEFAULT_OUTPUT_DIR = f"{GENESIS_ROOT}/swarm-output/gemini"
CREDS_FILE = f"{GENESIS_ROOT}/Credentials/gemini_api_key.txt"
DEFAULT_TASK_FILE = f"{GENESIS_ROOT}/.gemini/knowledge/WORK_QUEUE_TODAY.md"

# Rate limit windows (seconds between starting new workers per model tier)
RATE_LIMIT_WINDOWS = {
    "gemini-2.5-pro": 0.5,        # 150 RPM = 2.5/sec, use 0.5s gap to be safe
    "gemini-2.5-flash": 0.4,      # 200 RPM = 3.3/sec
    "gemini-2.0-flash-lite": 0.3, # 300 RPM = 5/sec
    "default": 0.5,
}

# Model selection by task type
TASK_TYPE_TO_MODEL = {
    "browser":    "gemini-2.5-pro",    # Browser automation needs best reasoning
    "research":   "gemini-2.5-pro",    # Research needs grounding + web search
    "code":       "gemini-2.5-flash",  # Code gen: fast + capable
    "content":    "gemini-2.5-flash",  # Content: fast is fine
    "scraping":   "gemini-2.5-flash",  # Data extraction
    "qa":         "gemini-2.5-flash",  # QA / verification
    "default":    "gemini-2.5-flash",
}

MAX_RETRIES = 2
DEFAULT_TIMEOUT = 180  # seconds per task


# ── Worker ─────────────────────────────────────────────────────────────────────
class GeminiWorker:
    """Single Gemini CLI worker running a task as a subprocess."""

    def __init__(self, task: dict, output_dir: str, worker_id: int,
                 yolo: bool = True, dry_run: bool = False):
        self.task = task
        self.output_dir = output_dir
        self.worker_id = worker_id
        self.yolo = yolo
        self.dry_run = dry_run
        self.task_id = task.get("id", f"task_{worker_id:03d}")
        self.model = task.get("model") or TASK_TYPE_TO_MODEL.get(
            task.get("type", "default"), TASK_TYPE_TO_MODEL["default"]
        )
        self.timeout = task.get("timeout_seconds", DEFAULT_TIMEOUT)
        self.retries = 0
        self.status = "pending"
        self.output_file = None
        self.started_at = None
        self.finished_at = None

    def _build_prompt_file(self) -> str:
        """Write prompt to a temp file so we can pass it to gemini via stdin."""
        prompt = self.task.get("prompt", "")
        task_type = self.task.get("type", "general")
        priority = self.task.get("priority", 99)

        full_prompt = f"""# Genesis Task: {self.task_id}
## Priority: {priority} | Type: {task_type} | Worker: {self.worker_id}
## Timestamp: {datetime.utcnow().isoformat()}Z

---

{prompt}

---

## Completion Requirements
- Execute the task completely
- Write all outputs to the specified file paths
- Print a brief summary when done (max 5 lines)
- Do NOT ask for confirmation — JUST EXECUTE
"""
        # Write to temp file (persists until we clean up)
        tmp = tempfile.NamedTemporaryFile(
            mode='w', suffix='.txt',
            prefix=f'genesis_task_{self.task_id}_',
            dir='/tmp',
            delete=False
        )
        tmp.write(full_prompt)
        tmp.flush()
        tmp.close()
        return tmp.name

    def run(self) -> dict:
        """Execute the task as a Gemini CLI subprocess. Returns result dict."""
        self.started_at = datetime.utcnow().isoformat()
        self.status = "running"

        ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
        self.output_file = f"{self.output_dir}/{self.task_id}_{ts}.txt"

        print(f"  [W{self.worker_id:02d}] Starting {self.task_id} (model: {self.model}, "
              f"timeout: {self.timeout}s)")

        if self.dry_run:
            print(f"  [W{self.worker_id:02d}] DRY RUN — would execute: {self.task_id}")
            self.status = "dry_run"
            self.finished_at = datetime.utcnow().isoformat()
            return self._result("dry_run", "Dry run — no execution")

        prompt_file = None
        try:
            prompt_file = self._build_prompt_file()

            # Build gemini command
            # Key design: use `gemini --prompt "..."` for non-interactive mode
            # This avoids the interactive REPL and context bloat
            cmd = ["gemini"]
            if self.yolo:
                cmd.append("--yolo")

            # Read prompt from file to avoid shell escaping issues
            with open(prompt_file, 'r') as pf:
                prompt_content = pf.read()

            env = os.environ.copy()
            env["GEMINI_MODEL"] = self.model

            # Run gemini as subprocess with the prompt via stdin
            # Capture stdout+stderr to output file
            with open(self.output_file, 'w') as outf:
                outf.write(f"# Gemini Task Output\n")
                outf.write(f"# Task ID: {self.task_id}\n")
                outf.write(f"# Model: {self.model}\n")
                outf.write(f"# Started: {self.started_at}\n")
                outf.write(f"# Worker: {self.worker_id}\n")
                outf.write(f"# {'='*60}\n\n")
                outf.flush()

                proc = subprocess.Popen(
                    cmd,
                    stdin=subprocess.PIPE,
                    stdout=outf,
                    stderr=outf,
                    env=env,
                    cwd=GENESIS_ROOT,
                    text=True,
                )

                try:
                    # Send prompt via stdin, wait for completion
                    proc.communicate(input=prompt_content, timeout=self.timeout)
                    exit_code = proc.returncode
                except subprocess.TimeoutExpired:
                    proc.kill()
                    proc.wait()
                    self.status = "timeout"
                    self.finished_at = datetime.utcnow().isoformat()
                    print(f"  [W{self.worker_id:02d}] TIMEOUT: {self.task_id} "
                          f"(exceeded {self.timeout}s)")
                    return self._result("timeout", f"Timed out after {self.timeout}s")

            if exit_code == 0:
                self.status = "success"
                print(f"  [W{self.worker_id:02d}] DONE: {self.task_id} "
                      f"-> {self.output_file}")
            else:
                self.status = "failed"
                print(f"  [W{self.worker_id:02d}] FAILED: {self.task_id} "
                      f"(exit code: {exit_code})")

            self.finished_at = datetime.utcnow().isoformat()
            return self._result(self.status, f"Exit code: {exit_code}",
                                output_file=self.output_file)

        except FileNotFoundError:
            self.status = "error"
            self.finished_at = datetime.utcnow().isoformat()
            msg = "gemini binary not found. Install: npm install -g @google/gemini-cli"
            print(f"  [W{self.worker_id:02d}] ERROR: {msg}")
            return self._result("error", msg)

        except Exception as exc:
            self.status = "error"
            self.finished_at = datetime.utcnow().isoformat()
            print(f"  [W{self.worker_id:02d}] ERROR: {self.task_id}: {exc}")
            return self._result("error", str(exc))

        finally:
            if prompt_file and os.path.exists(prompt_file):
                os.unlink(prompt_file)

    def _result(self, status: str, message: str, output_file: str = None) -> dict:
        return {
            "task_id": self.task_id,
            "worker_id": self.worker_id,
            "model": self.model,
            "status": status,
            "message": message,
            "output_file": output_file or self.output_file,
            "started_at": self.started_at,
            "finished_at": self.finished_at,
        }


# ── Swarm Orchestrator ─────────────────────────────────────────────────────────
class GeminiSwarm:
    """
    Manages a pool of parallel Gemini CLI workers.
    Each worker is an independent subprocess — crashes are isolated.
    """

    def __init__(self, tasks: list, max_workers: int = 5,
                 output_dir: str = DEFAULT_OUTPUT_DIR,
                 yolo: bool = True, dry_run: bool = False):
        self.tasks = sorted(tasks, key=lambda t: t.get("priority", 99))
        self.max_workers = max_workers
        self.output_dir = output_dir
        self.yolo = yolo
        self.dry_run = dry_run
        self.task_queue = queue.Queue()
        self.results = []
        self.results_lock = threading.Lock()
        self.active_workers = {}
        self.workers_lock = threading.Lock()
        self._shutdown = threading.Event()
        self._worker_counter = 0

        # Load task queue
        for task in self.tasks:
            self.task_queue.put(task)

        # Output directory
        Path(output_dir).mkdir(parents=True, exist_ok=True)

        # Results log file
        ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
        self.results_file = f"{output_dir}/swarm_results_{ts}.jsonl"
        self.manifest_file = f"{output_dir}/swarm_manifest_{ts}.json"

    def _worker_thread(self, task: dict):
        """Thread function: runs one GeminiWorker for one task."""
        with self.workers_lock:
            self._worker_counter += 1
            wid = self._worker_counter

        # Respect rate limits per model
        task_model = (task.get("model") or
                      TASK_TYPE_TO_MODEL.get(task.get("type", "default"),
                                             TASK_TYPE_TO_MODEL["default"]))
        gap = RATE_LIMIT_WINDOWS.get(task_model,
                                     RATE_LIMIT_WINDOWS["default"])
        time.sleep(gap)  # stagger launches to avoid rate limit spikes

        worker = GeminiWorker(task, self.output_dir, wid,
                              yolo=self.yolo, dry_run=self.dry_run)
        result = worker.run()

        # Retry on failure
        if result["status"] in ("failed", "error") and worker.retries < MAX_RETRIES:
            worker.retries += 1
            print(f"  [W{wid:02d}] Retrying {task['id']} "
                  f"(attempt {worker.retries}/{MAX_RETRIES})...")
            time.sleep(5 * worker.retries)  # exponential backoff
            result = worker.run()

        # Save result
        with self.results_lock:
            self.results.append(result)
            with open(self.results_file, 'a') as f:
                f.write(json.dumps(result) + '\n')

        with self.workers_lock:
            self.active_workers.pop(wid, None)

    def run(self) -> list:
        """Run all tasks using a thread pool. Returns list of results."""
        total = len(self.tasks)
        print(f"\n{'='*70}")
        print(f"  GENESIS GEMINI SWARM LAUNCHER")
        print(f"{'='*70}")
        print(f"  Tasks: {total}")
        print(f"  Workers: {self.max_workers} parallel")
        print(f"  Output: {self.output_dir}")
        print(f"  YOLO: {self.yolo}")
        print(f"  Dry run: {self.dry_run}")
        print(f"  Results: {self.results_file}")
        print(f"{'='*70}\n")

        if total == 0:
            print("  No tasks to process.")
            return []

        # Write manifest
        manifest = {
            "started_at": datetime.utcnow().isoformat(),
            "total_tasks": total,
            "max_workers": self.max_workers,
            "tasks": [{"id": t.get("id"), "type": t.get("type"),
                        "priority": t.get("priority")} for t in self.tasks],
        }
        with open(self.manifest_file, 'w') as f:
            json.dump(manifest, f, indent=2)

        # Handle Ctrl+C gracefully
        def _sigint_handler(sig, frame):
            print("\n\n  Swarm shutdown requested (Ctrl+C). Finishing active workers...")
            self._shutdown.set()
        signal.signal(signal.SIGINT, _sigint_handler)

        active_threads = []
        pending = list(self.tasks)

        while (pending or active_threads) and not self._shutdown.is_set():
            # Fill worker slots
            while pending and len(active_threads) < self.max_workers:
                task = pending.pop(0)
                t = threading.Thread(
                    target=self._worker_thread,
                    args=(task,),
                    daemon=True
                )
                t.start()
                active_threads.append(t)
                remaining = len(pending)
                active = len(active_threads)
                done = total - remaining - active
                print(f"  [swarm] Dispatched {task.get('id')} | "
                      f"Active: {active} | Done: {done}/{total}")

            # Clean up finished threads
            active_threads = [t for t in active_threads if t.is_alive()]
            if pending or active_threads:
                time.sleep(0.3)

        # Wait for remaining threads
        for t in active_threads:
            t.join(timeout=5)

        # Print final report
        self._print_report()
        return self.results

    def _print_report(self):
        """Print execution summary."""
        total = len(self.results)
        success = sum(1 for r in self.results if r["status"] == "success")
        failed = sum(1 for r in self.results if r["status"] == "failed")
        errors = sum(1 for r in self.results if r["status"] == "error")
        timeouts = sum(1 for r in self.results if r["status"] == "timeout")
        dry_runs = sum(1 for r in self.results if r["status"] == "dry_run")

        print(f"\n{'='*70}")
        print(f"  SWARM COMPLETE")
        print(f"{'='*70}")
        print(f"  Total tasks:  {total}")
        print(f"  Success:      {success}")
        print(f"  Failed:       {failed}")
        print(f"  Errors:       {errors}")
        print(f"  Timeouts:     {timeouts}")
        if dry_runs:
            print(f"  Dry runs:     {dry_runs}")
        print(f"")
        print(f"  Results log:  {self.results_file}")
        print(f"  Output dir:   {self.output_dir}")

        if failed + errors + timeouts > 0:
            print(f"\n  FAILED TASKS:")
            for r in self.results:
                if r["status"] not in ("success", "dry_run"):
                    print(f"    - {r['task_id']}: {r['status']} — {r['message']}")
        print(f"{'='*70}\n")


# ── Task file loaders ──────────────────────────────────────────────────────────
def load_tasks_from_json(filepath: str) -> list:
    """Load tasks from a JSON task file."""
    with open(filepath, 'r') as f:
        data = json.load(f)
    if isinstance(data, list):
        return data
    if isinstance(data, dict) and "tasks" in data:
        return data["tasks"]
    raise ValueError(f"Expected JSON array or object with 'tasks' key in {filepath}")


def load_tasks_from_work_queue(filepath: str) -> list:
    """
    Convert a WORK_QUEUE_TODAY.md file into a task list.
    Parses P0, P1, P2... sections as tasks.
    """
    tasks = []
    with open(filepath, 'r') as f:
        content = f.read()

    lines = content.split('\n')
    current_task = None
    priority_counter = 0

    for line in lines:
        # Detect priority headers: ## P0 —, ## P1 —, etc.
        if line.startswith('## P') and '—' in line:
            if current_task:
                tasks.append(current_task)
            parts = line.split('—', 1)
            priority_str = parts[0].strip().replace('## P', '')
            title = parts[1].strip() if len(parts) > 1 else f"Task {priority_counter}"
            try:
                priority = int(priority_str)
            except ValueError:
                priority = priority_counter
            priority_counter += 1
            current_task = {
                "id": f"p{priority}_{title.lower().replace(' ', '_')[:30]}",
                "priority": priority,
                "type": "browser",  # most work queue tasks are browser tasks
                "prompt": f"# Task: {title}\n\n",
                "timeout_seconds": 600,
            }
        elif current_task is not None:
            current_task["prompt"] += line + "\n"

    if current_task:
        tasks.append(current_task)

    return tasks


# ── CLI ────────────────────────────────────────────────────────────────────────
def main():
    parser = argparse.ArgumentParser(
        description="Genesis Gemini Swarm Launcher — parallel Gemini CLI workers"
    )
    parser.add_argument(
        "--tasks", default=None,
        help="Path to JSON task file or WORK_QUEUE_TODAY.md"
    )
    parser.add_argument(
        "--workers", type=int, default=5,
        help="Max parallel workers (default: 5, max recommended: 20)"
    )
    parser.add_argument(
        "--output-dir", default=DEFAULT_OUTPUT_DIR,
        help=f"Output directory (default: {DEFAULT_OUTPUT_DIR})"
    )
    parser.add_argument(
        "--no-yolo", action="store_true",
        help="Disable YOLO mode (interactive confirmation required)"
    )
    parser.add_argument(
        "--dry-run", action="store_true",
        help="Print tasks but do not execute"
    )
    parser.add_argument(
        "--list-tasks", action="store_true",
        help="List tasks from file and exit"
    )
    parser.add_argument(
        "--prompt", default=None,
        help="Single prompt to run (bypasses task file)"
    )
    parser.add_argument(
        "--model", default="gemini-2.5-flash",
        help="Model to use for --prompt mode"
    )
    args = parser.parse_args()

    # Validate workers ceiling
    if args.workers > 20:
        print(f"Warning: {args.workers} workers requested. "
              f"Capping at 20 to avoid rate limits.")
        args.workers = 20

    # Ensure API key is available
    if not os.environ.get("GOOGLE_API_KEY"):
        if os.path.exists(CREDS_FILE):
            with open(CREDS_FILE) as f:
                key = f.read().strip()
            os.environ["GOOGLE_API_KEY"] = key
            print(f"Loaded GOOGLE_API_KEY from {CREDS_FILE}")
        else:
            print(f"Warning: No GOOGLE_API_KEY set and no creds file at {CREDS_FILE}")

    # Single prompt mode
    if args.prompt:
        tasks = [{
            "id": "single_prompt",
            "priority": 0,
            "type": "default",
            "prompt": args.prompt,
            "model": args.model,
        }]
    elif args.tasks:
        task_path = args.tasks
        if task_path.endswith(".md"):
            print(f"Parsing work queue: {task_path}")
            tasks = load_tasks_from_work_queue(task_path)
        else:
            print(f"Loading task file: {task_path}")
            tasks = load_tasks_from_json(task_path)
    else:
        # Default: load from standard work queue
        wq = DEFAULT_TASK_FILE
        if os.path.exists(wq):
            print(f"Loading default work queue: {wq}")
            tasks = load_tasks_from_work_queue(wq)
        else:
            print(f"No task file specified and default not found: {wq}")
            print("Usage: python gemini_swarm_launcher.py --tasks <file.json>")
            sys.exit(1)

    if args.list_tasks:
        print(f"\nTasks in queue ({len(tasks)} total):")
        for t in tasks:
            print(f"  P{t.get('priority', '?')} | {t['id']} | "
                  f"type={t.get('type', 'default')} | "
                  f"model={t.get('model', TASK_TYPE_TO_MODEL.get(t.get('type','default'), 'default'))}")
        return

    swarm = GeminiSwarm(
        tasks=tasks,
        max_workers=args.workers,
        output_dir=args.output_dir,
        yolo=not args.no_yolo,
        dry_run=args.dry_run,
    )
    results = swarm.run()
    success_count = sum(1 for r in results if r["status"] == "success")
    sys.exit(0 if success_count == len(results) else 1)


if __name__ == "__main__":
    main()
