#!/usr/bin/env python3
"""
GENESIS KIMI EXECUTOR
======================
Moonshot AI Kimi K2.5 as a Genesis execution backend.

Uses Moonshot's OpenAI-compatible API directly (not via OpenRouter).
Base URL: https://api.moonshot.cn/v1
Models: moonshot-v1-8k, moonshot-v1-32k, moonshot-v1-128k

Same interface as gemini_executor: execute_task_sync() + execute_tasks_parallel().

Model routing:
  - moonshot-v1-8k    (short tasks, rapid Q&A, classification)   $0.15/1M tokens
  - moonshot-v1-32k   (typical Genesis stories, code gen)         $0.15/1M tokens
  - moonshot-v1-128k  (long documents, multi-file, deep research) $0.60/1M tokens

Usage:
    executor = KimiExecutor()
    result = executor.execute_task_sync("What is 2 + 2?")
    print(result.response)

    # Parallel swarm
    results = executor.execute_tasks_parallel([
        "Analyse revenue pipeline",
        "Draft outreach email for TradiesVoice",
        "Research competitor pricing",
    ])
    for r in results:
        print(r.response)

Author: Genesis System
Version: 1.0.0
"""

import asyncio
import json
import os
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

# Auto-load .env so MOONSHOT_API_KEY is always available without manual export
try:
    from dotenv import load_dotenv
    _env_path = Path(__file__).parent.parent / ".env"
    load_dotenv(_env_path)
except ImportError:
    pass  # dotenv not installed — rely on shell env


# ─────────────────────────────────────────────────────────────────────────────
# Constants & Pricing
# ─────────────────────────────────────────────────────────────────────────────

MOONSHOT_BASE_URL = "https://api.moonshot.cn/v1"

MOONSHOT_MODELS = {
    "8k":    "moonshot-v1-8k",
    "32k":   "moonshot-v1-32k",
    "128k":  "moonshot-v1-128k",
    # Friendly aliases used by genesis_execution_layer
    "fast":  "moonshot-v1-8k",
    "standard": "moonshot-v1-32k",
    "max":   "moonshot-v1-128k",
}

# USD per 1M tokens (input = output pricing, Moonshot flat rate)
MOONSHOT_PRICING: Dict[str, Dict[str, float]] = {
    "moonshot-v1-8k":   {"input": 0.15, "output": 0.15},
    "moonshot-v1-32k":  {"input": 0.15, "output": 0.15},
    "moonshot-v1-128k": {"input": 0.60, "output": 0.60},
}

DEFAULT_MODEL = "moonshot-v1-32k"
MAX_CONCURRENT_WORKERS = 50  # Conservative ceiling; Moonshot supports more
DEFAULT_CONCURRENT_WORKERS = 10

# Retry configuration
MAX_RETRIES = 4
RETRY_BASE_DELAY = 1.0   # seconds
RETRY_MAX_DELAY = 60.0   # seconds
RATE_LIMIT_CODES = {429}
RETRYABLE_CODES = {429, 500, 502, 503, 504}


# ─────────────────────────────────────────────────────────────────────────────
# Data Classes
# ─────────────────────────────────────────────────────────────────────────────

@dataclass
class ExecutionResult:
    """
    Unified result from Kimi execution.

    Mirrors the GeminiResponse pattern so the two executors are
    interchangeable in the Genesis Execution Layer.
    """
    success: bool
    response: str
    model: str
    tokens_used: int
    cost_estimate: float
    execution_time: float
    task_complete: bool
    error: Optional[str] = None
    # Kimi-specific extras
    prompt_tokens: int = 0
    completion_tokens: int = 0

    def __str__(self) -> str:
        return self.response


@dataclass
class KimiSwarmResult:
    """Aggregate result from a parallel swarm run."""
    results: List[ExecutionResult] = field(default_factory=list)
    total_tokens: int = 0
    total_cost: float = 0.0
    elapsed_seconds: float = 0.0
    success_count: int = 0
    failure_count: int = 0

    @property
    def responses(self) -> List[str]:
        """All response texts (including failures as empty strings)."""
        return [r.response for r in self.results]

    @property
    def successful_responses(self) -> List[str]:
        """Only successful response texts."""
        return [r.response for r in self.results if r.success]

    @property
    def all_success(self) -> bool:
        return self.failure_count == 0


