#!/usr/bin/env python3
"""
Deep Think Stream - Uses streaming to avoid server-side timeouts.
Streaming keeps the connection alive as tokens are generated.
"""

import json
import os
import sys
import time
from pathlib import Path


def get_api_key():
    api_key = os.environ.get("GEMINI_API_KEY")
    if api_key:
        return api_key
    creds_file = Path(__file__).parent.parent / "Credentials" / "gemini_api_key.txt"
    if creds_file.exists():
        return creds_file.read_text().strip()
    return None


def fire_prompt_streaming(prompt_file, output_file, model="gemini-2.5-pro", max_output_tokens=16384):
    """Fire a prompt using streaming to avoid server-side timeout."""
    import google.generativeai as genai

    api_key = get_api_key()
    if not api_key:
        print("ERROR: No GEMINI_API_KEY found")
        sys.exit(1)

    genai.configure(api_key=api_key)

    # Read prompt
    prompt_text = Path(prompt_file).read_text().strip()
    print(f"Prompt file: {prompt_file} ({len(prompt_text)} chars)")

    generation_config = {
        "temperature": 1.0,
        "max_output_tokens": max_output_tokens,
    }

    model_obj = genai.GenerativeModel(
        model_name=model,
        generation_config=generation_config,
    )

    print(f"STREAMING to {model} (max_tokens={max_output_tokens})...")
    start_time = time.time()

    full_text = []
    try:
        response = model_obj.generate_content(
            prompt_text,
            stream=True,
            request_options={"timeout": 600}
        )
        chunk_count = 0
        for chunk in response:
            if chunk.text:
                full_text.append(chunk.text)
                chunk_count += 1
                if chunk_count % 10 == 0:
                    elapsed = time.time() - start_time
                    chars = sum(len(t) for t in full_text)
                    print(f"  ... {chunk_count} chunks, {chars} chars, {elapsed:.0f}s")
    except Exception as e:
        print(f"ERROR during streaming: {e}")
        if full_text:
            print(f"  Recovered {len(full_text)} chunks ({sum(len(t) for t in full_text)} chars)")
        else:
            sys.exit(1)

    response_text = "".join(full_text)
    elapsed = time.time() - start_time
    print(f"Response complete in {elapsed:.2f}s ({len(response_text)} chars, {len(full_text)} chunks)")

    result = {
        "model": model,
        "elapsed_seconds": elapsed,
        "prompt_file": str(prompt_file),
        "prompt_length_chars": len(prompt_text),
        "response_length_chars": len(response_text),
        "chunk_count": len(full_text),
        "thinking": "",
        "response": response_text,
        "raw": response_text,
    }

    # Save JSON
    output_path = Path(output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(result, f, indent=2)

    # Save MD
    md_path = output_path.with_suffix(".md")
    with open(md_path, "w") as f:
        f.write(f"# Gemini Deep Think Result\n\n")
        f.write(f"**Model**: {model}\n")
        f.write(f"**Elapsed**: {elapsed:.2f}s\n")
        f.write(f"**Prompt**: {prompt_file}\n\n")
        f.write(f"## Response\n\n{response_text}\n")

    print(f"Saved to {output_path} and {md_path}")
    return result


if __name__ == "__main__":
    if len(sys.argv) < 3:
        print("Usage: python3 deep_think_stream.py <prompt_file> <output_file> [model] [max_tokens]")
        sys.exit(1)

    prompt_file = sys.argv[1]
    output_file = sys.argv[2]
    model = sys.argv[3] if len(sys.argv) > 3 else "gemini-2.5-pro"
    max_tokens = int(sys.argv[4]) if len(sys.argv) > 4 else 16384

    fire_prompt_streaming(prompt_file, output_file, model, max_tokens)
