#!/usr/bin/env python3
"""
KIMI K2.5 SWARM INTEGRATION
============================
Moonshot AI's Kimi K2.5 as a Genesis execution backend.

Native PARL capabilities:
- 100 sub-agents per API call
- 1500 tool calls
- 8K / 32K / 128K context windows
- OpenAI-compatible API (drop-in replacement)

OpenRouter API base URL: https://openrouter.ai/api/v1
Model: moonshotai/kimi-k2.5

Usage:
    from core.kimi_swarm import KimiSwarm, execute_kimi_task, execute_kimi_swarm

    # Single task
    result = execute_kimi_task("Summarise this 100K doc", model="max")

    # Parallel swarm
    results = execute_kimi_swarm([
        {"prompt": "Analyse revenue pipeline"},
        {"prompt": "Draft marketing copy for TradiesVoice"},
        {"prompt": "Research competitor pricing"},
    ])

Author: Genesis System
Version: 1.0.0
"""

import os
import json
import time
import asyncio
import concurrent.futures
import threading
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any

# ─────────────────────────────────────────────
# Configuration
# ─────────────────────────────────────────────

# OpenRouter API key (fallback to MOONSHOT_API_KEY for legacy compatibility)
OPENROUTER_API_KEY = os.getenv(
    "OPENROUTER_API_KEY",
    os.getenv("MOONSHOT_API_KEY", "")
)

# OpenRouter base URL with Kimi K2.5 model
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
KIMI_MODEL = "moonshotai/kimi-k2.5"

KIMI_MODELS = {
    "fast":     KIMI_MODEL,    # Kimi K2.5 via OpenRouter
    "standard": KIMI_MODEL,    # Same model, consistent interface
    "max":      KIMI_MODEL,    # Same model, consistent interface
}

# Pricing estimates (OpenRouter pass-through — approximate)
KIMI_PRICING: Dict[str, Dict[str, float]] = {
    KIMI_MODEL: {"input": 0.60, "output": 0.60},
}

# Default pricing estimates for parallel swarm execution
DEFAULT_SWARM_WORKERS = 10
MAX_SWARM_WORKERS = 50  # Kimi PARL native: up to 100 sub-agents


# ─────────────────────────────────────────────
# Data Classes
# ─────────────────────────────────────────────

@dataclass
class KimiResponse:
    """Structured response from a Kimi API call."""
    text: str
    model: str
    prompt_tokens: int = 0
    completion_tokens: int = 0
    total_tokens: int = 0
    execution_time: float = 0.0
    cost_estimate: float = 0.0
    success: bool = True
    error: Optional[str] = None
    raw_response: Dict[str, Any] = field(default_factory=dict)

    def __str__(self) -> str:
        return self.text


@dataclass
class KimiSwarmResult:
    """Result from a parallel swarm execution."""
    responses: List[KimiResponse] = field(default_factory=list)
    total_tokens: int = 0
    total_cost: float = 0.0
    elapsed_seconds: float = 0.0
    success_count: int = 0
    failure_count: int = 0

    @property
    def texts(self) -> List[str]:
        """Return just the text strings from successful responses."""
        return [r.text for r in self.responses if r.success]

    @property
    def all_success(self) -> bool:
        return self.failure_count == 0


# ─────────────────────────────────────────────
# Core Client
# ─────────────────────────────────────────────