# ─────────────────────────────────────────────────────────────────────────────
# Main Executor
# ─────────────────────────────────────────────────────────────────────────────

class KimiExecutor:
    """
    Direct Moonshot API executor for Genesis.

    Provides the same interface as GeminiExecutor:
      - execute_task_sync(task)  — synchronous single call
      - execute_tasks_parallel(tasks) — parallel swarm execution

    Thread-safe; safe to instantiate multiple times (each instance holds its
    own OpenAI client and usage counters).
    """

    def __init__(
        self,
        api_key: Optional[str] = None,
        default_model: str = "standard",
        max_workers: int = DEFAULT_CONCURRENT_WORKERS,
    ):
        self.api_key = api_key or self._load_api_key()
        self.default_model = MOONSHOT_MODELS.get(default_model, DEFAULT_MODEL)
        self.max_workers = min(max_workers, MAX_CONCURRENT_WORKERS)

        self.usage_log_path = Path("E:/genesis-system/data/kimi_executor_usage.jsonl")
        self.usage_log_path.parent.mkdir(parents=True, exist_ok=True)

        self.total_calls = 0
        self.total_tokens = 0
        self.total_cost = 0.0

        # Lazy OpenAI client
        self._client = None

    # ── API key loading ──────────────────────────────────────────────────────

    @staticmethod
    def _load_api_key() -> Optional[str]:
        """Load API key from environment (prefers MOONSHOT_API_KEY)."""
        for var in ("MOONSHOT_API_KEY", "MOONSHOT_KEY"):
            key = os.environ.get(var, "").strip()
            if key:
                return key
        return None

    def is_configured(self) -> bool:
        """Return True if the API key is available."""
        return bool(self.api_key)

    # ── OpenAI client ────────────────────────────────────────────────────────

    def _get_client(self):
        """Lazy-load and return the OpenAI-compatible client."""
        if self._client is None:
            try:
                from openai import OpenAI  # type: ignore
            except ImportError as exc:
                raise RuntimeError(
                    "openai package not installed. Run: pip install openai"
                ) from exc
            self._client = OpenAI(
                api_key=self.api_key,
                base_url=MOONSHOT_BASE_URL,
            )
        return self._client

    # ── Cost estimation ──────────────────────────────────────────────────────

    @staticmethod
    def _estimate_cost(model: str, prompt_tokens: int, completion_tokens: int) -> float:
        """Return estimated USD cost for a single call."""
        pricing = MOONSHOT_PRICING.get(model, {"input": 0.15, "output": 0.15})
        return (
            (prompt_tokens / 1_000_000) * pricing["input"]
            + (completion_tokens / 1_000_000) * pricing["output"]
        )

    # ── Model selection ──────────────────────────────────────────────────────

    @staticmethod
    def _select_model_for_task(task: str) -> str:
        """
        Auto-select model based on task length.

        Heuristic:
          < 2 000 words  → 8k  (cheapest)
          2 000–8 000    → 32k
          > 8 000        → 128k (long context)
        """
        word_count = len(task.split())
        if word_count > 8_000:
            return "moonshot-v1-128k"
        if word_count > 2_000:
            return "moonshot-v1-32k"
        return "moonshot-v1-8k"

    # ── Core execution ───────────────────────────────────────────────────────

    def execute(
        self,
        prompt: str,
        model: Optional[str] = None,
        system_prompt: Optional[str] = None,
        max_tokens: int = 8_192,
        temperature: float = 0.7,
    ) -> ExecutionResult:
        """
        Execute a single prompt with exponential-backoff retry on rate limits.

        Args:
            prompt:        The user prompt / task description.
            model:         Model key ("fast", "standard", "max") or full name.
                           If None, auto-selected from prompt length.
            system_prompt: Optional system message.
            max_tokens:    Maximum output tokens.
            temperature:   Sampling temperature (0–1).

        Returns:
            ExecutionResult with response and cost metadata.
        """
        if not self.is_configured():
            return ExecutionResult(
                success=False,
                response="",
                model=model or self.default_model,
                tokens_used=0,
                cost_estimate=0.0,
                execution_time=0.0,
                task_complete=False,
                error="MOONSHOT_API_KEY not set. Add to E:\\genesis-system\\.env",
            )

        resolved_model = (
            MOONSHOT_MODELS.get(model, model)
            if model
            else self._select_model_for_task(prompt)
        )

        messages: List[Dict[str, str]] = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})

        start_time = time.time()
        last_error: Optional[str] = None

        for attempt in range(MAX_RETRIES):
            try:
                client = self._get_client()
                response = client.chat.completions.create(
                    model=resolved_model,
                    messages=messages,
                    max_tokens=max_tokens,
                    temperature=temperature,
                )

                elapsed = time.time() - start_time
                usage = response.usage
                prompt_tokens = usage.prompt_tokens if usage else 0
                completion_tokens = usage.completion_tokens if usage else 0
                total_tokens = usage.total_tokens if usage else 0
                cost = self._estimate_cost(resolved_model, prompt_tokens, completion_tokens)

                text = response.choices[0].message.content or ""
                task_complete = "TASK_COMPLETE" in text

                result = ExecutionResult(
                    success=True,
                    response=text,
                    model=resolved_model,
                    tokens_used=total_tokens,
                    cost_estimate=cost,
                    execution_time=elapsed,
                    task_complete=task_complete,
                    prompt_tokens=prompt_tokens,
                    completion_tokens=completion_tokens,
                )

                # Update stats
                self.total_calls += 1
                self.total_tokens += total_tokens
                self.total_cost += cost
                self._log_usage(result, prompt[:120])

                return result

            except Exception as exc:
                last_error = str(exc)
                status_code = getattr(getattr(exc, "response", None), "status_code", None)

                # Determine if we should retry
                is_rate_limit = (
                    status_code in RATE_LIMIT_CODES
                    or "rate limit" in last_error.lower()
                    or "429" in last_error
                )
                is_retryable = (
                    status_code in RETRYABLE_CODES
                    or is_rate_limit
                    or "timeout" in last_error.lower()
                    or "connection" in last_error.lower()
                )

                if not is_retryable or attempt == MAX_RETRIES - 1:
                    break

                # Exponential backoff
                delay = min(RETRY_BASE_DELAY * (2 ** attempt), RETRY_MAX_DELAY)
                if is_rate_limit:
                    delay = max(delay, 5.0)  # Rate limits get at least 5s breathing room
                time.sleep(delay)

        elapsed = time.time() - start_time
        return ExecutionResult(
            success=False,
            response="",
            model=resolved_model,
            tokens_used=0,
            cost_estimate=0.0,
            execution_time=elapsed,
            task_complete=False,
            error=last_error,
        )

    # ── Parallel swarm execution ─────────────────────────────────────────────

    def execute_tasks_parallel(
        self,
        tasks: List[str],
        model: Optional[str] = None,
        system_prompt: Optional[str] = None,
        max_tokens: int = 8_192,
        temperature: float = 0.7,
        max_workers: Optional[int] = None,
    ) -> KimiSwarmResult:
        """
        Execute multiple tasks in parallel using asyncio + thread pool.

        Up to 50 concurrent workers (configurable).
        Tasks are returned in the same order they were submitted.

        Args:
            tasks:         List of prompt strings.
            model:         Model key or full name (None = auto-select per task).
            system_prompt: Shared system message applied to every task.
            max_tokens:    Max output tokens per task.
            temperature:   Sampling temperature.
            max_workers:   Override concurrent workers (default: self.max_workers).

        Returns:
            KimiSwarmResult with ordered results and aggregate stats.
        """
        if not tasks:
            return KimiSwarmResult()

        effective_workers = min(
            max_workers or self.max_workers,
            MAX_CONCURRENT_WORKERS,
            len(tasks),
        )

        start_time = time.time()

        # Run via asyncio — each task dispatched to the thread executor so the
        # blocking OpenAI SDK calls don't block the event loop.
        async def _run_all() -> List[ExecutionResult]:
            sem = asyncio.Semaphore(effective_workers)
            loop = asyncio.get_event_loop()

            async def _run_one(task_prompt: str) -> ExecutionResult:
                async with sem:
                    return await loop.run_in_executor(
                        None,
                        lambda p=task_prompt: self.execute(
                            prompt=p,
                            model=model,
                            system_prompt=system_prompt,
                            max_tokens=max_tokens,
                            temperature=temperature,
                        ),
                    )

            return await asyncio.gather(*[_run_one(t) for t in tasks])

        # Handle both "inside" and "outside" event-loop contexts
        try:
            loop = asyncio.get_event_loop()
            if loop.is_running():
                # We are already inside an asyncio event loop (e.g. Jupyter)
                import concurrent.futures
                with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
                    future = pool.submit(asyncio.run, _run_all())
                    results_list: List[ExecutionResult] = future.result()
            else:
                results_list = loop.run_until_complete(_run_all())
        except RuntimeError:
            results_list = asyncio.run(_run_all())

        elapsed = time.time() - start_time

        total_tokens = sum(r.tokens_used for r in results_list)
        total_cost = sum(r.cost_estimate for r in results_list)
        success_count = sum(1 for r in results_list if r.success)
        failure_count = len(results_list) - success_count

        return KimiSwarmResult(
            results=results_list,
            total_tokens=total_tokens,
            total_cost=total_cost,
            elapsed_seconds=elapsed,
            success_count=success_count,
            failure_count=failure_count,
        )

    # ── Sync convenience wrappers ────────────────────────────────────────────

    def execute_task_sync(self, task: str, **kwargs) -> ExecutionResult:
        """
        Synchronous single-task execution.

        Matches the genesis_execution_layer.execute_task_sync() signature.
        """
        return self.execute(prompt=task, **kwargs)

    # ── Logging ──────────────────────────────────────────────────────────────

    def _log_usage(self, result: ExecutionResult, prompt_preview: str) -> None:
        """Append a JSONL log entry for budget tracking."""
        entry = {
            "timestamp": datetime.now().isoformat(),
            "model": result.model,
            "prompt_tokens": result.prompt_tokens,
            "completion_tokens": result.completion_tokens,
            "total_tokens": result.tokens_used,
            "cost_usd": round(result.cost_estimate, 8),
            "execution_time_s": round(result.execution_time, 3),
            "success": result.success,
            "task_complete": result.task_complete,
            "prompt_preview": prompt_preview,
        }
        try:
            with open(self.usage_log_path, "a", encoding="utf-8") as fh:
                fh.write(json.dumps(entry) + "\n")
        except Exception:
            pass  # Never let logging break execution

    # ── Status ───────────────────────────────────────────────────────────────

    def get_status(self) -> Dict[str, Any]:
        """Return current executor status (mirrors GeminiExecutor pattern)."""
        return {
            "configured": self.is_configured(),
            "default_model": self.default_model,
            "base_url": MOONSHOT_BASE_URL,
            "max_workers": self.max_workers,
            "total_calls": self.total_calls,
            "total_tokens": self.total_tokens,
            "total_cost_usd": round(self.total_cost, 6),
            "api_key_preview": (
                f"{self.api_key[:8]}...{self.api_key[-4:]}"
                if self.api_key
                else "NOT SET"
            ),
        }

    def get_budget_status(self) -> Dict[str, Any]:
        """Get cost tracking (parallel to GeminiExecutor.get_budget_status)."""
        spent = 0.0
        calls = 0
        if self.usage_log_path.exists():
            with open(self.usage_log_path, "r", encoding="utf-8") as fh:
                for line in fh:
                    try:
                        entry = json.loads(line)
                        spent += entry.get("cost_usd", 0.0)
                        calls += 1
                    except (json.JSONDecodeError, KeyError):
                        pass
        avg = spent / calls if calls > 0 else 0.0
        return {
            "total_spent_usd": round(spent, 6),
            "total_calls": calls,
            "avg_cost_per_call": round(avg, 8),
            "models": list(MOONSHOT_MODELS.values()),
        }


