#!/usr/bin/env python3
"""
Gemini 2.0 Flash Worker
Cost-efficient task execution using Gemini API.

Cost comparison:
- Claude Opus: $5/$25 per MTok
- Gemini 2.0 Flash: $0.10/$0.40 per MTok = 50x cheaper

Usage:
- AIVA orchestrator (Claude) makes decisions
- Gemini workers execute the actual tasks
"""

import os
import sys
import json
import time
import asyncio
from pathlib import Path
from typing import Optional, Dict, Any, List
from datetime import datetime
import logging

# Load secrets
secrets_path = Path("/mnt/e/genesis-system/config/secrets.env")
if secrets_path.exists():
    with open(secrets_path) as f:
        for line in f:
            if "=" in line and not line.startswith("#"):
                key, value = line.strip().split("=", 1)
                os.environ[key] = value.strip('"')

import google.generativeai as genai

# Configure Gemini
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
genai.configure(api_key=GEMINI_API_KEY)

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("GeminiWorker")


class GeminiWorker:
    """Cost-efficient Gemini 2.0 Flash worker for task execution."""

    def __init__(self, model_name: str = "gemini-2.0-flash-exp"):
        self.model_name = model_name
        self.model = genai.GenerativeModel(model_name)
        self.generation_config = genai.GenerationConfig(
            temperature=0.7,
            top_p=0.95,
            top_k=40,
            max_output_tokens=8192,
        )
        self.system_prompt = self._load_system_prompt()
        self.total_tokens = 0
        self.total_cost = 0.0

    def _load_system_prompt(self) -> str:
        """Load worker system prompt."""
        return """You are a Genesis system worker executing tasks efficiently.

Your role:
1. Execute the given task completely
2. Write clean, production-ready code
3. Follow existing patterns in the codebase
4. Include error handling and logging
5. Test your work when possible

Guidelines:
- Use Python 3.10+ features
- Follow PEP 8 style
- Add docstrings to all functions
- Include type hints
- Handle edge cases

Output format:
- For code tasks: Output the complete file content
- For analysis tasks: Output structured findings
- Always be concise and actionable"""

    def _estimate_cost(self, input_tokens: int, output_tokens: int) -> float:
        """Estimate cost for Gemini 2.0 Flash."""
        # Gemini 2.0 Flash pricing: $0.10/MTok input, $0.40/MTok output
        input_cost = (input_tokens / 1_000_000) * 0.10
        output_cost = (output_tokens / 1_000_000) * 0.40
        return input_cost + output_cost

    async def execute_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
        """Execute a task using Gemini 2.0 Flash."""
        start_time = time.time()

        prompt = f"""{self.system_prompt}

TASK: {task.get('title', 'Untitled')}

DESCRIPTION:
{task.get('description', '')}

REQUIREMENTS:
- Output complete, working code
- Follow Genesis patterns
- Include all imports
- Add comprehensive error handling

Execute this task now. Output the complete solution."""

        try:
            response = self.model.generate_content(
                prompt,
                generation_config=self.generation_config,
            )

            # Extract token counts
            input_tokens = response.usage_metadata.prompt_token_count
            output_tokens = response.usage_metadata.candidates_token_count

            self.total_tokens += input_tokens + output_tokens
            cost = self._estimate_cost(input_tokens, output_tokens)
            self.total_cost += cost

            elapsed = time.time() - start_time

            logger.info(f"Task completed in {elapsed:.1f}s | Tokens: {input_tokens}+{output_tokens} | Cost: ${cost:.4f}")

            return {
                "success": True,
                "result": response.text,
                "tokens": {"input": input_tokens, "output": output_tokens},
                "cost": cost,
                "elapsed": elapsed,
                "model": self.model_name,
            }

        except Exception as e:
            logger.error(f"Task failed: {e}")
            return {
                "success": False,
                "error": str(e),
                "model": self.model_name,
            }

    async def execute_code_task(self, task: Dict[str, Any], output_path: str) -> Dict[str, Any]:
        """Execute a code generation task and save to file."""
        result = await self.execute_task(task)

        if result["success"]:
            # Extract code from response
            code = result["result"]

            # Clean up markdown code blocks if present
            if "```python" in code:
                code = code.split("```python")[1].split("```")[0]
            elif "```" in code:
                code = code.split("```")[1].split("```")[0]

            # Save to file
            output_file = Path(output_path)
            output_file.parent.mkdir(parents=True, exist_ok=True)
            output_file.write_text(code.strip())

            result["file_path"] = str(output_file)
            logger.info(f"Saved to: {output_file}")

        return result

    def get_stats(self) -> Dict[str, Any]:
        """Get worker statistics."""
        return {
            "total_tokens": self.total_tokens,
            "total_cost": round(self.total_cost, 4),
            "model": self.model_name,
        }


