"""
PM-004: Performance Metrics Collector
Track session performance to detect degradation in Genesis.

Acceptance Criteria:
- [x] GIVEN session completes WHEN logged THEN metrics stored
- [x] AND includes: duration, tokens, success/failure
- [x] AND stored in Redis sorted set `genesis:perf_metrics`

Dependencies: None
"""

import os
import json
import logging
import time
from datetime import datetime
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, asdict, field
from contextlib import contextmanager

try:
    import redis
except ImportError:
    redis = None

logger = logging.getLogger(__name__)


@dataclass
class SessionMetrics:
    """Metrics for a single session execution."""
    session_id: str
    task_id: Optional[str]
    model: str
    tier: int
    attempt: int

    # Timing
    duration_ms: int
    start_time: str
    end_time: str

    # Tokens
    input_tokens: int
    output_tokens: int
    total_tokens: int

    # Result
    success: bool
    error_type: Optional[str] = None
    error_message: Optional[str] = None

    # Cost
    cost: float = 0.0

    # Additional context
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)

    def to_json(self) -> str:
        return json.dumps(self.to_dict())

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "SessionMetrics":
        return cls(**data)

    @classmethod
    def from_json(cls, json_str: str) -> "SessionMetrics":
        return cls.from_dict(json.loads(json_str))


class PerformanceMetricsCollector:
    """
    Collect and store session performance metrics.

    Features:
    - Track duration, tokens, success/failure
    - Store in Redis sorted set by timestamp
    - Context manager for automatic timing
    - Performance baseline calculation
    """

    REDIS_KEY = "genesis:perf_metrics"
    REDIS_KEY_PREFIX = "genesis:perf_metrics:"
    MAX_STORED_METRICS = 10000  # Keep last N metrics

    def __init__(self, redis_url: Optional[str] = None):
        """
        Initialize PerformanceMetricsCollector.

        Args:
            redis_url: Redis connection URL. Defaults to env REDIS_URL.
        """
        self.redis_url = redis_url or os.getenv("REDIS_URL", "redis://localhost:6379")
        self._redis: Optional[Any] = None
        self._session_counter = 0

    @property
    def redis_client(self) -> Optional[Any]:
        """Lazy Redis client initialization."""
        if self._redis is None and redis is not None:
            try:
                self._redis = redis.from_url(self.redis_url)
                self._redis.ping()
            except Exception as e:
                logger.warning(f"Redis connection failed: {e}")
                self._redis = None
        return self._redis

    def _generate_session_id(self) -> str:
        """Generate unique session ID."""
        self._session_counter += 1
        timestamp = datetime.utcnow().strftime("%Y%m%d%H%M%S%f")
        return f"sess_{timestamp}_{self._session_counter}"

    def log_session(self,
                   model: str,
                   tier: int,
                   attempt: int,
                   duration_ms: int,
                   input_tokens: int,
                   output_tokens: int,
                   success: bool,
                   task_id: Optional[str] = None,
                   session_id: Optional[str] = None,
                   cost: float = 0.0,
                   error_type: Optional[str] = None,
                   error_message: Optional[str] = None,
                   metadata: Optional[Dict[str, Any]] = None) -> SessionMetrics:
        """
        Log session metrics.

        Args:
            model: Model used
            tier: Execution tier (1, 2, or 3)
            attempt: Attempt number within tier
            duration_ms: Duration in milliseconds
            input_tokens: Input token count
            output_tokens: Output token count
            success: Whether session succeeded
            task_id: Task identifier
            session_id: Session identifier (auto-generated if not provided)
            cost: Cost of the session
            error_type: Error type if failed
            error_message: Error message if failed
            metadata: Additional metadata

        Returns:
            SessionMetrics object
        """
        now = datetime.utcnow()

        metrics = SessionMetrics(
            session_id=session_id or self._generate_session_id(),
            task_id=task_id,
            model=model,
            tier=tier,
            attempt=attempt,
            duration_ms=duration_ms,
            start_time=(now.timestamp() - duration_ms / 1000).__str__(),
            end_time=now.isoformat(),
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            total_tokens=input_tokens + output_tokens,
            success=success,
            cost=cost,
            error_type=error_type,
            error_message=error_message,
            metadata=metadata or {}
        )

        # Store in Redis
        self._store_metrics(metrics)

        logger.debug(
            f"Logged session: {metrics.session_id} "
            f"model={model} tier={tier} attempt={attempt} "
            f"duration={duration_ms}ms tokens={metrics.total_tokens} "
            f"success={success}"
        )

        return metrics

    def _store_metrics(self, metrics: SessionMetrics) -> None:
        """Store metrics in Redis sorted set."""
        if self.redis_client:
            try:
                timestamp = datetime.utcnow().timestamp()

                # Store in sorted set (by timestamp)
                self.redis_client.zadd(
                    self.REDIS_KEY,
                    {metrics.to_json(): timestamp}
                )

                # Also store by model for per-model queries
                model_key = f"{self.REDIS_KEY_PREFIX}{metrics.model}"
                self.redis_client.zadd(
                    model_key,
                    {metrics.to_json(): timestamp}
                )

                # Trim to max size
                self.redis_client.zremrangebyrank(self.REDIS_KEY, 0, -self.MAX_STORED_METRICS - 1)

            except Exception as e:
                logger.warning(f"Failed to store metrics in Redis: {e}")

    def get_recent_metrics(self,
                          count: int = 100,
                          model: Optional[str] = None) -> List[SessionMetrics]:
        """
        Get recent metrics.

        Args:
            count: Number of metrics to retrieve
            model: Filter by model (optional)

        Returns:
            List of SessionMetrics
        """
        if not self.redis_client:
            return []

        try:
            key = f"{self.REDIS_KEY_PREFIX}{model}" if model else self.REDIS_KEY
            raw_metrics = self.redis_client.zrevrange(key, 0, count - 1)

            return [
                SessionMetrics.from_json(m.decode() if isinstance(m, bytes) else m)
                for m in raw_metrics
            ]
        except Exception as e:
            logger.warning(f"Failed to retrieve metrics: {e}")
            return []

    def get_baseline(self,
                    model: str,
                    window_size: int = 50) -> Dict[str, float]:
        """
        Calculate performance baseline for a model.

        Args:
            model: Model name
            window_size: Number of recent sessions to use

        Returns:
            Baseline metrics dict
        """
        metrics = self.get_recent_metrics(count=window_size, model=model)

        if not metrics:
            return {
                "avg_duration_ms": 0,
                "avg_tokens": 0,
                "success_rate": 0,
                "sample_size": 0
            }

        successful = [m for m in metrics if m.success]
        total_duration = sum(m.duration_ms for m in successful) if successful else 0
        total_tokens = sum(m.total_tokens for m in successful) if successful else 0

        return {
            "avg_duration_ms": total_duration / len(successful) if successful else 0,
            "avg_tokens": total_tokens / len(successful) if successful else 0,
            "success_rate": len(successful) / len(metrics) * 100 if metrics else 0,
            "sample_size": len(metrics),
            "model": model
        }

    def get_tier_statistics(self) -> Dict[str, Dict[str, Any]]:
        """
        Get statistics per tier.

        Returns:
            Dict with tier stats
        """
        metrics = self.get_recent_metrics(count=1000)

        tier_stats = {1: [], 2: [], 3: []}
        for m in metrics:
            if m.tier in tier_stats:
                tier_stats[m.tier].append(m)

        result = {}
        for tier, tier_metrics in tier_stats.items():
            if tier_metrics:
                successful = [m for m in tier_metrics if m.success]
                result[f"tier_{tier}"] = {
                    "total_sessions": len(tier_metrics),
                    "successful_sessions": len(successful),
                    "success_rate": len(successful) / len(tier_metrics) * 100,
                    "avg_duration_ms": sum(m.duration_ms for m in tier_metrics) / len(tier_metrics),
                    "avg_tokens": sum(m.total_tokens for m in tier_metrics) / len(tier_metrics),
                    "total_cost": sum(m.cost for m in tier_metrics)
                }
        return result

    @contextmanager
    def track_session(self,
                     model: str,
                     tier: int,
                     attempt: int,
                     task_id: Optional[str] = None,
                     **kwargs):
        """
        Context manager for tracking session metrics.

        Usage:
            with collector.track_session("gemini-flash", 1, 1) as tracker:
                # ... execute session ...
                tracker.set_tokens(1000, 500)
                tracker.set_success(True)

        Args:
            model: Model name
            tier: Execution tier
            attempt: Attempt number
            task_id: Task identifier
            **kwargs: Additional metadata
        """

        class SessionTracker:
            def __init__(self):
                self.input_tokens = 0
                self.output_tokens = 0
                self.success = False
                self.error_type = None
                self.error_message = None
                self.cost = 0.0
                self.metadata = kwargs

            def set_tokens(self, input_tokens: int, output_tokens: int):
                self.input_tokens = input_tokens
                self.output_tokens = output_tokens

            def set_success(self, success: bool):
                self.success = success

            def set_error(self, error_type: str, error_message: str):
                self.error_type = error_type
                self.error_message = error_message
                self.success = False

            def set_cost(self, cost: float):
                self.cost = cost

        tracker = SessionTracker()
        start_time = time.time()

        try:
            yield tracker
        finally:
            duration_ms = int((time.time() - start_time) * 1000)
            self.log_session(
                model=model,
                tier=tier,
                attempt=attempt,
                duration_ms=duration_ms,
                input_tokens=tracker.input_tokens,
                output_tokens=tracker.output_tokens,
                success=tracker.success,
                task_id=task_id,
                cost=tracker.cost,
                error_type=tracker.error_type,
                error_message=tracker.error_message,
                metadata=tracker.metadata
            )