# ─────────────────────────────────────────────────────────────────────────────
# Module-Level Convenience Shims (genesis_execution_layer compatible)
# ─────────────────────────────────────────────────────────────────────────────

# Shared singleton for module-level calls
_default_executor: Optional[KimiExecutor] = None


def _get_default_executor() -> KimiExecutor:
    global _default_executor
    if _default_executor is None:
        _default_executor = KimiExecutor()
    return _default_executor


def execute_task_sync(task: str, **kwargs) -> ExecutionResult:
    """
    Drop-in replacement for genesis_execution_layer.execute_task_sync().

    Routes through the Moonshot API with auto model selection.
    """
    return _get_default_executor().execute_task_sync(task, **kwargs)


def execute_tasks_parallel(tasks: List[str], **kwargs) -> KimiSwarmResult:
    """
    Parallel swarm execution — up to 50 concurrent workers.

    Example:
        from core.kimi_executor import execute_tasks_parallel
        results = execute_tasks_parallel([
            "Summarise the Q1 revenue pipeline",
            "Draft outreach email for TradiesVoice",
        ])
        for r in results.results:
            print(r.response)
    """
    return _get_default_executor().execute_tasks_parallel(tasks, **kwargs)


# ─────────────────────────────────────────────────────────────────────────────
# Live Test
# ─────────────────────────────────────────────────────────────────────────────

