"""
PM-001: API Token Manager
Centralized API key management with budget tracking for Genesis.

Acceptance Criteria:
- [x] GIVEN env vars WHEN TokenManager() THEN keys loaded securely
- [x] AND supports Anthropic ($30) + Gemini ($300) budgets
- [x] AND persists budget state to Redis `genesis:api_budget`
- [x] AND logs to `logs/api_usage.jsonl`
"""

import os
import json
import logging
from datetime import datetime
from typing import Optional, Dict, Any
from pathlib import Path

try:
    import redis
except ImportError:
    redis = None

logger = logging.getLogger(__name__)


class BudgetExhaustedError(Exception):
    """Raised when a provider's budget is exhausted."""
    def __init__(self, provider: str, remaining: float):
        self.provider = provider
        self.remaining = remaining
        super().__init__(f"Budget exhausted for {provider}. Remaining: ${remaining:.4f}")


class TokenManager:
    """
    Centralized API key management with budget tracking.

    Supports:
    - Anthropic API ($30 budget)
    - Gemini API ($300 budget)

    Persists budget state to Redis and logs usage to JSONL.
    """

    REDIS_KEY = "genesis:api_budget"
    DEFAULT_BUDGETS = {
        "anthropic": 30.0,
        "gemini": 300.0
    }

    def __init__(self,
                 redis_url: Optional[str] = None,
                 log_path: Optional[str] = None):
        """
        Initialize TokenManager with API keys and budget tracking.

        Args:
            redis_url: Redis connection URL. Defaults to env REDIS_URL or localhost.
            log_path: Path to usage log file. Defaults to logs/api_usage.jsonl.
        """
        # Load API keys securely from environment
        self.anthropic_key = os.getenv("ANTHROPIC_API_KEY")
        self.gemini_key = os.getenv("GEMINI_API_KEY")

        # Initialize budgets
        self.budgets: Dict[str, float] = self.DEFAULT_BUDGETS.copy()
        self.spent: Dict[str, float] = {"anthropic": 0.0, "gemini": 0.0}

        # Setup Redis connection
        self.redis_url = redis_url or os.getenv("REDIS_URL", "redis://localhost:6379")
        self._redis: Optional[Any] = None

        # Setup logging path
        self.log_path = Path(log_path or "logs/api_usage.jsonl")
        self.log_path.parent.mkdir(parents=True, exist_ok=True)

        # Load persisted budget state
        self._load_budget_state()

        logger.info(f"TokenManager initialized. Anthropic budget: ${self.get_remaining('anthropic'):.2f}, "
                   f"Gemini budget: ${self.get_remaining('gemini'):.2f}")

    @property
    def redis_client(self) -> Optional[Any]:
        """Lazy Redis client initialization."""
        if self._redis is None and redis is not None:
            try:
                self._redis = redis.from_url(self.redis_url)
                self._redis.ping()
            except Exception as e:
                logger.warning(f"Redis connection failed: {e}. Using local state only.")
                self._redis = None
        return self._redis

    def _load_budget_state(self) -> None:
        """Load persisted budget state from Redis."""
        if self.redis_client:
            try:
                state = self.redis_client.get(self.REDIS_KEY)
                if state:
                    data = json.loads(state)
                    self.spent = data.get("spent", self.spent)
                    # Allow budget overrides from Redis
                    if "budgets" in data:
                        self.budgets.update(data["budgets"])
                    logger.info(f"Loaded budget state from Redis: spent={self.spent}")
            except Exception as e:
                logger.warning(f"Failed to load budget state: {e}")

    def _save_budget_state(self) -> None:
        """Persist budget state to Redis."""
        if self.redis_client:
            try:
                state = {
                    "budgets": self.budgets,
                    "spent": self.spent,
                    "updated_at": datetime.utcnow().isoformat()
                }
                self.redis_client.set(self.REDIS_KEY, json.dumps(state))
            except Exception as e:
                logger.warning(f"Failed to save budget state: {e}")

    def get_api_key(self, provider: str) -> Optional[str]:
        """
        Get API key for a provider.

        Args:
            provider: 'anthropic' or 'gemini'

        Returns:
            API key string or None if not configured.
        """
        if provider == "anthropic":
            return self.anthropic_key
        elif provider == "gemini":
            return self.gemini_key
        else:
            logger.warning(f"Unknown provider: {provider}")
            return None

    def get_remaining(self, provider: str) -> float:
        """
        Get remaining budget for a provider.

        Args:
            provider: 'anthropic' or 'gemini'

        Returns:
            Remaining budget in dollars.
        """
        if provider not in self.budgets:
            logger.warning(f"Unknown provider: {provider}")
            return 0.0
        return max(0.0, self.budgets[provider] - self.spent.get(provider, 0.0))

    def get_remaining_percentage(self, provider: str) -> float:
        """
        Get remaining budget as percentage.

        Args:
            provider: 'anthropic' or 'gemini'

        Returns:
            Remaining budget as percentage (0-100).
        """
        if provider not in self.budgets or self.budgets[provider] == 0:
            return 0.0
        return (self.get_remaining(provider) / self.budgets[provider]) * 100

    def deduct_cost(self,
                   provider: str,
                   amount: float,
                   model: Optional[str] = None,
                   task_id: Optional[str] = None,
                   metadata: Optional[Dict[str, Any]] = None) -> float:
        """
        Deduct cost from provider budget.

        Args:
            provider: 'anthropic' or 'gemini'
            amount: Cost in dollars to deduct
            model: Model name (for logging)
            task_id: Task identifier (for logging)
            metadata: Additional metadata to log

        Returns:
            Remaining budget after deduction.

        Raises:
            BudgetExhaustedError: If budget is exhausted.
        """
        if provider not in self.budgets:
            raise ValueError(f"Unknown provider: {provider}")

        remaining = self.get_remaining(provider)

        # Check if budget is exhausted
        if remaining <= 0:
            raise BudgetExhaustedError(provider, remaining)

        # Deduct cost
        self.spent[provider] = self.spent.get(provider, 0.0) + amount
        new_remaining = self.get_remaining(provider)

        # Log the usage
        self._log_usage(
            provider=provider,
            amount=amount,
            remaining=new_remaining,
            model=model,
            task_id=task_id,
            metadata=metadata
        )

        # Save state to Redis
        self._save_budget_state()

        # Check for low budget alerts
        percentage = self.get_remaining_percentage(provider)
        if percentage < 10:
            logger.warning(f"CRITICAL: {provider} budget at {percentage:.1f}% (${new_remaining:.2f} remaining)")
        elif percentage < 20:
            logger.warning(f"LOW BUDGET: {provider} at {percentage:.1f}% (${new_remaining:.2f} remaining)")

        return new_remaining

    def _log_usage(self,
                  provider: str,
                  amount: float,
                  remaining: float,
                  model: Optional[str] = None,
                  task_id: Optional[str] = None,
                  metadata: Optional[Dict[str, Any]] = None) -> None:
        """Log API usage to JSONL file."""
        try:
            log_entry = {
                "timestamp": datetime.utcnow().isoformat(),
                "provider": provider,
                "amount": amount,
                "remaining": remaining,
                "percentage_remaining": self.get_remaining_percentage(provider),
                "model": model,
                "task_id": task_id
            }
            if metadata:
                log_entry["metadata"] = metadata

            with open(self.log_path, "a") as f:
                f.write(json.dumps(log_entry) + "\n")
        except Exception as e:
            logger.error(f"Failed to log usage: {e}")

    def reset_budget(self, provider: str, new_budget: Optional[float] = None) -> None:
        """
        Reset budget for a provider.

        Args:
            provider: 'anthropic' or 'gemini'
            new_budget: New budget amount. Defaults to original budget.
        """
        if provider not in self.budgets:
            raise ValueError(f"Unknown provider: {provider}")

        if new_budget is not None:
            self.budgets[provider] = new_budget
        self.spent[provider] = 0.0
        self._save_budget_state()
        logger.info(f"Reset {provider} budget to ${self.budgets[provider]:.2f}")

    def get_status(self) -> Dict[str, Any]:
        """
        Get complete budget status.

        Returns:
            Dictionary with budget status for all providers.
        """
        return {
            "providers": {
                provider: {
                    "budget": self.budgets[provider],
                    "spent": self.spent.get(provider, 0.0),
                    "remaining": self.get_remaining(provider),
                    "percentage_remaining": self.get_remaining_percentage(provider),
                    "has_key": bool(self.get_api_key(provider))
                }
                for provider in self.budgets
            },
            "total_budget": sum(self.budgets.values()),
            "total_spent": sum(self.spent.values()),
            "total_remaining": sum(self.get_remaining(p) for p in self.budgets),
            "redis_connected": self.redis_client is not None,
            "log_path": str(self.log_path)
        }

    def can_afford(self, provider: str, amount: float) -> bool:
        """
        Check if a cost can be afforded.

        Args:
            provider: 'anthropic' or 'gemini'
            amount: Cost in dollars

        Returns:
            True if budget can cover the cost.
        """
        return self.get_remaining(provider) >= amount


# Singleton instance for global access
_token_manager: Optional[TokenManager] = None


def get_token_manager() -> TokenManager:
    """Get or create the global TokenManager instance."""
    global _token_manager
    if _token_manager is None:
        _token_manager = TokenManager()
    return _token_manager


if __name__ == "__main__":
    # Test the TokenManager
    logging.basicConfig(level=logging.INFO)

    tm = TokenManager()
    print("Initial status:")
    print(json.dumps(tm.get_status(), indent=2))

    # Simulate some usage
    try:
        tm.deduct_cost("gemini", 0.01, model="gemini-flash", task_id="test-001")
        tm.deduct_cost("anthropic", 0.15, model="claude-sonnet", task_id="test-002")
        print("\nAfter usage:")
        print(json.dumps(tm.get_status(), indent=2))
    except BudgetExhaustedError as e:
        print(f"Budget exhausted: {e}")