# Singleton instance
_metrics_collector: Optional[PerformanceMetricsCollector] = None


def get_metrics_collector() -> PerformanceMetricsCollector:
    """Get or create global PerformanceMetricsCollector instance."""
    global _metrics_collector
    if _metrics_collector is None:
        _metrics_collector = PerformanceMetricsCollector()
    return _metrics_collector


if __name__ == "__main__":
    # Test the PerformanceMetricsCollector
    logging.basicConfig(level=logging.INFO)

    collector = PerformanceMetricsCollector()

    # Log some test sessions
    collector.log_session(
        model="gemini-flash",
        tier=1,
        attempt=1,
        duration_ms=1500,
        input_tokens=1000,
        output_tokens=500,
        success=True,
        task_id="test-001"
    )

    collector.log_session(
        model="claude-sonnet",
        tier=2,
        attempt=1,
        duration_ms=3000,
        input_tokens=2000,
        output_tokens=1500,
        success=True,
        task_id="test-002"
    )

    # Test context manager
    with collector.track_session("gemini-flash", 1, 2, task_id="test-003") as tracker:
        time.sleep(0.1)  # Simulate work
        tracker.set_tokens(500, 200)
        tracker.set_success(True)

    print("\nTier Statistics:")
    print(json.dumps(collector.get_tier_statistics(), indent=2))
