"""
Genesis Circuit Breaker
========================
Implements the Circuit Breaker pattern for graceful service degradation.

States:
- CLOSED: Normal operation, requests flow through
- OPEN: Service failed, requests immediately return fallback
- HALF_OPEN: Testing if service recovered

Usage:
    from circuit_breaker import CircuitBreaker, CircuitState

    # Create breaker for Redis
    redis_breaker = CircuitBreaker(
        name="redis",
        failure_threshold=3,
        recovery_timeout=30,
        half_open_requests=2
    )

    # Use as decorator
    @redis_breaker
    def get_from_redis(key):
        return redis_client.get(key)

    # Or use context manager
    with redis_breaker:
        result = redis_client.get(key)

    # Check state
    if redis_breaker.state == CircuitState.OPEN:
        use_fallback()
"""

import time
import threading
from enum import Enum
from typing import Callable, Any, Optional, Dict, List
from dataclasses import dataclass, field
from functools import wraps
from datetime import datetime
import json
from pathlib import Path


class CircuitState(Enum):
    """Circuit breaker states."""
    CLOSED = "closed"       # Normal operation
    OPEN = "open"           # Failing, reject requests
    HALF_OPEN = "half_open" # Testing recovery


@dataclass
class CircuitStats:
    """Statistics for a circuit breaker."""
    total_calls: int = 0
    successful_calls: int = 0
    failed_calls: int = 0
    rejected_calls: int = 0
    last_failure_time: Optional[float] = None
    last_success_time: Optional[float] = None
    state_changes: List[Dict] = field(default_factory=list)

    def to_dict(self) -> Dict:
        return {
            "total_calls": self.total_calls,
            "successful_calls": self.successful_calls,
            "failed_calls": self.failed_calls,
            "rejected_calls": self.rejected_calls,
            "success_rate": self.successful_calls / max(1, self.total_calls),
            "last_failure": self.last_failure_time,
            "last_success": self.last_success_time,
            "state_changes": self.state_changes[-10:]  # Last 10 changes
        }


