"""
PM-014: Performance Degradation Detector
Detect context rot symptoms and performance issues in Genesis.

Acceptance Criteria:
- [x] GIVEN metrics WHEN analyzed THEN baseline established
- [x] AND if duration +50% THEN flag review
- [x] AND if tokens +2x THEN flag potential issue

Dependencies: PM-004
"""

import os
import json
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List, Tuple
from dataclasses import dataclass, field
from enum import Enum
from statistics import mean, stdev

from core.perf_metrics import PerformanceMetricsCollector, get_metrics_collector, SessionMetrics

logger = logging.getLogger(__name__)


class DegradationType(Enum):
    """Types of performance degradation."""
    NONE = "none"
    DURATION_INCREASE = "duration_increase"
    TOKEN_BLOAT = "token_bloat"
    SUCCESS_RATE_DROP = "success_rate_drop"
    ERROR_RATE_SPIKE = "error_rate_spike"
    COST_SPIKE = "cost_spike"
    CONTEXT_ROT = "context_rot"  # Combined degradation pattern


@dataclass
class DegradationAlert:
    """Alert for detected performance degradation."""
    degradation_type: DegradationType
    severity: str  # "low", "medium", "high", "critical"
    model: str
    message: str

    # Metrics
    current_value: float
    baseline_value: float
    deviation_percentage: float

    # Recommendations
    recommendation: str

    # Metadata
    timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
    sample_size: int = 0

    def to_dict(self) -> Dict[str, Any]:
        return {
            "degradation_type": self.degradation_type.value,
            "severity": self.severity,
            "model": self.model,
            "message": self.message,
            "current_value": self.current_value,
            "baseline_value": self.baseline_value,
            "deviation_percentage": self.deviation_percentage,
            "recommendation": self.recommendation,
            "timestamp": self.timestamp,
            "sample_size": self.sample_size
        }


@dataclass
class PerformanceBaseline:
    """Baseline metrics for a model."""
    model: str
    avg_duration_ms: float
    avg_tokens: float
    avg_cost: float
    success_rate: float
    sample_size: int
    calculated_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())

    # Standard deviations (for anomaly detection)
    duration_stdev: float = 0.0
    tokens_stdev: float = 0.0

    def to_dict(self) -> Dict[str, Any]:
        return {
            "model": self.model,
            "avg_duration_ms": self.avg_duration_ms,
            "avg_tokens": self.avg_tokens,
            "avg_cost": self.avg_cost,
            "success_rate": self.success_rate,
            "sample_size": self.sample_size,
            "calculated_at": self.calculated_at,
            "duration_stdev": self.duration_stdev,
            "tokens_stdev": self.tokens_stdev
        }


