#!/usr/bin/env python3
"""
Gemini Swarm Simple — overnight parallel Gemini Flash requests
Fires 10 parallel requests every 30 seconds until 4am AEST.
Results written to data/overnight_logs/swarm_output.jsonl

Uses google.generativeai SDK with gemini-2.0-flash-lite (cost-efficient).
"""
import os
import sys
import json
import time
import threading
from pathlib import Path
from datetime import datetime, timezone, timedelta

# Suppress FutureWarning from deprecated SDK
import warnings
warnings.filterwarnings("ignore")

# Load .env for GEMINI_API_KEY
try:
    from dotenv import load_dotenv
    env_path = Path("E:/genesis-system/.env")
    if env_path.exists():
        load_dotenv(env_path)
        print(f"[SWARM] Loaded .env from {env_path}")
    else:
        load_dotenv()
        print("[SWARM] Loaded .env from default path")
except ImportError:
    print("[SWARM] dotenv not available, relying on environment variables")

import google.generativeai as genai

GENESIS = Path("E:/genesis-system")
LOG_DIR = GENESIS / "data" / "overnight_logs"
LOG_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_FILE = LOG_DIR / "swarm_output.jsonl"

AEST = timezone(timedelta(hours=10))
SWARM_INTERVAL_SECONDS = 30
NUM_AGENTS = 10

# Use gemini-2.0-flash-lite — cost-efficient, fast, available
GEMINI_MODEL = "gemini-2.0-flash-lite"

PROMPTS = [
    "Generate 5 concise insights about AI voice agents for Australian businesses. Focus on ROI and missed calls.",
    "Generate 5 reasons small Australian businesses lose revenue without 24/7 call answering. Be specific.",
    "Generate 5 objections tradies have to AI receptionists and how to counter each one professionally.",
    "Generate 5 use cases for AI voice agents in the Australian trades and home services sector.",
    "Generate 5 key differentiators that make a premium AI receptionist worth $497/month to a busy tradie.",
    "Generate 5 statistics or data points about missed calls costing Australian small businesses revenue.",
    "Generate 5 features the ideal AI receptionist for Australian trades businesses must have.",
    "Generate 5 ways an AI receptionist improves customer experience for callers to a trades business.",
    "Generate 5 industries in Australia that would benefit most from AI voice agents and why.",
    "Generate 5 competitive advantages of using AI receptionists vs traditional answering services.",
]

write_lock = threading.Lock()


def get_4am_aest_utc():
    """Return UTC datetime for 4am AEST tomorrow."""
    now_aest = datetime.now(AEST)
    four_am_aest = now_aest.replace(hour=4, minute=0, second=0, microsecond=0)
    if four_am_aest <= now_aest:
        four_am_aest += timedelta(days=1)
    return four_am_aest.astimezone(timezone.utc)


def run_agent(agent_id, prompt, wave_num, results_list):
    """Single Gemini agent — runs one request."""
    start = time.time()
    try:
        model = genai.GenerativeModel(GEMINI_MODEL)
        response = model.generate_content(prompt)
        elapsed = time.time() - start
        result = {
            "wave": wave_num,
            "agent_id": agent_id,
            "timestamp_utc": datetime.now(timezone.utc).isoformat(),
            "timestamp_aest": datetime.now(AEST).isoformat(),
            "prompt_preview": prompt[:80],
            "response_chars": len(response.text),
            "elapsed_seconds": round(elapsed, 2),
            "status": "success",
            "response_preview": response.text[:400],
        }
    except Exception as e:
        elapsed = time.time() - start
        result = {
            "wave": wave_num,
            "agent_id": agent_id,
            "timestamp_utc": datetime.now(timezone.utc).isoformat(),
            "timestamp_aest": datetime.now(AEST).isoformat(),
            "prompt_preview": prompt[:80],
            "elapsed_seconds": round(elapsed, 2),
            "status": "error",
            "error": str(e),
        }

    with write_lock:
        results_list.append(result)
        with open(OUTPUT_FILE, 'a', encoding='utf-8') as f:
            f.write(json.dumps(result) + '\n')


def run_wave(wave_num):
    """Fire NUM_AGENTS parallel Gemini requests."""
    threads = []
    results = []
    for i in range(NUM_AGENTS):
        prompt = PROMPTS[i % len(PROMPTS)]
        t = threading.Thread(target=run_agent, args=(i + 1, prompt, wave_num, results))
        threads.append(t)
        t.start()

    for t in threads:
        t.join(timeout=60)  # Max 60s per agent

    success = sum(1 for r in results if r.get('status') == 'success')
    total_chars = sum(r.get('response_chars', 0) for r in results)
    print(
        f"[SWARM] Wave {wave_num} complete | "
        f"{success}/{NUM_AGENTS} success | "
        f"{total_chars:,} chars generated | "
        f"{datetime.now(AEST).strftime('%H:%M AEST')}"
    )
    sys.stdout.flush()
    return results


def main():
    api_key = os.environ.get('GEMINI_API_KEY')
    if not api_key:
        print("[SWARM] ERROR: GEMINI_API_KEY not found in environment. Exiting.")
        sys.exit(1)

    genai.configure(api_key=api_key)
    print(f"[SWARM] Gemini configured with model: {GEMINI_MODEL}")
    print(f"[SWARM] API key: ...{api_key[-6:]}")
    print(f"[SWARM] Starting at {datetime.now(AEST).strftime('%Y-%m-%d %H:%M:%S AEST')}")

    stop_at_utc = get_4am_aest_utc()
    print(f"[SWARM] Will stop at {stop_at_utc.strftime('%Y-%m-%d %H:%M:%S UTC')} (4am AEST)")
    print(f"[SWARM] Output: {OUTPUT_FILE}")
    print(f"[SWARM] PID: {os.getpid()}")

    wave = 0
    while datetime.now(timezone.utc) < stop_at_utc:
        wave += 1
        print(f"[SWARM] Firing wave {wave} ({NUM_AGENTS} agents in parallel)...")
        sys.stdout.flush()
        run_wave(wave)

        # Sleep in 5s chunks to stay responsive to stop time
        sleep_end = datetime.now(timezone.utc) + timedelta(seconds=SWARM_INTERVAL_SECONDS)
        while datetime.now(timezone.utc) < sleep_end:
            if datetime.now(timezone.utc) >= stop_at_utc:
                break
            time.sleep(5)

    total_lines = 0
    if OUTPUT_FILE.exists():
        with open(OUTPUT_FILE, encoding='utf-8') as f:
            total_lines = sum(1 for _ in f)
    print(f"[SWARM] 4am AEST reached. Total outputs written: {total_lines}. Exiting.")


if __name__ == "__main__":
    main()