class CircuitBreaker:
    """
    Circuit breaker for protecting against cascading failures.

    Args:
        name: Identifier for this circuit
        failure_threshold: Failures before opening circuit
        recovery_timeout: Seconds before trying half-open
        half_open_requests: Successful requests needed to close
        exceptions: Exception types to catch (default: Exception)
    """

    def __init__(
        self,
        name: str,
        failure_threshold: int = 5,
        recovery_timeout: float = 30.0,
        half_open_requests: int = 2,
        exceptions: tuple = (Exception,)
    ):
        self.name = name
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.half_open_requests = half_open_requests
        self.exceptions = exceptions

        self._state = CircuitState.CLOSED
        self._failure_count = 0
        self._success_count = 0
        self._last_failure_time: Optional[float] = None
        self._lock = threading.RLock()
        self._stats = CircuitStats()

    @property
    def state(self) -> CircuitState:
        """Get current state, checking for automatic recovery."""
        with self._lock:
            if self._state == CircuitState.OPEN:
                if self._should_attempt_recovery():
                    self._transition_to(CircuitState.HALF_OPEN)
            return self._state

    @property
    def is_available(self) -> bool:
        """Check if circuit allows requests."""
        return self.state != CircuitState.OPEN

    def _should_attempt_recovery(self) -> bool:
        """Check if enough time passed to try recovery."""
        if self._last_failure_time is None:
            return True
        return time.time() - self._last_failure_time >= self.recovery_timeout

    def _transition_to(self, new_state: CircuitState) -> None:
        """Transition to a new state."""
        old_state = self._state
        self._state = new_state

        if new_state == CircuitState.CLOSED:
            self._failure_count = 0
            self._success_count = 0
        elif new_state == CircuitState.HALF_OPEN:
            self._success_count = 0

        self._stats.state_changes.append({
            "from": old_state.value,
            "to": new_state.value,
            "timestamp": datetime.now().isoformat()
        })

        print(f"[CircuitBreaker:{self.name}] {old_state.value} -> {new_state.value}")

    def record_success(self) -> None:
        """Record a successful call."""
        with self._lock:
            self._stats.total_calls += 1
            self._stats.successful_calls += 1
            self._stats.last_success_time = time.time()

            if self._state == CircuitState.HALF_OPEN:
                self._success_count += 1
                if self._success_count >= self.half_open_requests:
                    self._transition_to(CircuitState.CLOSED)

    def record_failure(self, exc: Optional[Exception] = None) -> None:
        """Record a failed call."""
        with self._lock:
            self._stats.total_calls += 1
            self._stats.failed_calls += 1
            self._stats.last_failure_time = time.time()
            self._last_failure_time = time.time()
            self._failure_count += 1

            if self._state == CircuitState.HALF_OPEN:
                # Any failure in half-open reopens immediately
                self._transition_to(CircuitState.OPEN)
            elif self._state == CircuitState.CLOSED:
                if self._failure_count >= self.failure_threshold:
                    self._transition_to(CircuitState.OPEN)

    def record_rejection(self) -> None:
        """Record a rejected call (circuit open)."""
        with self._lock:
            self._stats.total_calls += 1
            self._stats.rejected_calls += 1

    def reset(self) -> None:
        """Manually reset the circuit to closed state."""
        with self._lock:
            self._transition_to(CircuitState.CLOSED)
            self._failure_count = 0
            self._success_count = 0
            self._last_failure_time = None

    def get_stats(self) -> Dict:
        """Get current statistics."""
        with self._lock:
            stats = self._stats.to_dict()
            stats["name"] = self.name
            stats["state"] = self._state.value
            stats["failure_count"] = self._failure_count
            return stats

    def __call__(self, func: Callable) -> Callable:
        """Use as decorator."""
        @wraps(func)
        def wrapper(*args, **kwargs):
            return self.call(func, *args, **kwargs)
        return wrapper

    def __enter__(self):
        """Context manager entry."""
        if not self.is_available:
            self.record_rejection()
            raise CircuitOpenError(f"Circuit '{self.name}' is OPEN")
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit."""
        if exc_type is None:
            self.record_success()
        elif exc_type and issubclass(exc_type, self.exceptions):
            self.record_failure(exc_val)
        return False  # Don't suppress exceptions

    def call(self, func: Callable, *args, **kwargs) -> Any:
        """Execute function with circuit breaker protection."""
        if not self.is_available:
            self.record_rejection()
            raise CircuitOpenError(f"Circuit '{self.name}' is OPEN")

        try:
            result = func(*args, **kwargs)
            self.record_success()
            return result
        except self.exceptions as e:
            self.record_failure(e)
            raise


class CircuitOpenError(Exception):
    """Raised when circuit is open and call is rejected."""
    pass


class CircuitBreakerRegistry:
    """
    Registry for managing multiple circuit breakers.

    Usage:
        registry = CircuitBreakerRegistry()
        registry.register("redis", failure_threshold=3)
        registry.register("qdrant", failure_threshold=5)

        # Get a breaker
        redis_cb = registry.get("redis")

        # Get all stats
        stats = registry.get_all_stats()
    """

    def __init__(self, persist_path: Optional[str] = None):
        self._breakers: Dict[str, CircuitBreaker] = {}
        self._lock = threading.Lock()
        self._persist_path = Path(persist_path) if persist_path else None

    def register(
        self,
        name: str,
        failure_threshold: int = 5,
        recovery_timeout: float = 30.0,
        half_open_requests: int = 2,
        exceptions: tuple = (Exception,)
    ) -> CircuitBreaker:
        """Register a new circuit breaker."""
        with self._lock:
            if name in self._breakers:
                return self._breakers[name]

            breaker = CircuitBreaker(
                name=name,
                failure_threshold=failure_threshold,
                recovery_timeout=recovery_timeout,
                half_open_requests=half_open_requests,
                exceptions=exceptions
            )
            self._breakers[name] = breaker
            return breaker

    def get(self, name: str) -> Optional[CircuitBreaker]:
        """Get a circuit breaker by name."""
        return self._breakers.get(name)

    def get_or_create(
        self,
        name: str,
        **kwargs
    ) -> CircuitBreaker:
        """Get existing or create new circuit breaker."""
        if name not in self._breakers:
            return self.register(name, **kwargs)
        return self._breakers[name]

    def get_all_stats(self) -> Dict[str, Dict]:
        """Get stats from all circuit breakers."""
        return {name: cb.get_stats() for name, cb in self._breakers.items()}

    def get_health_summary(self) -> Dict:
        """Get health summary of all circuits."""
        stats = self.get_all_stats()
        open_circuits = [n for n, s in stats.items() if s["state"] == "open"]
        half_open = [n for n, s in stats.items() if s["state"] == "half_open"]

        return {
            "total_circuits": len(stats),
            "open": len(open_circuits),
            "half_open": len(half_open),
            "closed": len(stats) - len(open_circuits) - len(half_open),
            "open_circuits": open_circuits,
            "half_open_circuits": half_open,
            "healthy": len(open_circuits) == 0
        }

    def reset_all(self) -> None:
        """Reset all circuit breakers."""
        for cb in self._breakers.values():
            cb.reset()

    def persist(self) -> None:
        """Persist circuit breaker stats to file."""
        if not self._persist_path:
            return

        try:
            stats = self.get_all_stats()
            self._persist_path.parent.mkdir(parents=True, exist_ok=True)
            with open(self._persist_path, 'w') as f:
                json.dump(stats, f, indent=2)
        except Exception as e:
            print(f"[!] Failed to persist circuit breaker stats: {e}")


# Global registry instance
_global_registry = CircuitBreakerRegistry()


def get_circuit_breaker(name: str, **kwargs) -> CircuitBreaker:
    """Get or create a circuit breaker from global registry."""
    return _global_registry.get_or_create(name, **kwargs)


def get_registry() -> CircuitBreakerRegistry:
    """Get the global circuit breaker registry."""
    return _global_registry


def get_all_status() -> Dict:
    """Get health status of all circuit breakers."""
    return _global_registry.get_health_summary()


# CLI interface
if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1:
        cmd = sys.argv[1]

        if cmd == "status":
            health = _global_registry.get_health_summary()
            print(json.dumps(health, indent=2))

        elif cmd == "demo":
            # Demo circuit breaker behavior
            cb = CircuitBreaker("demo", failure_threshold=3, recovery_timeout=5)

            print("Simulating failures...")
            for i in range(5):
                try:
                    with cb:
                        if i < 4:  # Fail first 4 times
                            raise ConnectionError("Simulated failure")
                except (ConnectionError, CircuitOpenError) as e:
                    print(f"  Call {i+1}: {type(e).__name__}")

            print(f"\nCircuit state: {cb.state.value}")
            print(f"Stats: {cb.get_stats()}")

            print("\nWaiting for recovery timeout (5s)...")
            time.sleep(6)

            print(f"Circuit state after timeout: {cb.state.value}")

        else:
            print(f"Unknown command: {cmd}")
            print("Usage: python circuit_breaker.py [status|demo]")
    else:
        print("Genesis Circuit Breaker")
        print("Usage: python circuit_breaker.py [status|demo]")