def _run_live_test() -> bool:
    """
    Execute a live API call against Moonshot and print results.

    Returns True on SUCCESS, False on failure.
    """
    print("=" * 60)
    print("KIMI EXECUTOR — LIVE TEST")
    print("=" * 60)

    executor = KimiExecutor()

    if not executor.is_configured():
        print("[FAIL] MOONSHOT_API_KEY not set.")
        print("       Add to E:\\genesis-system\\.env:")
        print("       MOONSHOT_API_KEY=sk-...")
        return False

    status = executor.get_status()
    print(f"API Key : {status['api_key_preview']}")
    print(f"Model   : {status['default_model']}")
    print(f"Workers : {status['max_workers']}")
    print()

    # ── Test 1: Single execution ──
    print("--- Test 1: Single execution (moonshot-v1-8k) ---")
    result = executor.execute(
        prompt="What is 2 + 2? Reply with just the number.",
        model="fast",  # moonshot-v1-8k — cheapest
        max_tokens=16,
    )

    if result.success:
        print(f"  SUCCESS")
        print(f"  Response : {result.response.strip()}")
        print(f"  Tokens   : {result.tokens_used}")
        print(f"  Cost     : ${result.cost_estimate:.8f}")
        print(f"  Latency  : {result.execution_time:.2f}s")
        print(f"  Model    : {result.model}")
    else:
        print(f"  FAILED: {result.error}")
        return False

    print()

    # ── Test 2: Parallel swarm (3 tasks) ──
    print("--- Test 2: Parallel swarm (3 tasks, moonshot-v1-8k) ---")
    swarm_result = executor.execute_tasks_parallel(
        tasks=[
            "Reply with just the word: ALPHA",
            "Reply with just the word: BETA",
            "Reply with just the word: GAMMA",
        ],
        model="fast",
        max_tokens=16,
        max_workers=3,
    )

    if swarm_result.all_success:
        print(f"  SUCCESS — {swarm_result.success_count}/{len(swarm_result.results)} tasks passed")
        for i, r in enumerate(swarm_result.results, 1):
            print(f"  Task {i}: {r.response.strip()!r}")
        print(f"  Total tokens : {swarm_result.total_tokens}")
        print(f"  Total cost   : ${swarm_result.total_cost:.8f}")
        print(f"  Wall time    : {swarm_result.elapsed_seconds:.2f}s")
    else:
        print(f"  PARTIAL FAILURE — {swarm_result.failure_count} tasks failed")
        for r in swarm_result.results:
            if not r.success:
                print(f"  Error: {r.error}")
        # Still count as SUCCESS if at least 2/3 passed (rate limits are transient)
        if swarm_result.success_count >= 2:
            print("  (Partial success — rate limit transient, acceptable)")
        else:
            return False

    print()
    print("=" * 60)
    print("KIMI EXECUTOR TEST: SUCCESS")
    print("=" * 60)
    return True


if __name__ == "__main__":
    import sys
    ok = _run_live_test()
    sys.exit(0 if ok else 1)