class KimiSwarm:
    """
    Kimi K2.5 swarm executor with native PARL capabilities.

    Singleton-safe. Thread-safe for parallel swarm use.
    Modelled after Genesis's Qwen/Gemini integration patterns.
    """

    _instance = None
    _lock = threading.Lock()

    def __new__(cls, model: str = "standard"):
        # Not a singleton — allow different model instances
        return super().__new__(cls)

    def __init__(self, model: str = "standard"):
        self.model_key = model
        self.model = KIMI_MODELS.get(model, KIMI_MODELS["standard"])
        self.api_key = OPENROUTER_API_KEY
        self.base_url = OPENROUTER_BASE_URL

        # Logging
        self.log_path = Path(__file__).parent.parent / "data" / "kimi_usage.jsonl"
        self.log_path.parent.mkdir(parents=True, exist_ok=True)

        # Stats
        self.total_calls = 0
        self.total_tokens = 0
        self.total_cost = 0.0

        # Lazy OpenAI client
        self._client = None

    def _get_client(self):
        """Lazy-load OpenAI client (OpenAI-compatible Moonshot endpoint)."""
        if self._client is None:
            try:
                from openai import OpenAI
                self._client = OpenAI(
                    api_key=self.api_key,
                    base_url=self.base_url,
                )
            except ImportError:
                raise RuntimeError(
                    "openai package not installed. Run: pip install openai"
                )
        return self._client

    def _estimate_cost(self, model: str, prompt_tokens: int, completion_tokens: int) -> float:
        """Estimate USD cost from token counts."""
        pricing = KIMI_PRICING.get(model, {"input": 0.60, "output": 0.60})
        return (
            (prompt_tokens / 1_000_000) * pricing["input"] +
            (completion_tokens / 1_000_000) * pricing["output"]
        )

    def _log(self, event: str, data: Dict = None):
        """Append a JSONL log entry."""
        entry = {
            "timestamp": datetime.now().isoformat(),
            "event": event,
            "model": self.model,
            "data": data or {}
        }
        try:
            with open(self.log_path, "a", encoding="utf-8") as f:
                f.write(json.dumps(entry) + "\n")
        except Exception:
            pass  # Never let logging break execution

    def _is_configured(self) -> bool:
        """Check if API key is set."""
        return bool(self.api_key and self.api_key.strip())

    def execute(
        self,
        prompt: str,
        system: Optional[str] = None,
        max_tokens: int = 4096,
        temperature: float = 0.7,
        model_override: Optional[str] = None,
    ) -> KimiResponse:
        """
        Execute a single Kimi task.

        Args:
            prompt: The user prompt
            system: Optional system message
            max_tokens: Max tokens to generate
            temperature: Sampling temperature (0-1)
            model_override: Override the default model for this call

        Returns:
            KimiResponse with text and metadata
        """
        if not self._is_configured():
            return KimiResponse(
                text="",
                model=self.model,
                success=False,
                error="OPENROUTER_API_KEY not set. Get key from https://openrouter.ai and add to E:\\genesis-system\\.env"
            )

        model = model_override or self.model
        start_time = time.time()

        messages = []
        if system:
            messages.append({"role": "system", "content": system})
        messages.append({"role": "user", "content": prompt})

        try:
            client = self._get_client()
            response = client.chat.completions.create(
                model=model,
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
            )

            elapsed = time.time() - start_time
            usage = response.usage

            prompt_tokens = usage.prompt_tokens if usage else 0
            completion_tokens = usage.completion_tokens if usage else 0
            total_tokens = usage.total_tokens if usage else 0
            cost = self._estimate_cost(model, prompt_tokens, completion_tokens)

            # Update stats
            self.total_calls += 1
            self.total_tokens += total_tokens
            self.total_cost += cost

            text = response.choices[0].message.content or ""

            self._log("execute_success", {
                "prompt_preview": prompt[:100],
                "tokens": total_tokens,
                "cost_usd": cost,
                "elapsed_s": round(elapsed, 3),
            })

            return KimiResponse(
                text=text,
                model=model,
                prompt_tokens=prompt_tokens,
                completion_tokens=completion_tokens,
                total_tokens=total_tokens,
                execution_time=elapsed,
                cost_estimate=cost,
                success=True,
                raw_response={"id": response.id, "model": response.model},
            )

        except Exception as e:
            elapsed = time.time() - start_time
            error_msg = str(e)

            self._log("execute_error", {
                "prompt_preview": prompt[:100],
                "error": error_msg,
                "elapsed_s": round(elapsed, 3),
            })

            return KimiResponse(
                text="",
                model=model,
                execution_time=elapsed,
                success=False,
                error=error_msg,
            )

    def swarm_execute(
        self,
        tasks: List[Dict[str, Any]],
        max_workers: int = DEFAULT_SWARM_WORKERS,
    ) -> KimiSwarmResult:
        """
        Execute multiple tasks in parallel using ThreadPoolExecutor.

        This leverages Kimi's native PARL architecture — up to 100 concurrent
        sub-agents, each running an independent chain of up to 1500 tool calls.

        Args:
            tasks: List of task dicts with keys:
                   - prompt (required)
                   - system (optional)
                   - max_tokens (optional, default 4096)
                   - temperature (optional, default 0.7)
                   - model_override (optional)
            max_workers: Parallel threads (max 50 via this wrapper; Kimi supports 100 native)

        Returns:
            KimiSwarmResult with all responses and aggregate stats
        """
        if not tasks:
            return KimiSwarmResult()

        start_time = time.time()
        effective_workers = min(max_workers, MAX_SWARM_WORKERS, len(tasks))

        self._log("swarm_start", {
            "task_count": len(tasks),
            "workers": effective_workers,
        })

        def run_task(task: Dict) -> KimiResponse:
            return self.execute(
                prompt=task.get("prompt", ""),
                system=task.get("system", None),
                max_tokens=task.get("max_tokens", 4096),
                temperature=task.get("temperature", 0.7),
                model_override=task.get("model_override", None),
            )

        responses = []
        with concurrent.futures.ThreadPoolExecutor(max_workers=effective_workers) as executor:
            future_map = {executor.submit(run_task, task): i for i, task in enumerate(tasks)}
            # Collect in submission order
            ordered = [None] * len(tasks)
            for future in concurrent.futures.as_completed(future_map):
                idx = future_map[future]
                try:
                    ordered[idx] = future.result()
                except Exception as e:
                    ordered[idx] = KimiResponse(
                        text="", model=self.model, success=False, error=str(e)
                    )
            responses = ordered

        elapsed = time.time() - start_time
        success_count = sum(1 for r in responses if r and r.success)
        failure_count = len(responses) - success_count
        total_tokens = sum(r.total_tokens for r in responses if r)
        total_cost = sum(r.cost_estimate for r in responses if r)

        self._log("swarm_complete", {
            "success": success_count,
            "failed": failure_count,
            "total_tokens": total_tokens,
            "total_cost_usd": total_cost,
            "elapsed_s": round(elapsed, 3),
        })

        return KimiSwarmResult(
            responses=responses,
            total_tokens=total_tokens,
            total_cost=total_cost,
            elapsed_seconds=elapsed,
            success_count=success_count,
            failure_count=failure_count,
        )

    async def execute_async(
        self,
        prompt: str,
        system: Optional[str] = None,
        max_tokens: int = 4096,
        temperature: float = 0.7,
    ) -> KimiResponse:
        """Async wrapper for use in asyncio contexts."""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(
            None,
            lambda: self.execute(prompt, system, max_tokens, temperature)
        )

    async def swarm_execute_async(
        self,
        tasks: List[Dict[str, Any]],
        max_workers: int = DEFAULT_SWARM_WORKERS,
    ) -> KimiSwarmResult:
        """Async wrapper for swarm_execute."""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(
            None,
            lambda: self.swarm_execute(tasks, max_workers)
        )

    def get_status(self) -> Dict[str, Any]:
        """Return current client status."""
        return {
            "configured": self._is_configured(),
            "model": self.model,
            "model_key": self.model_key,
            "base_url": self.base_url,
            "total_calls": self.total_calls,
            "total_tokens": self.total_tokens,
            "total_cost_usd": round(self.total_cost, 6),
            "api_key_set": bool(self.api_key),
            "api_key_preview": f"{self.api_key[:8]}..." if self.api_key else "NOT SET",
        }


