"""
PM-002: Cost Tracker Core
Real-time cost tracking per API call for Genesis.

Acceptance Criteria:
- [x] GIVEN API call WHEN completed THEN budget decremented
- [x] AND if budget < 10% THEN alert logged
- [x] AND if budget exhausted THEN raise BudgetExhaustedError
- [x] AND tracks per-model costs accurately

Dependencies: PM-001 (api_token_manager.py)
"""

import os
import json
import logging
from datetime import datetime
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, asdict
from pathlib import Path

from core.api_token_manager import TokenManager, BudgetExhaustedError, get_token_manager

logger = logging.getLogger(__name__)


# Model cost configuration (cost per 1K tokens)
MODEL_COSTS = {
    # Gemini models
    "gemini-flash": {"input": 0.00001, "output": 0.00004},
    "gemini-1.5-flash": {"input": 0.00001, "output": 0.00004},
    "gemini-2.0-flash": {"input": 0.00001, "output": 0.00004},  # Retiring 2026-03-03
    "gemini-2.5-flash": {"input": 0.000075, "output": 0.00030},
    "gemini-pro": {"input": 0.00025, "output": 0.00075},
    "gemini-1.5-pro": {"input": 0.00125, "output": 0.00375},
    "gemini-2.0-pro": {"input": 0.00125, "output": 0.00375},

    # Claude models
    "claude-3-haiku": {"input": 0.00025, "output": 0.00125},
    "claude-3-sonnet": {"input": 0.003, "output": 0.015},
    "claude-3-5-sonnet": {"input": 0.003, "output": 0.015},
    "claude-3-opus": {"input": 0.015, "output": 0.075},
    "claude-opus-4-5": {"input": 0.015, "output": 0.075},
    "claude-sonnet-4": {"input": 0.003, "output": 0.015},
}

# Map model names to providers
MODEL_PROVIDERS = {
    "gemini-flash": "gemini",
    "gemini-1.5-flash": "gemini",
    "gemini-2.0-flash": "gemini",
    "gemini-2.5-flash": "gemini",
    "gemini-pro": "gemini",
    "gemini-1.5-pro": "gemini",
    "gemini-2.0-pro": "gemini",
    "claude-3-haiku": "anthropic",
    "claude-3-sonnet": "anthropic",
    "claude-3-5-sonnet": "anthropic",
    "claude-3-opus": "anthropic",
    "claude-opus-4-5": "anthropic",
    "claude-sonnet-4": "anthropic",
}


@dataclass
class APICallRecord:
    """Record of an API call for cost tracking."""
    timestamp: str
    provider: str
    model: str
    input_tokens: int
    output_tokens: int
    cost: float
    task_id: Optional[str] = None
    success: bool = True
    error_message: Optional[str] = None
    duration_ms: Optional[int] = None
    metadata: Optional[Dict[str, Any]] = None

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


