"""
PM-015: Success Rate Analytics
Track tier success rates for Genesis.

Acceptance Criteria:
- [x] GIVEN execution WHEN logged THEN tier success recorded
- [x] AND shows: Tier1 %, Tier2 %, Tier3 %, Failed %
- [x] AND stored in Redis `genesis:tier_success_rates`

Dependencies: PM-009
"""

import os
import json
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
from collections import defaultdict

try:
    import redis
except ImportError:
    redis = None

logger = logging.getLogger(__name__)


@dataclass
class TierSuccessRecord:
    """Record of a task completion by tier."""
    task_id: str
    final_tier: int
    success: bool
    total_attempts: int
    total_cost: float
    duration_ms: int
    timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())

    def to_dict(self) -> Dict[str, Any]:
        return {
            "task_id": self.task_id,
            "final_tier": self.final_tier,
            "success": self.success,
            "total_attempts": self.total_attempts,
            "total_cost": self.total_cost,
            "duration_ms": self.duration_ms,
            "timestamp": self.timestamp
        }

    def to_json(self) -> str:
        return json.dumps(self.to_dict())

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "TierSuccessRecord":
        return cls(**data)

    @classmethod
    def from_json(cls, json_str: str) -> "TierSuccessRecord":
        return cls.from_dict(json.loads(json_str))


@dataclass
class SuccessRateStats:
    """Statistics for success rates."""
    total_tasks: int
    successful_tasks: int
    failed_tasks: int
    overall_success_rate: float

    # By tier
    tier1_total: int
    tier1_success: int
    tier1_rate: float

    tier2_total: int
    tier2_success: int
    tier2_rate: float

    tier3_total: int
    tier3_success: int
    tier3_rate: float

    # Human review (failed all tiers)
    human_review_count: int
    human_review_rate: float

    # Cost stats
    total_cost: float
    avg_cost_per_task: float
    avg_cost_per_success: float

    # Time stats
    avg_duration_ms: float
    avg_attempts: float

    # Period
    period_start: str
    period_end: str

    def to_dict(self) -> Dict[str, Any]:
        return {
            "total_tasks": self.total_tasks,
            "successful_tasks": self.successful_tasks,
            "failed_tasks": self.failed_tasks,
            "overall_success_rate": self.overall_success_rate,
            "by_tier": {
                "tier1": {
                    "total": self.tier1_total,
                    "success": self.tier1_success,
                    "rate": self.tier1_rate
                },
                "tier2": {
                    "total": self.tier2_total,
                    "success": self.tier2_success,
                    "rate": self.tier2_rate
                },
                "tier3": {
                    "total": self.tier3_total,
                    "success": self.tier3_success,
                    "rate": self.tier3_rate
                }
            },
            "human_review": {
                "count": self.human_review_count,
                "rate": self.human_review_rate
            },
            "cost": {
                "total": self.total_cost,
                "avg_per_task": self.avg_cost_per_task,
                "avg_per_success": self.avg_cost_per_success
            },
            "performance": {
                "avg_duration_ms": self.avg_duration_ms,
                "avg_attempts": self.avg_attempts
            },
            "period": {
                "start": self.period_start,
                "end": self.period_end
            }
        }