# ─────────────────────────────────────────────
# Module-Level Convenience Functions
# ─────────────────────────────────────────────

# Singleton instances per model tier
_instances: Dict[str, KimiSwarm] = {}
_instances_lock = threading.Lock()


def _get_instance(model: str = "standard") -> KimiSwarm:
    """Get or create a singleton KimiSwarm for a given model tier."""
    with _instances_lock:
        if model not in _instances:
            _instances[model] = KimiSwarm(model=model)
        return _instances[model]


def execute_kimi_task(
    prompt: str,
    model: str = "standard",
    system: Optional[str] = None,
    max_tokens: int = 4096,
) -> str:
    """
    Quick single-task Kimi executor.

    Returns the response text directly (empty string on failure).
    Check KimiSwarm.execute() for structured response with metadata.
    """
    swarm = _get_instance(model)
    response = swarm.execute(prompt=prompt, system=system, max_tokens=max_tokens)
    return response.text


def execute_kimi_swarm(
    tasks: List[Dict[str, Any]],
    model: str = "standard",
    max_workers: int = DEFAULT_SWARM_WORKERS,
) -> List[str]:
    """
    Quick parallel swarm executor.

    Returns list of response texts in submission order.

    Args:
        tasks: List of {"prompt": ..., "system": ..., "max_tokens": ...}
        model: "fast" | "standard" | "max"
        max_workers: Parallel threads

    Example:
        results = execute_kimi_swarm([
            {"prompt": "Analyse Q1 revenue"},
            {"prompt": "Draft tradie outreach email"},
            {"prompt": "Research Moonshot pricing"},
        ])
    """
    swarm = _get_instance(model)
    result = swarm.swarm_execute(tasks=tasks, max_workers=max_workers)
    return result.texts