class CostTracker:
    """
    Real-time cost tracking per API call.

    Features:
    - Automatic budget deduction after each call
    - Per-model cost calculation
    - Low budget alerts (10%, 20%)
    - BudgetExhaustedError when budget depleted
    - Detailed call history
    """

    def __init__(self, token_manager: Optional[TokenManager] = None):
        """
        Initialize CostTracker.

        Args:
            token_manager: TokenManager instance. Uses global if not provided.
        """
        self.token_manager = token_manager or get_token_manager()
        self.call_history: List[APICallRecord] = []
        self.model_costs: Dict[str, float] = {}  # Track per-model spending
        self._alert_thresholds_crossed: Dict[str, set] = {
            "anthropic": set(),
            "gemini": set()
        }

    def calculate_cost(self,
                      model: str,
                      input_tokens: int,
                      output_tokens: int) -> float:
        """
        Calculate cost for an API call.

        Args:
            model: Model name
            input_tokens: Number of input tokens
            output_tokens: Number of output tokens

        Returns:
            Cost in dollars
        """
        # Normalize model name
        model_key = model.lower().replace("_", "-")

        # Find matching model cost
        cost_config = None
        for key, config in MODEL_COSTS.items():
            if key in model_key or model_key in key:
                cost_config = config
                break

        if cost_config is None:
            logger.warning(f"Unknown model '{model}', using default cost")
            cost_config = {"input": 0.001, "output": 0.002}

        input_cost = (input_tokens / 1000) * cost_config["input"]
        output_cost = (output_tokens / 1000) * cost_config["output"]

        return input_cost + output_cost

    def get_provider_for_model(self, model: str) -> str:
        """
        Get the provider for a model.

        Args:
            model: Model name

        Returns:
            Provider name ('anthropic' or 'gemini')
        """
        model_key = model.lower().replace("_", "-")

        for key, provider in MODEL_PROVIDERS.items():
            if key in model_key or model_key in key:
                return provider

        # Default heuristic
        if "gemini" in model_key:
            return "gemini"
        return "anthropic"

    def track_call(self,
                  model: str,
                  input_tokens: int,
                  output_tokens: int,
                  task_id: Optional[str] = None,
                  success: bool = True,
                  error_message: Optional[str] = None,
                  duration_ms: Optional[int] = None,
                  metadata: Optional[Dict[str, Any]] = None) -> APICallRecord:
        """
        Track an API call and deduct from budget.

        Args:
            model: Model name used
            input_tokens: Number of input tokens
            output_tokens: Number of output tokens
            task_id: Task identifier
            success: Whether the call succeeded
            error_message: Error message if failed
            duration_ms: Call duration in milliseconds
            metadata: Additional metadata

        Returns:
            APICallRecord with call details

        Raises:
            BudgetExhaustedError: If budget is exhausted
        """
        provider = self.get_provider_for_model(model)
        cost = self.calculate_cost(model, input_tokens, output_tokens)

        # Check budget before deducting
        remaining = self.token_manager.get_remaining(provider)
        if remaining <= 0:
            raise BudgetExhaustedError(provider, remaining)

        # Create record
        record = APICallRecord(
            timestamp=datetime.utcnow().isoformat(),
            provider=provider,
            model=model,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            cost=cost,
            task_id=task_id,
            success=success,
            error_message=error_message,
            duration_ms=duration_ms,
            metadata=metadata
        )

        # Deduct cost from budget
        self.token_manager.deduct_cost(
            provider=provider,
            amount=cost,
            model=model,
            task_id=task_id,
            metadata=metadata
        )

        # Track per-model costs
        if model not in self.model_costs:
            self.model_costs[model] = 0.0
        self.model_costs[model] += cost

        # Add to history
        self.call_history.append(record)

        # Check for alerts
        self._check_budget_alerts(provider)

        logger.debug(f"Tracked API call: {model}, cost=${cost:.6f}, remaining=${self.token_manager.get_remaining(provider):.2f}")

        return record

    def _check_budget_alerts(self, provider: str) -> None:
        """Check and log budget alerts."""
        percentage = self.token_manager.get_remaining_percentage(provider)
        remaining = self.token_manager.get_remaining(provider)

        thresholds = [(10, "CRITICAL"), (20, "WARNING")]

        for threshold, level in thresholds:
            if percentage < threshold and threshold not in self._alert_thresholds_crossed[provider]:
                self._alert_thresholds_crossed[provider].add(threshold)

                if level == "CRITICAL":
                    logger.error(
                        f"[{level}] {provider.upper()} budget at {percentage:.1f}% "
                        f"(${remaining:.2f} remaining). Execution may halt soon!"
                    )
                else:
                    logger.warning(
                        f"[{level}] {provider.upper()} budget at {percentage:.1f}% "
                        f"(${remaining:.2f} remaining)"
                    )

    def can_afford_model(self,
                        model: str,
                        estimated_input_tokens: int = 1000,
                        estimated_output_tokens: int = 1000) -> bool:
        """
        Check if budget can afford a model call.

        Args:
            model: Model name
            estimated_input_tokens: Expected input tokens
            estimated_output_tokens: Expected output tokens

        Returns:
            True if budget can cover the estimated cost
        """
        provider = self.get_provider_for_model(model)
        estimated_cost = self.calculate_cost(model, estimated_input_tokens, estimated_output_tokens)
        return self.token_manager.can_afford(provider, estimated_cost)

    def get_model_spending(self) -> Dict[str, float]:
        """Get spending breakdown by model."""
        return self.model_costs.copy()

    def get_provider_spending(self) -> Dict[str, float]:
        """Get spending breakdown by provider."""
        spending = {"anthropic": 0.0, "gemini": 0.0}
        for model, cost in self.model_costs.items():
            provider = self.get_provider_for_model(model)
            spending[provider] += cost
        return spending

    def get_statistics(self) -> Dict[str, Any]:
        """Get comprehensive cost statistics."""
        provider_spending = self.get_provider_spending()

        return {
            "total_calls": len(self.call_history),
            "successful_calls": sum(1 for r in self.call_history if r.success),
            "failed_calls": sum(1 for r in self.call_history if not r.success),
            "total_cost": sum(r.cost for r in self.call_history),
            "total_input_tokens": sum(r.input_tokens for r in self.call_history),
            "total_output_tokens": sum(r.output_tokens for r in self.call_history),
            "spending_by_model": self.get_model_spending(),
            "spending_by_provider": provider_spending,
            "budgets": {
                provider: {
                    "budget": self.token_manager.budgets[provider],
                    "spent": self.token_manager.spent.get(provider, 0.0),
                    "remaining": self.token_manager.get_remaining(provider),
                    "percentage_remaining": self.token_manager.get_remaining_percentage(provider)
                }
                for provider in ["anthropic", "gemini"]
            }
        }

    def export_history(self, filepath: Optional[str] = None) -> str:
        """
        Export call history to JSONL file.

        Args:
            filepath: Output file path. Defaults to logs/cost_history.jsonl

        Returns:
            Path to exported file
        """
        filepath = filepath or "logs/cost_history.jsonl"
        Path(filepath).parent.mkdir(parents=True, exist_ok=True)

        with open(filepath, "w") as f:
            for record in self.call_history:
                f.write(json.dumps(record.to_dict()) + "\n")

        return filepath


# Singleton instance
_cost_tracker: Optional[CostTracker] = None


def get_cost_tracker() -> CostTracker:
    """Get or create global CostTracker instance."""
    global _cost_tracker
    if _cost_tracker is None:
        _cost_tracker = CostTracker()
    return _cost_tracker


if __name__ == "__main__":
    # Test the CostTracker
    logging.basicConfig(level=logging.INFO)

    tracker = CostTracker()

    # Simulate API calls
    try:
        tracker.track_call(
            model="gemini-flash",
            input_tokens=1000,
            output_tokens=500,
            task_id="test-001"
        )
        tracker.track_call(
            model="claude-3-sonnet",
            input_tokens=2000,
            output_tokens=1000,
            task_id="test-002"
        )
        tracker.track_call(
            model="gemini-2.0-pro",
            input_tokens=5000,
            output_tokens=3000,
            task_id="test-003"
        )

        print("\nCost Statistics:")
        print(json.dumps(tracker.get_statistics(), indent=2))

    except BudgetExhaustedError as e:
        print(f"Budget exhausted: {e}")
