"""
PM-005: Model Tier Loader
Load and manage 3-tier model hierarchy configuration for Genesis.

Acceptance Criteria:
- [x] Tier 1: Gemini Flash, 20 attempts, $0.01/attempt
- [x] Tier 2: Claude Sonnet, 10 attempts, $0.15/attempt
- [x] Tier 3: Claude Opus, 10 attempts, $0.75/attempt
- [x] AND loadable via load_model_tiers()

Dependencies: None
"""

import os
import json
import logging
from typing import Optional, Dict, Any, List
from dataclasses import dataclass
from pathlib import Path

logger = logging.getLogger(__name__)


@dataclass
class TierConfig:
    """Configuration for a single model tier."""
    tier_number: int
    name: str
    provider: str
    model: str
    max_attempts: int
    cost_per_attempt: float
    timeout_seconds: int
    priority: int
    description: str

    @property
    def max_cost(self) -> float:
        """Maximum cost if all attempts used."""
        return self.max_attempts * self.cost_per_attempt

    def to_dict(self) -> Dict[str, Any]:
        return {
            "tier_number": self.tier_number,
            "name": self.name,
            "provider": self.provider,
            "model": self.model,
            "max_attempts": self.max_attempts,
            "cost_per_attempt": self.cost_per_attempt,
            "timeout_seconds": self.timeout_seconds,
            "priority": self.priority,
            "description": self.description,
            "max_cost": self.max_cost
        }


@dataclass
class EscalationRule:
    """Rules for escalating from one tier to the next."""
    from_tier: int
    to_tier: Optional[int]  # None means human review
    after_attempts: int
    on_error_types: List[str]
    action: str = "escalate"
    notify: bool = False