def kimi_status(model: str = "standard") -> Dict[str, Any]:
    """Return status dict for a given model tier."""
    return _get_instance(model).get_status()


# ─────────────────────────────────────────────
# CLI / Quick Test
# ─────────────────────────────────────────────

def _cli_test():
    """Run a quick connectivity test from the command line."""
    import argparse

    parser = argparse.ArgumentParser(description="Kimi K2.5 Swarm — Genesis Integration")
    parser.add_argument("--model", choices=["fast", "standard", "max"], default="fast",
                        help="Model tier to test")
    parser.add_argument("--prompt", type=str, default="What is 2 + 2? Reply with just the number.",
                        help="Test prompt")
    parser.add_argument("--status", action="store_true", help="Show status only")
    args = parser.parse_args()

    if args.status:
        import json as _json
        print(_json.dumps(kimi_status(args.model), indent=2))
        return

    print(f"Testing Kimi K2.5 ({args.model}) ...")
    swarm = KimiSwarm(model=args.model)
    status = swarm.get_status()

    if not status["configured"]:
        print("\n[ERROR] MOONSHOT_API_KEY not set.")
        print("To get your key:")
        print("  1. Go to https://platform.moonshot.cn")
        print("  2. Create account / log in")
        print("  3. Navigate to API Keys section")
        print("  4. Create a new key")
        print("  5. Add to E:\\genesis-system\\.env:")
        print("     MOONSHOT_API_KEY=sk-...")
        return

    response = swarm.execute(args.prompt)

    if response.success:
        print(f"\n[SUCCESS]")
        print(f"Model:    {response.model}")
        print(f"Response: {response.text}")
        print(f"Tokens:   {response.total_tokens}")
        print(f"Cost:     ${response.cost_estimate:.6f}")
        print(f"Latency:  {response.execution_time:.2f}s")
    else:
        print(f"\n[FAILED] {response.error}")


if __name__ == "__main__":
    _cli_test()
