"""
PM-013: Budget Alert System
Alert when budget thresholds crossed for Genesis.

Acceptance Criteria:
- [x] GIVEN budget check WHEN <20% THEN log warning
- [x] AND when <10% THEN alert in HANDOFF.md
- [x] AND when exhausted THEN halt new executions

Dependencies: PM-002
"""

import os
import json
import logging
from datetime import datetime
from typing import Optional, Dict, Any, List, Callable
from dataclasses import dataclass, field
from pathlib import Path
from enum import Enum

from core.api_token_manager import TokenManager, get_token_manager, BudgetExhaustedError

logger = logging.getLogger(__name__)


class AlertLevel(Enum):
    """Alert severity levels."""
    INFO = "info"
    WARNING = "warning"
    CRITICAL = "critical"
    HALT = "halt"


@dataclass
class BudgetAlert:
    """A budget alert event."""
    alert_id: str
    level: AlertLevel
    provider: str
    message: str
    percentage_remaining: float
    amount_remaining: float
    threshold_crossed: float
    timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
    acknowledged: bool = False

    def to_dict(self) -> Dict[str, Any]:
        return {
            "alert_id": self.alert_id,
            "level": self.level.value,
            "provider": self.provider,
            "message": self.message,
            "percentage_remaining": self.percentage_remaining,
            "amount_remaining": self.amount_remaining,
            "threshold_crossed": self.threshold_crossed,
            "timestamp": self.timestamp,
            "acknowledged": self.acknowledged
        }