class GeminiTaskRunner:
    """Run tasks from Redis queue using Gemini workers."""

    def __init__(self, num_workers: int = 4):
        self.num_workers = num_workers
        self.workers = [GeminiWorker() for _ in range(num_workers)]

        # Redis connection
        import redis
        REDIS_CONFIG = {
            "host": os.environ.get("GENESIS_REDIS_HOST", ""),
            "port": int(os.environ.get("GENESIS_REDIS_PORT", 26379)),
            "password": os.environ.get("GENESIS_REDIS_PASSWORD", ""),
            "decode_responses": True,
        }
        self.redis = redis.Redis(**REDIS_CONFIG)

        self.queue_key = "genesis:task_queue"
        self.active_key = "genesis:active_tasks"
        self.completed_key = "genesis:completed_tasks"
        self.failed_key = "genesis:failed_tasks"

    async def process_queue(self, max_tasks: int = 100):
        """Process tasks from queue using Gemini workers."""
        processed = 0

        logger.info(f"Starting Gemini task runner with {self.num_workers} workers")

        while processed < max_tasks:
            # Get next task
            result = self.redis.zpopmin(self.queue_key, count=1)
            if not result:
                logger.info("Queue empty, waiting...")
                await asyncio.sleep(30)
                continue

            task_json, score = result[0]
            task = json.loads(task_json)

            logger.info(f"Processing: {task['title'][:50]}...")

            # Mark as active
            task["status"] = "active"
            task["started_at"] = datetime.now().isoformat()
            self.redis.hset(self.active_key, task["id"], json.dumps(task))

            # Select worker (round robin)
            worker = self.workers[processed % self.num_workers]

            # Determine output path from task description
            output_path = None
            desc = task.get("description", "")
            if "/mnt/e/genesis-system" in desc:
                import re
                match = re.search(r"(/mnt/e/genesis-system/[^\s]+\.py)", desc)
                if match:
                    output_path = match.group(1)

            # Execute task
            if output_path:
                result = await worker.execute_code_task(task, output_path)
            else:
                result = await worker.execute_task(task)

            # Update task status
            self.redis.hdel(self.active_key, task["id"])

            if result["success"]:
                task["status"] = "completed"
                task["completed_at"] = datetime.now().isoformat()
                task["result"] = result.get("file_path", "completed")
                task["cost"] = result.get("cost", 0)
                self.redis.hset(self.completed_key, task["id"], json.dumps(task))
                logger.info(f"✅ Completed: {task['title'][:40]} | Cost: ${result.get('cost', 0):.4f}")
            else:
                task["status"] = "failed"
                task["error"] = result.get("error", "Unknown error")
                self.redis.hset(self.failed_key, task["id"], json.dumps(task))
                logger.warning(f"❌ Failed: {task['title'][:40]}")

            processed += 1

            # Brief pause between tasks
            await asyncio.sleep(2)

        # Print stats
        total_cost = sum(w.total_cost for w in self.workers)
        total_tokens = sum(w.total_tokens for w in self.workers)
        logger.info(f"\n{'='*50}")
        logger.info(f"Session complete: {processed} tasks")
        logger.info(f"Total tokens: {total_tokens:,}")
        logger.info(f"Total cost: ${total_cost:.4f}")
        logger.info(f"{'='*50}")


async def main():
    """Main entry point."""
    import argparse

    parser = argparse.ArgumentParser(description="Gemini 2.0 Flash Worker")
    parser.add_argument("command", choices=["run", "test", "stats"])
    parser.add_argument("--max-tasks", type=int, default=100)
    parser.add_argument("--workers", type=int, default=4)

    args = parser.parse_args()

    if args.command == "run":
        runner = GeminiTaskRunner(num_workers=args.workers)
        await runner.process_queue(max_tasks=args.max_tasks)

    elif args.command == "test":
        worker = GeminiWorker()
        result = await worker.execute_task({
            "title": "Test task",
            "description": "Write a simple Python hello world function"
        })
        print(json.dumps(result, indent=2))

    elif args.command == "stats":
        worker = GeminiWorker()
        print(json.dumps(worker.get_stats(), indent=2))


if __name__ == "__main__":
    asyncio.run(main())