class ModelTierLoader:
    """
    Load and manage model tier configuration.

    Features:
    - Load from JSON config file
    - Default configuration fallback
    - Tier lookup by number, model name, or alias
    - Escalation rule management
    """

    DEFAULT_CONFIG_PATH = "config/model_tiers.json"

    # Default configuration if file not found
    DEFAULT_TIERS = {
        1: TierConfig(
            tier_number=1,
            name="Tier 1 - Fast & Cheap",
            provider="gemini",
            model="gemini-2.5-flash",
            max_attempts=20,
            cost_per_attempt=0.01,
            timeout_seconds=120,
            priority=1,
            description="Gemini Flash for quick, cost-effective attempts"
        ),
        2: TierConfig(
            tier_number=2,
            name="Tier 2 - Balanced",
            provider="anthropic",
            model="claude-sonnet-4",
            max_attempts=10,
            cost_per_attempt=0.15,
            timeout_seconds=180,
            priority=2,
            description="Claude Sonnet for balanced capability and cost"
        ),
        3: TierConfig(
            tier_number=3,
            name="Tier 3 - Maximum Capability",
            provider="anthropic",
            model="claude-opus-4-5",
            max_attempts=10,
            cost_per_attempt=0.75,
            timeout_seconds=300,
            priority=3,
            description="Claude Opus for complex, high-stakes tasks"
        )
    }

    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize ModelTierLoader.

        Args:
            config_path: Path to model_tiers.json. Uses default if not provided.
        """
        self.config_path = config_path or self.DEFAULT_CONFIG_PATH
        self.tiers: Dict[int, TierConfig] = {}
        self.escalation_rules: List[EscalationRule] = []
        self.model_aliases: Dict[str, str] = {}
        self.budget_allocation: Dict[str, float] = {}
        self._raw_config: Dict[str, Any] = {}

        self._load_config()

    def _load_config(self) -> None:
        """Load configuration from file or use defaults."""
        config_file = Path(self.config_path)

        if config_file.exists():
            try:
                with open(config_file) as f:
                    self._raw_config = json.load(f)
                self._parse_config(self._raw_config)
                logger.info(f"Loaded model tier config from {config_file}")
            except Exception as e:
                logger.warning(f"Failed to load config from {config_file}: {e}. Using defaults.")
                self._use_defaults()
        else:
            logger.info(f"Config file not found at {config_file}. Using defaults.")
            self._use_defaults()

    def _parse_config(self, config: Dict[str, Any]) -> None:
        """Parse JSON configuration into typed objects."""
        # Parse tiers
        tiers_config = config.get("tiers", {})
        for tier_num, tier_data in tiers_config.items():
            tier_number = int(tier_num)
            self.tiers[tier_number] = TierConfig(
                tier_number=tier_number,
                name=tier_data.get("name", f"Tier {tier_number}"),
                provider=tier_data["provider"],
                model=tier_data["model"],
                max_attempts=tier_data["max_attempts"],
                cost_per_attempt=tier_data["cost_per_attempt"],
                timeout_seconds=tier_data.get("timeout_seconds", 120),
                priority=tier_data.get("priority", tier_number),
                description=tier_data.get("description", "")
            )

        # Parse escalation rules
        escalation_config = config.get("escalation_rules", {})
        if "tier_1_to_2" in escalation_config:
            rule = escalation_config["tier_1_to_2"]
            self.escalation_rules.append(EscalationRule(
                from_tier=1,
                to_tier=2,
                after_attempts=rule.get("after_attempts", 20),
                on_error_types=rule.get("on_error_types", [])
            ))
        if "tier_2_to_3" in escalation_config:
            rule = escalation_config["tier_2_to_3"]
            self.escalation_rules.append(EscalationRule(
                from_tier=2,
                to_tier=3,
                after_attempts=rule.get("after_attempts", 10),
                on_error_types=rule.get("on_error_types", [])
            ))
        if "tier_3_failure" in escalation_config:
            rule = escalation_config["tier_3_failure"]
            self.escalation_rules.append(EscalationRule(
                from_tier=3,
                to_tier=None,
                after_attempts=self.tiers[3].max_attempts if 3 in self.tiers else 10,
                on_error_types=[],
                action=rule.get("action", "mark_for_human_review"),
                notify=rule.get("notify", True)
            ))

        # Parse aliases and budget
        self.model_aliases = config.get("model_aliases", {})
        self.budget_allocation = config.get("budget_allocation", {})

    def _use_defaults(self) -> None:
        """Use default tier configuration."""
        self.tiers = self.DEFAULT_TIERS.copy()
        self.escalation_rules = [
            EscalationRule(1, 2, 20, ["timeout", "context_overflow", "persistent_failure"]),
            EscalationRule(2, 3, 10, ["timeout", "context_overflow", "persistent_failure"]),
            EscalationRule(3, None, 10, [], "mark_for_human_review", True)
        ]
        self.model_aliases = {
            "gemini-flash": "gemini-2.5-flash",
            "gemini-2.0-flash": "gemini-2.5-flash",
            "claude-sonnet": "claude-sonnet-4",
            "claude-opus": "claude-opus-4-5"
        }
        self.budget_allocation = {
            "gemini_budget": 300.0,
            "anthropic_budget": 30.0
        }

    def get_tier(self, tier_number: int) -> Optional[TierConfig]:
        """
        Get tier configuration by number.

        Args:
            tier_number: 1, 2, or 3

        Returns:
            TierConfig or None if not found
        """
        return self.tiers.get(tier_number)

    def get_tier_for_model(self, model: str) -> Optional[TierConfig]:
        """
        Get tier configuration for a model name.

        Args:
            model: Model name or alias

        Returns:
            TierConfig or None if not found
        """
        # Resolve alias if exists
        resolved_model = self.model_aliases.get(model, model)

        # Search tiers
        for tier in self.tiers.values():
            if tier.model == resolved_model or tier.model == model:
                return tier
        return None

    def get_next_tier(self, current_tier: int) -> Optional[TierConfig]:
        """
        Get the next tier for escalation.

        Args:
            current_tier: Current tier number

        Returns:
            Next TierConfig or None if at highest tier
        """
        next_tier_num = current_tier + 1
        return self.tiers.get(next_tier_num)

    def should_escalate(self,
                       current_tier: int,
                       attempt: int,
                       error_type: Optional[str] = None) -> bool:
        """
        Check if escalation should occur.

        Args:
            current_tier: Current tier number
            attempt: Current attempt number
            error_type: Type of error encountered (optional)

        Returns:
            True if should escalate to next tier
        """
        tier_config = self.get_tier(current_tier)
        if not tier_config:
            return False

        # Check if max attempts reached
        if attempt >= tier_config.max_attempts:
            return True

        # Check error type escalation rules
        for rule in self.escalation_rules:
            if rule.from_tier == current_tier:
                if error_type and error_type in rule.on_error_types:
                    return True

        return False

    def get_all_tiers(self) -> List[TierConfig]:
        """Get all tier configurations in order."""
        return [self.tiers[i] for i in sorted(self.tiers.keys())]

    def calculate_max_budget(self) -> float:
        """Calculate maximum possible budget usage across all tiers."""
        return sum(tier.max_cost for tier in self.tiers.values())

    def get_summary(self) -> Dict[str, Any]:
        """Get summary of tier configuration."""
        return {
            "tiers": {
                tier_num: tier.to_dict()
                for tier_num, tier in self.tiers.items()
            },
            "total_max_attempts": sum(t.max_attempts for t in self.tiers.values()),
            "max_budget_usage": self.calculate_max_budget(),
            "escalation_rules": [
                {
                    "from": r.from_tier,
                    "to": r.to_tier,
                    "after_attempts": r.after_attempts,
                    "error_types": r.on_error_types
                }
                for r in self.escalation_rules
            ],
            "model_aliases": self.model_aliases
        }


# Module-level function for convenience
_tier_loader: Optional[ModelTierLoader] = None


def load_model_tiers(config_path: Optional[str] = None) -> ModelTierLoader:
    """
    Load model tier configuration.

    Args:
        config_path: Optional path to config file

    Returns:
        ModelTierLoader instance
    """
    global _tier_loader
    if _tier_loader is None or config_path:
        _tier_loader = ModelTierLoader(config_path)
    return _tier_loader


def get_tier(tier_number: int) -> Optional[TierConfig]:
    """Convenience function to get a tier configuration."""
    return load_model_tiers().get_tier(tier_number)


if __name__ == "__main__":
    # Test the ModelTierLoader
    logging.basicConfig(level=logging.INFO)

    loader = load_model_tiers()

    print("Model Tier Summary:")
    print(json.dumps(loader.get_summary(), indent=2))

    print("\n\nTier Details:")
    for tier in loader.get_all_tiers():
        print(f"\n{tier.name}:")
        print(f"  Model: {tier.model}")
        print(f"  Provider: {tier.provider}")
        print(f"  Max Attempts: {tier.max_attempts}")
        print(f"  Cost/Attempt: ${tier.cost_per_attempt}")
        print(f"  Max Cost: ${tier.max_cost:.2f}")