class SuccessRateAnalytics:
    """
    Track and analyze tier success rates.

    Features:
    - Record task completions by tier
    - Calculate success rates per tier
    - Track escalation patterns
    - Store in Redis for persistence
    - Time-based analytics (hourly, daily)
    """

    REDIS_KEY = "genesis:tier_success_rates"
    REDIS_KEY_RECORDS = "genesis:success_records"
    MAX_RECORDS = 10000

    def __init__(self, redis_url: Optional[str] = None):
        """
        Initialize SuccessRateAnalytics.

        Args:
            redis_url: Redis connection URL
        """
        self.redis_url = redis_url or os.getenv("REDIS_URL", "redis://localhost:6379")
        self._redis: Optional[Any] = None

        # In-memory records
        self._records: List[TierSuccessRecord] = []

        # Counters (for quick stats without full record scan)
        self._counters = {
            "total": 0,
            "success": 0,
            "failed": 0,
            "tier1_total": 0,
            "tier1_success": 0,
            "tier2_total": 0,
            "tier2_success": 0,
            "tier3_total": 0,
            "tier3_success": 0,
            "human_review": 0
        }

        # Load from Redis if available
        self._load_state()

    @property
    def redis_client(self) -> Optional[Any]:
        """Lazy Redis client initialization."""
        if self._redis is None and redis is not None:
            try:
                self._redis = redis.from_url(self.redis_url)
                self._redis.ping()
            except Exception as e:
                logger.warning(f"Redis connection failed: {e}")
                self._redis = None
        return self._redis

    def _load_state(self) -> None:
        """Load state from Redis."""
        if self.redis_client:
            try:
                # Load counters
                counters = self.redis_client.hgetall(self.REDIS_KEY)
                if counters:
                    for key, value in counters.items():
                        key_str = key.decode() if isinstance(key, bytes) else key
                        value_str = value.decode() if isinstance(value, bytes) else value
                        if key_str in self._counters:
                            self._counters[key_str] = int(value_str)
                    logger.info(f"Loaded success rate counters from Redis")
            except Exception as e:
                logger.warning(f"Failed to load state from Redis: {e}")

    def _save_state(self) -> None:
        """Save state to Redis."""
        if self.redis_client:
            try:
                self.redis_client.hset(self.REDIS_KEY, mapping={
                    k: str(v) for k, v in self._counters.items()
                })
            except Exception as e:
                logger.warning(f"Failed to save state to Redis: {e}")

    def record_execution(self,
                        task_id: str,
                        final_tier: int,
                        success: bool,
                        total_attempts: int = 1,
                        total_cost: float = 0.0,
                        duration_ms: int = 0) -> TierSuccessRecord:
        """
        Record a task execution result.

        Args:
            task_id: Task identifier
            final_tier: Tier where task completed (or 0 if failed all)
            success: Whether task succeeded
            total_attempts: Total attempts across all tiers
            total_cost: Total cost
            duration_ms: Total duration

        Returns:
            TierSuccessRecord
        """
        record = TierSuccessRecord(
            task_id=task_id,
            final_tier=final_tier,
            success=success,
            total_attempts=total_attempts,
            total_cost=total_cost,
            duration_ms=duration_ms
        )

        # Update counters
        self._counters["total"] += 1

        if success:
            self._counters["success"] += 1

            # Track by final tier
            if final_tier == 1:
                self._counters["tier1_total"] += 1
                self._counters["tier1_success"] += 1
            elif final_tier == 2:
                self._counters["tier2_total"] += 1
                self._counters["tier2_success"] += 1
            elif final_tier == 3:
                self._counters["tier3_total"] += 1
                self._counters["tier3_success"] += 1
        else:
            self._counters["failed"] += 1
            self._counters["human_review"] += 1

        # Store record
        self._records.append(record)
        if len(self._records) > self.MAX_RECORDS:
            self._records = self._records[-self.MAX_RECORDS:]

        # Store in Redis
        if self.redis_client:
            try:
                self.redis_client.zadd(
                    self.REDIS_KEY_RECORDS,
                    {record.to_json(): datetime.utcnow().timestamp()}
                )
                self.redis_client.zremrangebyrank(
                    self.REDIS_KEY_RECORDS, 0, -self.MAX_RECORDS - 1
                )
            except Exception as e:
                logger.warning(f"Failed to store record in Redis: {e}")

        # Save counters
        self._save_state()

        logger.info(
            f"Recorded execution: task={task_id}, tier={final_tier}, "
            f"success={success}, attempts={total_attempts}"
        )

        return record

    def get_success_rates(self) -> Dict[str, float]:
        """
        Get current success rates by tier.

        Returns:
            Dictionary with success rates
        """
        total = self._counters["total"]

        if total == 0:
            return {
                "tier1": 0.0,
                "tier2": 0.0,
                "tier3": 0.0,
                "overall": 0.0,
                "failed": 0.0
            }

        return {
            "tier1": (self._counters["tier1_success"] / total * 100) if total else 0.0,
            "tier2": (self._counters["tier2_success"] / total * 100) if total else 0.0,
            "tier3": (self._counters["tier3_success"] / total * 100) if total else 0.0,
            "overall": (self._counters["success"] / total * 100) if total else 0.0,
            "failed": (self._counters["failed"] / total * 100) if total else 0.0
        }

    def get_statistics(self, period_hours: Optional[int] = None) -> SuccessRateStats:
        """
        Get comprehensive success rate statistics.

        Args:
            period_hours: Limit to last N hours (None for all time)

        Returns:
            SuccessRateStats
        """
        # Filter records by period if specified
        records = self._records
        if period_hours:
            cutoff = datetime.utcnow() - timedelta(hours=period_hours)
            records = [
                r for r in records
                if datetime.fromisoformat(r.timestamp) > cutoff
            ]

        if not records:
            return SuccessRateStats(
                total_tasks=0, successful_tasks=0, failed_tasks=0,
                overall_success_rate=0.0,
                tier1_total=0, tier1_success=0, tier1_rate=0.0,
                tier2_total=0, tier2_success=0, tier2_rate=0.0,
                tier3_total=0, tier3_success=0, tier3_rate=0.0,
                human_review_count=0, human_review_rate=0.0,
                total_cost=0.0, avg_cost_per_task=0.0, avg_cost_per_success=0.0,
                avg_duration_ms=0.0, avg_attempts=0.0,
                period_start="", period_end=""
            )

        # Calculate stats from records
        total = len(records)
        successful = [r for r in records if r.success]
        failed = [r for r in records if not r.success]

        tier1 = [r for r in records if r.final_tier == 1]
        tier2 = [r for r in records if r.final_tier == 2]
        tier3 = [r for r in records if r.final_tier == 3]

        tier1_success = [r for r in tier1 if r.success]
        tier2_success = [r for r in tier2 if r.success]
        tier3_success = [r for r in tier3 if r.success]

        total_cost = sum(r.total_cost for r in records)
        total_duration = sum(r.duration_ms for r in records)
        total_attempts = sum(r.total_attempts for r in records)

        timestamps = [datetime.fromisoformat(r.timestamp) for r in records]
        period_start = min(timestamps).isoformat() if timestamps else ""
        period_end = max(timestamps).isoformat() if timestamps else ""

        return SuccessRateStats(
            total_tasks=total,
            successful_tasks=len(successful),
            failed_tasks=len(failed),
            overall_success_rate=(len(successful) / total * 100) if total else 0.0,

            tier1_total=len(tier1),
            tier1_success=len(tier1_success),
            tier1_rate=(len(tier1_success) / len(tier1) * 100) if tier1 else 0.0,

            tier2_total=len(tier2),
            tier2_success=len(tier2_success),
            tier2_rate=(len(tier2_success) / len(tier2) * 100) if tier2 else 0.0,

            tier3_total=len(tier3),
            tier3_success=len(tier3_success),
            tier3_rate=(len(tier3_success) / len(tier3) * 100) if tier3 else 0.0,

            human_review_count=len(failed),
            human_review_rate=(len(failed) / total * 100) if total else 0.0,

            total_cost=total_cost,
            avg_cost_per_task=(total_cost / total) if total else 0.0,
            avg_cost_per_success=(total_cost / len(successful)) if successful else 0.0,

            avg_duration_ms=(total_duration / total) if total else 0.0,
            avg_attempts=(total_attempts / total) if total else 0.0,

            period_start=period_start,
            period_end=period_end
        )

    def get_escalation_analysis(self) -> Dict[str, Any]:
        """
        Analyze escalation patterns.

        Returns:
            Escalation analysis dictionary
        """
        if not self._records:
            return {"insufficient_data": True}

        # Calculate escalation rates
        total = len(self._records)
        tier1_only = sum(1 for r in self._records if r.final_tier == 1 and r.success)
        escalated_to_2 = sum(1 for r in self._records if r.final_tier >= 2)
        escalated_to_3 = sum(1 for r in self._records if r.final_tier >= 3)
        failed_all = sum(1 for r in self._records if not r.success)

        return {
            "total_tasks": total,
            "completed_tier1": tier1_only,
            "completed_tier1_rate": (tier1_only / total * 100) if total else 0,
            "escalated_to_tier2": escalated_to_2,
            "escalated_to_tier2_rate": (escalated_to_2 / total * 100) if total else 0,
            "escalated_to_tier3": escalated_to_3,
            "escalated_to_tier3_rate": (escalated_to_3 / total * 100) if total else 0,
            "failed_all_tiers": failed_all,
            "failed_all_tiers_rate": (failed_all / total * 100) if total else 0,
            "avg_attempts_tier1": self._avg_attempts_for_tier(1),
            "avg_attempts_tier2": self._avg_attempts_for_tier(2),
            "avg_attempts_tier3": self._avg_attempts_for_tier(3)
        }

    def _avg_attempts_for_tier(self, tier: int) -> float:
        """Calculate average attempts for a tier."""
        tier_records = [r for r in self._records if r.final_tier == tier]
        if not tier_records:
            return 0.0
        return sum(r.total_attempts for r in tier_records) / len(tier_records)

    def reset_counters(self) -> None:
        """Reset all counters (use with caution)."""
        for key in self._counters:
            self._counters[key] = 0
        self._records = []
        self._save_state()
        logger.info("Success rate counters reset")


# Singleton instance
_analytics: Optional[SuccessRateAnalytics] = None


def get_success_analytics() -> SuccessRateAnalytics:
    """Get or create global SuccessRateAnalytics instance."""
    global _analytics
    if _analytics is None:
        _analytics = SuccessRateAnalytics()
    return _analytics


if __name__ == "__main__":
    # Test the SuccessRateAnalytics
    logging.basicConfig(level=logging.INFO)

    analytics = SuccessRateAnalytics()

    # Record some test executions
    analytics.record_execution("task-001", final_tier=1, success=True, total_attempts=3)
    analytics.record_execution("task-002", final_tier=1, success=True, total_attempts=5)
    analytics.record_execution("task-003", final_tier=2, success=True, total_attempts=25)
    analytics.record_execution("task-004", final_tier=3, success=True, total_attempts=35)
    analytics.record_execution("task-005", final_tier=0, success=False, total_attempts=40)

    print("Success Rates:")
    print(json.dumps(analytics.get_success_rates(), indent=2))

    print("\nStatistics:")
    stats = analytics.get_statistics()
    print(json.dumps(stats.to_dict(), indent=2))

    print("\nEscalation Analysis:")
    print(json.dumps(analytics.get_escalation_analysis(), indent=2))