class PerformanceDegradationDetector:
    """
    Detect context rot and performance degradation.

    Detection Thresholds:
    - Duration +50% from baseline: Flag for review
    - Tokens +100% (2x): Flag potential context rot
    - Success rate -20%: Flag reliability issue
    - Error rate +50%: Flag system issue

    Baseline:
    - Established from last N successful sessions
    - Updated periodically to adapt to system changes
    """

    # Degradation thresholds
    DURATION_THRESHOLD = 0.50  # +50%
    TOKEN_THRESHOLD = 1.00     # +100% (2x)
    SUCCESS_RATE_THRESHOLD = 0.20  # -20%
    ERROR_RATE_THRESHOLD = 0.50  # +50%
    COST_THRESHOLD = 0.75  # +75%

    # Baseline parameters
    DEFAULT_BASELINE_WINDOW = 50  # Sessions
    MIN_SAMPLES_FOR_BASELINE = 10

    def __init__(self,
                 metrics_collector: Optional[PerformanceMetricsCollector] = None,
                 baseline_window: int = 50):
        """
        Initialize PerformanceDegradationDetector.

        Args:
            metrics_collector: PerformanceMetricsCollector instance
            baseline_window: Number of sessions for baseline calculation
        """
        self.metrics_collector = metrics_collector or get_metrics_collector()
        self.baseline_window = baseline_window

        # Cached baselines per model
        self._baselines: Dict[str, PerformanceBaseline] = {}
        self._baseline_updated: Dict[str, datetime] = {}

        # Alert history
        self._alerts: List[DegradationAlert] = []

    def calculate_baseline(self, model: str, force_recalculate: bool = False) -> Optional[PerformanceBaseline]:
        """
        Calculate or retrieve baseline for a model.

        Args:
            model: Model name
            force_recalculate: Force recalculation even if cached

        Returns:
            PerformanceBaseline or None if insufficient data
        """
        # Check cache
        if not force_recalculate and model in self._baselines:
            last_update = self._baseline_updated.get(model)
            if last_update and datetime.utcnow() - last_update < timedelta(hours=1):
                return self._baselines[model]

        # Get recent metrics
        metrics = self.metrics_collector.get_recent_metrics(count=self.baseline_window, model=model)

        if len(metrics) < self.MIN_SAMPLES_FOR_BASELINE:
            logger.debug(f"Insufficient samples for baseline: {len(metrics)}/{self.MIN_SAMPLES_FOR_BASELINE}")
            return None

        # Filter to successful sessions only
        successful = [m for m in metrics if m.success]

        if len(successful) < self.MIN_SAMPLES_FOR_BASELINE // 2:
            logger.warning(f"Too few successful sessions for baseline: {len(successful)}")
            return None

        # Calculate statistics
        durations = [m.duration_ms for m in successful]
        tokens = [m.total_tokens for m in successful]
        costs = [m.cost for m in successful]

        baseline = PerformanceBaseline(
            model=model,
            avg_duration_ms=mean(durations),
            avg_tokens=mean(tokens),
            avg_cost=mean(costs) if costs else 0.0,
            success_rate=len(successful) / len(metrics) * 100,
            sample_size=len(metrics),
            duration_stdev=stdev(durations) if len(durations) > 1 else 0.0,
            tokens_stdev=stdev(tokens) if len(tokens) > 1 else 0.0
        )

        # Cache
        self._baselines[model] = baseline
        self._baseline_updated[model] = datetime.utcnow()

        logger.info(f"Calculated baseline for {model}: duration={baseline.avg_duration_ms:.0f}ms, "
                   f"tokens={baseline.avg_tokens:.0f}, success={baseline.success_rate:.1f}%")

        return baseline

    def analyze_session(self, session: SessionMetrics) -> List[DegradationAlert]:
        """
        Analyze a single session for degradation.

        Args:
            session: SessionMetrics to analyze

        Returns:
            List of DegradationAlerts (may be empty)
        """
        alerts = []

        # Get baseline
        baseline = self.calculate_baseline(session.model)
        if not baseline:
            return alerts

        # Check duration degradation
        duration_alert = self._check_duration(session, baseline)
        if duration_alert:
            alerts.append(duration_alert)

        # Check token bloat
        token_alert = self._check_tokens(session, baseline)
        if token_alert:
            alerts.append(token_alert)

        # Check cost spike
        cost_alert = self._check_cost(session, baseline)
        if cost_alert:
            alerts.append(cost_alert)

        # Store alerts
        self._alerts.extend(alerts)

        return alerts

    def _check_duration(self,
                       session: SessionMetrics,
                       baseline: PerformanceBaseline) -> Optional[DegradationAlert]:
        """Check for duration degradation."""
        if baseline.avg_duration_ms == 0:
            return None

        deviation = (session.duration_ms - baseline.avg_duration_ms) / baseline.avg_duration_ms

        if deviation >= self.DURATION_THRESHOLD:
            severity = self._calculate_severity(deviation, [0.5, 1.0, 2.0])

            return DegradationAlert(
                degradation_type=DegradationType.DURATION_INCREASE,
                severity=severity,
                model=session.model,
                message=f"Session duration {deviation*100:.0f}% above baseline",
                current_value=session.duration_ms,
                baseline_value=baseline.avg_duration_ms,
                deviation_percentage=deviation * 100,
                recommendation="Consider spawning a fresh session to avoid context accumulation",
                sample_size=baseline.sample_size
            )

        return None

    def _check_tokens(self,
                     session: SessionMetrics,
                     baseline: PerformanceBaseline) -> Optional[DegradationAlert]:
        """Check for token bloat (context rot indicator)."""
        if baseline.avg_tokens == 0:
            return None

        deviation = (session.total_tokens - baseline.avg_tokens) / baseline.avg_tokens

        if deviation >= self.TOKEN_THRESHOLD:
            severity = self._calculate_severity(deviation, [1.0, 2.0, 4.0])

            return DegradationAlert(
                degradation_type=DegradationType.TOKEN_BLOAT,
                severity=severity,
                model=session.model,
                message=f"Token usage {deviation*100:.0f}% above baseline (potential context rot)",
                current_value=session.total_tokens,
                baseline_value=baseline.avg_tokens,
                deviation_percentage=deviation * 100,
                recommendation="Context rot detected. Force fresh session and truncate context.",
                sample_size=baseline.sample_size
            )

        return None

    def _check_cost(self,
                   session: SessionMetrics,
                   baseline: PerformanceBaseline) -> Optional[DegradationAlert]:
        """Check for cost spikes."""
        if baseline.avg_cost == 0:
            return None

        deviation = (session.cost - baseline.avg_cost) / baseline.avg_cost

        if deviation >= self.COST_THRESHOLD:
            severity = self._calculate_severity(deviation, [0.75, 1.5, 3.0])

            return DegradationAlert(
                degradation_type=DegradationType.COST_SPIKE,
                severity=severity,
                model=session.model,
                message=f"Session cost {deviation*100:.0f}% above baseline",
                current_value=session.cost,
                baseline_value=baseline.avg_cost,
                deviation_percentage=deviation * 100,
                recommendation="Investigate token usage. Consider model downgrade for simple tasks.",
                sample_size=baseline.sample_size
            )

        return None

    def _calculate_severity(self,
                           deviation: float,
                           thresholds: List[float]) -> str:
        """Calculate severity based on deviation."""
        if deviation >= thresholds[2]:
            return "critical"
        elif deviation >= thresholds[1]:
            return "high"
        elif deviation >= thresholds[0]:
            return "medium"
        return "low"

    def analyze_recent_performance(self,
                                  model: Optional[str] = None,
                                  window: int = 20) -> Dict[str, Any]:
        """
        Analyze recent performance for degradation trends.

        Args:
            model: Model to analyze (all if None)
            window: Number of recent sessions to analyze

        Returns:
            Analysis report dictionary
        """
        metrics = self.metrics_collector.get_recent_metrics(count=window, model=model)

        if not metrics:
            return {"status": "insufficient_data", "sample_size": 0}

        # Group by model
        by_model: Dict[str, List[SessionMetrics]] = {}
        for m in metrics:
            if m.model not in by_model:
                by_model[m.model] = []
            by_model[m.model].append(m)

        reports = {}
        for model_name, model_metrics in by_model.items():
            baseline = self.calculate_baseline(model_name)

            if not baseline:
                reports[model_name] = {"status": "insufficient_baseline"}
                continue

            successful = [m for m in model_metrics if m.success]
            recent_success_rate = len(successful) / len(model_metrics) * 100 if model_metrics else 0

            # Calculate recent averages
            if successful:
                recent_avg_duration = mean([m.duration_ms for m in successful])
                recent_avg_tokens = mean([m.total_tokens for m in successful])
            else:
                recent_avg_duration = 0
                recent_avg_tokens = 0

            # Check for trends
            reports[model_name] = {
                "baseline_duration_ms": baseline.avg_duration_ms,
                "recent_duration_ms": recent_avg_duration,
                "duration_deviation": (recent_avg_duration - baseline.avg_duration_ms) / baseline.avg_duration_ms * 100 if baseline.avg_duration_ms else 0,
                "baseline_tokens": baseline.avg_tokens,
                "recent_tokens": recent_avg_tokens,
                "token_deviation": (recent_avg_tokens - baseline.avg_tokens) / baseline.avg_tokens * 100 if baseline.avg_tokens else 0,
                "baseline_success_rate": baseline.success_rate,
                "recent_success_rate": recent_success_rate,
                "success_rate_change": recent_success_rate - baseline.success_rate,
                "sample_size": len(model_metrics),
                "degradation_detected": False
            }

            # Flag if degradation detected
            if reports[model_name]["duration_deviation"] > 50:
                reports[model_name]["degradation_detected"] = True
                reports[model_name]["degradation_type"] = "duration"
            if reports[model_name]["token_deviation"] > 100:
                reports[model_name]["degradation_detected"] = True
                reports[model_name]["degradation_type"] = "token_bloat"

        return {
            "analyzed_at": datetime.utcnow().isoformat(),
            "total_sessions": len(metrics),
            "models_analyzed": len(by_model),
            "by_model": reports
        }

    def get_alerts(self,
                  model: Optional[str] = None,
                  severity: Optional[str] = None) -> List[DegradationAlert]:
        """Get degradation alerts with optional filtering."""
        alerts = self._alerts

        if model:
            alerts = [a for a in alerts if a.model == model]
        if severity:
            alerts = [a for a in alerts if a.severity == severity]

        return alerts

    def get_baselines(self) -> Dict[str, PerformanceBaseline]:
        """Get all calculated baselines."""
        return self._baselines.copy()


# Singleton instance
_detector: Optional[PerformanceDegradationDetector] = None


def get_perf_degradation_detector() -> PerformanceDegradationDetector:
    """Get or create global PerformanceDegradationDetector instance."""
    global _detector
    if _detector is None:
        _detector = PerformanceDegradationDetector()
    return _detector


if __name__ == "__main__":
    # Test the PerformanceDegradationDetector
    logging.basicConfig(level=logging.INFO)

    detector = PerformanceDegradationDetector()

    # Analyze recent performance
    print("Performance Analysis:")
    analysis = detector.analyze_recent_performance()
    print(json.dumps(analysis, indent=2, default=str))

    # Show baselines
    print("\nBaselines:")
    for model, baseline in detector.get_baselines().items():
        print(f"\n{model}:")
        print(json.dumps(baseline.to_dict(), indent=2))