class BudgetAlertSystem:
    """
    Monitor budgets and generate alerts at thresholds.

    Thresholds:
    - <20%: Log warning
    - <10%: Alert in HANDOFF.md
    - <5%: Critical alert
    - 0%: Halt new executions
    """

    DEFAULT_THRESHOLDS = {
        20.0: AlertLevel.WARNING,
        10.0: AlertLevel.CRITICAL,
        5.0: AlertLevel.CRITICAL,
        0.0: AlertLevel.HALT
    }

    def __init__(self,
                 token_manager: Optional[TokenManager] = None,
                 handoff_path: str = "HANDOFF.md",
                 thresholds: Optional[Dict[float, AlertLevel]] = None):
        """
        Initialize BudgetAlertSystem.

        Args:
            token_manager: TokenManager instance
            handoff_path: Path to HANDOFF.md
            thresholds: Custom threshold->level mapping
        """
        self.token_manager = token_manager or get_token_manager()
        self.handoff_path = Path(handoff_path)
        self.thresholds = thresholds or self.DEFAULT_THRESHOLDS

        # Track which thresholds have been crossed
        self._crossed_thresholds: Dict[str, set] = {
            "anthropic": set(),
            "gemini": set()
        }

        # Alert history
        self._alerts: List[BudgetAlert] = []
        self._alert_counter = 0

        # Alert callbacks
        self._callbacks: List[Callable[[BudgetAlert], None]] = []

        # Execution halt flag
        self._execution_halted: Dict[str, bool] = {
            "anthropic": False,
            "gemini": False
        }

    def add_callback(self, callback: Callable[[BudgetAlert], None]) -> None:
        """Add a callback for alert notifications."""
        self._callbacks.append(callback)

    def _generate_alert_id(self, provider: str) -> str:
        """Generate unique alert ID."""
        self._alert_counter += 1
        timestamp = datetime.utcnow().strftime("%Y%m%d%H%M%S")
        return f"alert_{provider}_{timestamp}_{self._alert_counter}"

    def check_budget(self, provider: str) -> Optional[BudgetAlert]:
        """
        Check budget for a provider and generate alert if threshold crossed.

        Args:
            provider: 'anthropic' or 'gemini'

        Returns:
            BudgetAlert if threshold crossed, None otherwise
        """
        percentage = self.token_manager.get_remaining_percentage(provider)
        remaining = self.token_manager.get_remaining(provider)

        # Check thresholds from highest to lowest
        for threshold, level in sorted(self.thresholds.items(), reverse=True):
            if percentage <= threshold and threshold not in self._crossed_thresholds[provider]:
                # Threshold crossed for first time
                self._crossed_thresholds[provider].add(threshold)

                alert = self._create_alert(provider, threshold, level, percentage, remaining)
                self._process_alert(alert)

                return alert

        return None

    def _create_alert(self,
                     provider: str,
                     threshold: float,
                     level: AlertLevel,
                     percentage: float,
                     remaining: float) -> BudgetAlert:
        """Create a budget alert."""
        if level == AlertLevel.HALT:
            message = f"EXECUTION HALTED: {provider.upper()} budget exhausted (${remaining:.2f} remaining)"
        elif level == AlertLevel.CRITICAL:
            message = f"CRITICAL: {provider.upper()} budget at {percentage:.1f}% (${remaining:.2f} remaining)"
        else:
            message = f"WARNING: {provider.upper()} budget at {percentage:.1f}% (${remaining:.2f} remaining)"

        return BudgetAlert(
            alert_id=self._generate_alert_id(provider),
            level=level,
            provider=provider,
            message=message,
            percentage_remaining=percentage,
            amount_remaining=remaining,
            threshold_crossed=threshold
        )

    def _process_alert(self, alert: BudgetAlert) -> None:
        """Process an alert - log, notify, and store."""
        # Store alert
        self._alerts.append(alert)

        # Log based on level
        if alert.level == AlertLevel.HALT:
            logger.error(alert.message)
            self._execution_halted[alert.provider] = True
        elif alert.level == AlertLevel.CRITICAL:
            logger.error(alert.message)
            self._write_to_handoff(alert)
        else:
            logger.warning(alert.message)

        # Call callbacks
        for callback in self._callbacks:
            try:
                callback(alert)
            except Exception as e:
                logger.warning(f"Alert callback error: {e}")

    def _write_to_handoff(self, alert: BudgetAlert) -> None:
        """Write critical alert to HANDOFF.md."""
        try:
            entry = f"""
## BUDGET ALERT - {alert.timestamp}
**Level:** {alert.level.value.upper()}
**Provider:** {alert.provider}
**Message:** {alert.message}
**Remaining:** ${alert.amount_remaining:.2f} ({alert.percentage_remaining:.1f}%)

"""
            with open(self.handoff_path, "a") as f:
                f.write(entry)

            logger.info(f"Budget alert written to HANDOFF.md")

        except Exception as e:
            logger.warning(f"Failed to write to HANDOFF.md: {e}")

    def is_execution_halted(self, provider: str) -> bool:
        """Check if execution is halted for a provider."""
        return self._execution_halted.get(provider, False)

    def can_execute(self, provider: str) -> bool:
        """Check if execution is allowed for a provider."""
        if self.is_execution_halted(provider):
            return False

        # Re-check in case budget was topped up
        remaining = self.token_manager.get_remaining(provider)
        if remaining <= 0:
            self._execution_halted[provider] = True
            return False

        return True

    def check_before_execution(self, provider: str, estimated_cost: float = 0.0) -> bool:
        """
        Check budget before executing a task.

        Args:
            provider: 'anthropic' or 'gemini'
            estimated_cost: Estimated cost of the task

        Returns:
            True if execution should proceed, False to halt

        Raises:
            BudgetExhaustedError if budget is exhausted
        """
        # Check for halted execution
        if not self.can_execute(provider):
            raise BudgetExhaustedError(provider, 0.0)

        # Check if we can afford this task
        remaining = self.token_manager.get_remaining(provider)
        if estimated_cost > 0 and remaining < estimated_cost:
            logger.warning(
                f"Insufficient budget for task: need ${estimated_cost:.4f}, "
                f"have ${remaining:.2f}"
            )
            return False

        # Run threshold check
        self.check_budget(provider)

        return True

    def check_all_budgets(self) -> List[BudgetAlert]:
        """Check all provider budgets and return any new alerts."""
        alerts = []
        for provider in ["anthropic", "gemini"]:
            alert = self.check_budget(provider)
            if alert:
                alerts.append(alert)
        return alerts

    def get_alerts(self,
                  provider: Optional[str] = None,
                  level: Optional[AlertLevel] = None,
                  unacknowledged_only: bool = False) -> List[BudgetAlert]:
        """
        Get alerts with optional filtering.

        Args:
            provider: Filter by provider
            level: Filter by level
            unacknowledged_only: Only return unacknowledged alerts

        Returns:
            List of matching alerts
        """
        alerts = self._alerts

        if provider:
            alerts = [a for a in alerts if a.provider == provider]
        if level:
            alerts = [a for a in alerts if a.level == level]
        if unacknowledged_only:
            alerts = [a for a in alerts if not a.acknowledged]

        return alerts

    def acknowledge_alert(self, alert_id: str) -> bool:
        """Acknowledge an alert."""
        for alert in self._alerts:
            if alert.alert_id == alert_id:
                alert.acknowledged = True
                return True
        return False

    def reset_provider(self, provider: str) -> None:
        """Reset alert state for a provider (e.g., after budget top-up)."""
        self._crossed_thresholds[provider] = set()
        self._execution_halted[provider] = False
        logger.info(f"Reset alert state for {provider}")

    def get_status(self) -> Dict[str, Any]:
        """Get current budget alert status."""
        return {
            "providers": {
                provider: {
                    "execution_halted": self._execution_halted.get(provider, False),
                    "thresholds_crossed": list(self._crossed_thresholds.get(provider, set())),
                    "percentage_remaining": self.token_manager.get_remaining_percentage(provider),
                    "amount_remaining": self.token_manager.get_remaining(provider),
                    "can_execute": self.can_execute(provider)
                }
                for provider in ["anthropic", "gemini"]
            },
            "total_alerts": len(self._alerts),
            "unacknowledged_alerts": len([a for a in self._alerts if not a.acknowledged]),
            "thresholds": {str(k): v.value for k, v in self.thresholds.items()}
        }


# Singleton instance
_alert_system: Optional[BudgetAlertSystem] = None


def get_budget_alert_system() -> BudgetAlertSystem:
    """Get or create global BudgetAlertSystem instance."""
    global _alert_system
    if _alert_system is None:
        _alert_system = BudgetAlertSystem()
    return _alert_system


def check_budget_before_execution(provider: str, estimated_cost: float = 0.0) -> bool:
    """Convenience function to check budget before execution."""
    return get_budget_alert_system().check_before_execution(provider, estimated_cost)


if __name__ == "__main__":
    # Test the BudgetAlertSystem
    logging.basicConfig(level=logging.INFO)

    system = BudgetAlertSystem()

    # Add a test callback
    def alert_callback(alert: BudgetAlert):
        print(f"CALLBACK: {alert.message}")

    system.add_callback(alert_callback)

    print("Budget Alert Status:")
    print(json.dumps(system.get_status(), indent=2))

    # Check all budgets
    alerts = system.check_all_budgets()
    if alerts:
        print(f"\nNew alerts: {len(alerts)}")
        for alert in alerts:
            print(json.dumps(alert.to_dict(), indent=2))
    else:
        print("\nNo alerts triggered (budgets healthy)")
