"""
AIVA Production Circuit Breaker System
=======================================
A production-grade circuit breaker implementation with comprehensive
failure detection, recovery strategies, and graceful degradation.

Components:
    - CircuitBreaker: Main circuit breaker pattern implementation
    - StateManager: Manages Open/Half-Open/Closed state transitions
    - FailureDetector: Detects and classifies service failures
    - RecoveryStrategy: Implements automatic recovery mechanisms
    - FallbackHandler: Provides graceful degradation capabilities
    - MetricsCollector: Tracks failure rates and performance metrics

Usage:
    from prod_01_circuit_breaker import (
        CircuitBreaker,
        CircuitBreakerConfig,
        FallbackHandler,
        ExponentialBackoffRecovery
    )

    # Configure circuit breaker
    config = CircuitBreakerConfig(
        failure_threshold=5,
        recovery_timeout=30.0,
        half_open_max_calls=3,
        failure_rate_threshold=0.5
    )

    # Create with fallback
    fallback = FallbackHandler(default_value={"status": "degraded"})

    cb = CircuitBreaker(
        name="external_api",
        config=config,
        fallback_handler=fallback
    )

    # Use as decorator
    @cb.protect
    async def call_external_api(data):
        return await api_client.post(data)

    # Or use directly
    result = await cb.execute(call_external_api, data)

Author: AIVA Genesis System
Version: 1.0.0
"""

from __future__ import annotations

import asyncio
import hashlib
import json
import logging
import threading
import time
import traceback
import uuid
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum, auto
from functools import wraps
from pathlib import Path
from typing import (
    Any,
    Callable,
    Coroutine,
    Deque,
    Dict,
    Generic,
    List,
    Optional,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("aiva.circuit_breaker")

# Type variables for generic support
T = TypeVar("T")
R = TypeVar("R")


# =============================================================================
# ENUMS AND CONSTANTS
# =============================================================================


class CircuitState(Enum):
    """Circuit breaker states with automatic transition rules."""

    CLOSED = "closed"  # Normal operation, requests flow through
    OPEN = "open"  # Service failed, requests rejected immediately
    HALF_OPEN = "half_open"  # Testing if service has recovered
    FORCED_OPEN = "forced_open"  # Manually opened, no auto-recovery
    DISABLED = "disabled"  # Circuit breaker bypassed


class FailureType(Enum):
    """Classification of failure types for intelligent handling."""

    TIMEOUT = auto()
    CONNECTION_ERROR = auto()
    SERVER_ERROR = auto()
    CLIENT_ERROR = auto()
    RATE_LIMITED = auto()
    CIRCUIT_OPEN = auto()
    UNKNOWN = auto()


class RecoveryMode(Enum):
    """Recovery strategy modes."""

    IMMEDIATE = "immediate"  # Try immediately after timeout
    LINEAR_BACKOFF = "linear_backoff"  # Linear increase in wait time
    EXPONENTIAL_BACKOFF = "exponential_backoff"  # Exponential increase
    FIBONACCI_BACKOFF = "fibonacci_backoff"  # Fibonacci sequence backoff
    ADAPTIVE = "adaptive"  # Learn optimal recovery timing


class MetricType(Enum):
    """Types of metrics collected."""

    CALL_COUNT = "call_count"
    SUCCESS_COUNT = "success_count"
    FAILURE_COUNT = "failure_count"
    REJECTION_COUNT = "rejection_count"
    TIMEOUT_COUNT = "timeout_count"
    LATENCY = "latency"
    FAILURE_RATE = "failure_rate"
    STATE_CHANGE = "state_change"


# =============================================================================
# CONFIGURATION
# =============================================================================


@dataclass
class CircuitBreakerConfig:
    """
    Configuration for circuit breaker behavior.

    Attributes:
        failure_threshold: Number of failures before opening circuit
        recovery_timeout: Seconds to wait before half-open transition
        half_open_max_calls: Max test calls in half-open state
        success_threshold: Successes needed in half-open to close
        failure_rate_threshold: Failure rate to trigger open (0.0-1.0)
        sliding_window_size: Window size for failure rate calculation
        slow_call_duration_threshold: Threshold for slow call detection (ms)
        slow_call_rate_threshold: Rate of slow calls to trigger open
        timeout: Default call timeout in seconds
        record_exceptions: Exception types to track as failures
        ignore_exceptions: Exception types to ignore
        automatic_transition_enabled: Allow automatic state transitions
        metrics_enabled: Enable metrics collection
        metrics_window_seconds: Time window for metrics aggregation
        persist_state: Whether to persist state to disk
        state_file_path: Path for state persistence
    """

    failure_threshold: int = 5
    recovery_timeout: float = 30.0
    half_open_max_calls: int = 3
    success_threshold: int = 2
    failure_rate_threshold: float = 0.5
    sliding_window_size: int = 100
    slow_call_duration_threshold: float = 2000.0  # milliseconds
    slow_call_rate_threshold: float = 0.8
    timeout: float = 10.0
    record_exceptions: Tuple[Type[Exception], ...] = (Exception,)
    ignore_exceptions: Tuple[Type[Exception], ...] = ()
    automatic_transition_enabled: bool = True
    metrics_enabled: bool = True
    metrics_window_seconds: int = 60
    persist_state: bool = False
    state_file_path: Optional[str] = None

    def validate(self) -> None:
        """Validate configuration values."""
        if self.failure_threshold < 1:
            raise ValueError("failure_threshold must be >= 1")
        if self.recovery_timeout < 0:
            raise ValueError("recovery_timeout must be >= 0")
        if not 0 <= self.failure_rate_threshold <= 1:
            raise ValueError("failure_rate_threshold must be between 0 and 1")
        if self.sliding_window_size < 1:
            raise ValueError("sliding_window_size must be >= 1")


# =============================================================================
# METRICS COLLECTOR
# =============================================================================


@dataclass
class CallRecord:
    """Record of a single call through the circuit breaker."""

    timestamp: float
    duration_ms: float
    success: bool
    failure_type: Optional[FailureType] = None
    exception_type: Optional[str] = None
    exception_message: Optional[str] = None


@dataclass
class MetricSnapshot:
    """Point-in-time snapshot of metrics."""

    timestamp: float
    total_calls: int
    successful_calls: int
    failed_calls: int
    rejected_calls: int
    timeout_calls: int
    avg_latency_ms: float
    p50_latency_ms: float
    p95_latency_ms: float
    p99_latency_ms: float
    failure_rate: float
    slow_call_rate: float
    current_state: str
    state_duration_seconds: float


class MetricsCollector:
    """
    Collects and aggregates circuit breaker metrics.

    Provides real-time metrics, historical data, and statistical analysis
    for monitoring circuit breaker health and performance.
    """

    def __init__(
        self,
        window_size: int = 100,
        window_seconds: int = 60,
        history_size: int = 1000,
    ):
        self._window_size = window_size
        self._window_seconds = window_seconds
        self._history_size = history_size

        self._calls: Deque[CallRecord] = deque(maxlen=window_size)
        self._history: Deque[MetricSnapshot] = deque(maxlen=history_size)
        self._lock = threading.RLock()

        # Counters (cumulative)
        self._total_calls = 0
        self._successful_calls = 0
        self._failed_calls = 0
        self._rejected_calls = 0
        self._timeout_calls = 0

        # State tracking
        self._state_entered_at: float = time.time()
        self._current_state: str = CircuitState.CLOSED.value
        self._state_history: List[Dict[str, Any]] = []

        # Failure type distribution
        self._failure_distribution: Dict[FailureType, int] = {
            ft: 0 for ft in FailureType
        }

    def record_call(
        self,
        duration_ms: float,
        success: bool,
        failure_type: Optional[FailureType] = None,
        exception: Optional[Exception] = None,
    ) -> None:
        """Record a call through the circuit breaker."""
        with self._lock:
            record = CallRecord(
                timestamp=time.time(),
                duration_ms=duration_ms,
                success=success,
                failure_type=failure_type,
                exception_type=type(exception).__name__ if exception else None,
                exception_message=str(exception) if exception else None,
            )
            self._calls.append(record)

            self._total_calls += 1
            if success:
                self._successful_calls += 1
            else:
                self._failed_calls += 1
                if failure_type:
                    self._failure_distribution[failure_type] += 1

    def record_rejection(self) -> None:
        """Record a rejected call (circuit open)."""
        with self._lock:
            self._total_calls += 1
            self._rejected_calls += 1
            self._failure_distribution[FailureType.CIRCUIT_OPEN] += 1

    def record_timeout(self) -> None:
        """Record a timeout."""
        with self._lock:
            self._timeout_calls += 1
            self._failure_distribution[FailureType.TIMEOUT] += 1

    def record_state_change(self, old_state: str, new_state: str) -> None:
        """Record a state transition."""
        with self._lock:
            now = time.time()
            duration = now - self._state_entered_at

            self._state_history.append(
                {
                    "from_state": old_state,
                    "to_state": new_state,
                    "timestamp": now,
                    "duration_in_previous_state": duration,
                }
            )

            self._state_entered_at = now
            self._current_state = new_state

            # Keep history bounded
            if len(self._state_history) > 100:
                self._state_history = self._state_history[-100:]

    def get_sliding_window_stats(self) -> Dict[str, Any]:
        """Get statistics for the current sliding window."""
        with self._lock:
            now = time.time()
            cutoff = now - self._window_seconds

            # Filter to recent calls
            recent_calls = [c for c in self._calls if c.timestamp >= cutoff]

            if not recent_calls:
                return {
                    "calls_in_window": 0,
                    "success_rate": 1.0,
                    "failure_rate": 0.0,
                    "avg_latency_ms": 0.0,
                }

            successes = sum(1 for c in recent_calls if c.success)
            total = len(recent_calls)
            latencies = [c.duration_ms for c in recent_calls]

            return {
                "calls_in_window": total,
                "success_rate": successes / total if total > 0 else 1.0,
                "failure_rate": (total - successes) / total if total > 0 else 0.0,
                "avg_latency_ms": sum(latencies) / len(latencies),
                "min_latency_ms": min(latencies),
                "max_latency_ms": max(latencies),
            }

    def get_percentiles(self) -> Dict[str, float]:
        """Calculate latency percentiles."""
        with self._lock:
            latencies = sorted([c.duration_ms for c in self._calls])
            if not latencies:
                return {"p50": 0.0, "p75": 0.0, "p90": 0.0, "p95": 0.0, "p99": 0.0}

            def percentile(p: float) -> float:
                k = (len(latencies) - 1) * p / 100
                f = int(k)
                c = f + 1 if f + 1 < len(latencies) else f
                return latencies[f] + (k - f) * (latencies[c] - latencies[f])

            return {
                "p50": percentile(50),
                "p75": percentile(75),
                "p90": percentile(90),
                "p95": percentile(95),
                "p99": percentile(99),
            }

    def get_failure_distribution(self) -> Dict[str, int]:
        """Get distribution of failure types."""
        with self._lock:
            return {ft.name: count for ft, count in self._failure_distribution.items()}

    def snapshot(self) -> MetricSnapshot:
        """Create a point-in-time snapshot of all metrics."""
        with self._lock:
            now = time.time()
            window_stats = self.get_sliding_window_stats()
            percentiles = self.get_percentiles()

            # Calculate slow call rate
            slow_calls = sum(
                1
                for c in self._calls
                if c.duration_ms > 2000  # Default slow threshold
            )
            slow_call_rate = (
                slow_calls / len(self._calls) if self._calls else 0.0
            )

            snapshot = MetricSnapshot(
                timestamp=now,
                total_calls=self._total_calls,
                successful_calls=self._successful_calls,
                failed_calls=self._failed_calls,
                rejected_calls=self._rejected_calls,
                timeout_calls=self._timeout_calls,
                avg_latency_ms=window_stats.get("avg_latency_ms", 0.0),
                p50_latency_ms=percentiles["p50"],
                p95_latency_ms=percentiles["p95"],
                p99_latency_ms=percentiles["p99"],
                failure_rate=window_stats.get("failure_rate", 0.0),
                slow_call_rate=slow_call_rate,
                current_state=self._current_state,
                state_duration_seconds=now - self._state_entered_at,
            )

            self._history.append(snapshot)
            return snapshot

    def get_health_score(self) -> float:
        """
        Calculate overall health score (0.0 - 1.0).

        Combines failure rate, latency, and state into a single score.
        """
        with self._lock:
            stats = self.get_sliding_window_stats()

            # Base score from success rate
            success_score = stats.get("success_rate", 1.0)

            # Penalize high latency
            avg_latency = stats.get("avg_latency_ms", 0)
            latency_penalty = min(avg_latency / 5000, 0.3)  # Max 30% penalty

            # Penalize non-closed states
            state_penalty = 0.0
            if self._current_state == CircuitState.OPEN.value:
                state_penalty = 0.5
            elif self._current_state == CircuitState.HALF_OPEN.value:
                state_penalty = 0.2

            score = max(0.0, success_score - latency_penalty - state_penalty)
            return round(score, 3)

    def export_metrics(self) -> Dict[str, Any]:
        """Export all metrics as a dictionary."""
        with self._lock:
            return {
                "totals": {
                    "total_calls": self._total_calls,
                    "successful_calls": self._successful_calls,
                    "failed_calls": self._failed_calls,
                    "rejected_calls": self._rejected_calls,
                    "timeout_calls": self._timeout_calls,
                },
                "sliding_window": self.get_sliding_window_stats(),
                "percentiles": self.get_percentiles(),
                "failure_distribution": self.get_failure_distribution(),
                "state": {
                    "current": self._current_state,
                    "duration_seconds": time.time() - self._state_entered_at,
                    "transitions": len(self._state_history),
                },
                "health_score": self.get_health_score(),
            }

    def reset(self) -> None:
        """Reset all metrics."""
        with self._lock:
            self._calls.clear()
            self._total_calls = 0
            self._successful_calls = 0
            self._failed_calls = 0
            self._rejected_calls = 0
            self._timeout_calls = 0
            self._failure_distribution = {ft: 0 for ft in FailureType}
            self._state_history.clear()
            self._state_entered_at = time.time()


# =============================================================================
# FAILURE DETECTOR
# =============================================================================


class FailureDetector:
    """
    Detects and classifies service failures.

    Uses pattern recognition and heuristics to identify:
    - Connection failures
    - Timeout failures
    - Server errors (5xx)
    - Client errors (4xx)
    - Rate limiting
    - Cascading failures
    """

    # Common exception patterns
    TIMEOUT_PATTERNS = frozenset(
        [
            "timeout",
            "timed out",
            "deadline exceeded",
            "connection timed out",
            "read timed out",
        ]
    )

    CONNECTION_PATTERNS = frozenset(
        [
            "connection refused",
            "connection reset",
            "connection closed",
            "connection error",
            "network unreachable",
            "host unreachable",
            "no route to host",
            "connection aborted",
        ]
    )

    RATE_LIMIT_PATTERNS = frozenset(
        [
            "rate limit",
            "too many requests",
            "429",
            "throttled",
            "quota exceeded",
        ]
    )

    def __init__(
        self,
        slow_call_threshold_ms: float = 2000.0,
        consecutive_failure_threshold: int = 3,
    ):
        self._slow_call_threshold = slow_call_threshold_ms
        self._consecutive_failure_threshold = consecutive_failure_threshold
        self._consecutive_failures = 0
        self._last_failure_time: Optional[float] = None
        self._lock = threading.Lock()

    def classify_exception(self, exception: Exception) -> FailureType:
        """Classify an exception into a failure type."""
        exc_str = str(exception).lower()
        exc_type = type(exception).__name__.lower()

        # Check timeout patterns
        if any(p in exc_str or p in exc_type for p in self.TIMEOUT_PATTERNS):
            return FailureType.TIMEOUT

        # Check connection patterns
        if any(p in exc_str or p in exc_type for p in self.CONNECTION_PATTERNS):
            return FailureType.CONNECTION_ERROR

        # Check rate limiting
        if any(p in exc_str for p in self.RATE_LIMIT_PATTERNS):
            return FailureType.RATE_LIMITED

        # Check for HTTP status codes in exception
        if "500" in exc_str or "502" in exc_str or "503" in exc_str or "504" in exc_str:
            return FailureType.SERVER_ERROR
        if "400" in exc_str or "401" in exc_str or "403" in exc_str or "404" in exc_str:
            return FailureType.CLIENT_ERROR

        return FailureType.UNKNOWN

    def is_slow_call(self, duration_ms: float) -> bool:
        """Check if a call is considered slow."""
        return duration_ms >= self._slow_call_threshold

    def record_failure(self) -> Tuple[bool, int]:
        """
        Record a failure and check for consecutive failure threshold.

        Returns:
            Tuple of (threshold_breached, consecutive_count)
        """
        with self._lock:
            now = time.time()

            # Reset if too much time passed since last failure
            if (
                self._last_failure_time
                and now - self._last_failure_time > 60
            ):
                self._consecutive_failures = 0

            self._consecutive_failures += 1
            self._last_failure_time = now

            breached = self._consecutive_failures >= self._consecutive_failure_threshold
            return breached, self._consecutive_failures

    def record_success(self) -> None:
        """Record a success, resetting consecutive failure count."""
        with self._lock:
            self._consecutive_failures = 0

    def should_trip(
        self,
        failure_count: int,
        total_count: int,
        failure_threshold: int,
        failure_rate_threshold: float,
    ) -> bool:
        """
        Determine if circuit should trip based on failure patterns.

        Uses both count-based and rate-based thresholds.
        """
        # Count-based check
        if failure_count >= failure_threshold:
            return True

        # Rate-based check (only if we have enough samples)
        if total_count >= 10:
            failure_rate = failure_count / total_count
            if failure_rate >= failure_rate_threshold:
                return True

        return False

    def get_failure_summary(self) -> Dict[str, Any]:
        """Get current failure detection state."""
        with self._lock:
            return {
                "consecutive_failures": self._consecutive_failures,
                "threshold": self._consecutive_failure_threshold,
                "last_failure_time": self._last_failure_time,
                "threshold_breached": (
                    self._consecutive_failures >= self._consecutive_failure_threshold
                ),
            }


# =============================================================================
# RECOVERY STRATEGY
# =============================================================================


class RecoveryStrategy(ABC):
    """Abstract base class for recovery strategies."""

    @abstractmethod
    def get_recovery_delay(self, attempt: int) -> float:
        """Get delay before next recovery attempt."""
        pass

    @abstractmethod
    def should_attempt_recovery(
        self, time_since_open: float, attempt: int
    ) -> bool:
        """Determine if recovery should be attempted."""
        pass

    @abstractmethod
    def reset(self) -> None:
        """Reset recovery state."""
        pass


class ImmediateRecovery(RecoveryStrategy):
    """Attempt recovery immediately after base timeout."""

    def __init__(self, base_timeout: float = 30.0):
        self._base_timeout = base_timeout

    def get_recovery_delay(self, attempt: int) -> float:
        return self._base_timeout

    def should_attempt_recovery(
        self, time_since_open: float, attempt: int
    ) -> bool:
        return time_since_open >= self._base_timeout

    def reset(self) -> None:
        pass


class ExponentialBackoffRecovery(RecoveryStrategy):
    """
    Exponential backoff recovery strategy.

    Each failed recovery attempt doubles the wait time,
    up to a maximum delay.
    """

    def __init__(
        self,
        base_timeout: float = 30.0,
        max_timeout: float = 600.0,
        multiplier: float = 2.0,
        jitter: float = 0.1,
    ):
        self._base_timeout = base_timeout
        self._max_timeout = max_timeout
        self._multiplier = multiplier
        self._jitter = jitter
        self._attempt = 0
        self._lock = threading.Lock()

    def get_recovery_delay(self, attempt: int) -> float:
        import random

        delay = min(
            self._base_timeout * (self._multiplier ** attempt),
            self._max_timeout,
        )
        # Add jitter to prevent thundering herd
        jitter_range = delay * self._jitter
        delay += random.uniform(-jitter_range, jitter_range)
        return max(self._base_timeout, delay)

    def should_attempt_recovery(
        self, time_since_open: float, attempt: int
    ) -> bool:
        required_delay = self.get_recovery_delay(attempt)
        return time_since_open >= required_delay

    def reset(self) -> None:
        with self._lock:
            self._attempt = 0


class FibonacciBackoffRecovery(RecoveryStrategy):
    """
    Fibonacci sequence backoff recovery.

    Provides a middle ground between linear and exponential backoff.
    """

    def __init__(
        self,
        base_timeout: float = 30.0,
        max_timeout: float = 600.0,
    ):
        self._base_timeout = base_timeout
        self._max_timeout = max_timeout
        self._fib_cache: Dict[int, int] = {0: 0, 1: 1}

    def _fibonacci(self, n: int) -> int:
        if n in self._fib_cache:
            return self._fib_cache[n]
        self._fib_cache[n] = self._fibonacci(n - 1) + self._fibonacci(n - 2)
        return self._fib_cache[n]

    def get_recovery_delay(self, attempt: int) -> float:
        fib_multiplier = self._fibonacci(attempt + 2)  # Start from F(2) = 1
        delay = self._base_timeout * fib_multiplier
        return min(delay, self._max_timeout)

    def should_attempt_recovery(
        self, time_since_open: float, attempt: int
    ) -> bool:
        return time_since_open >= self.get_recovery_delay(attempt)

    def reset(self) -> None:
        pass


class AdaptiveRecovery(RecoveryStrategy):
    """
    Adaptive recovery strategy that learns optimal timing.

    Tracks successful and failed recovery attempts to adjust
    the recovery delay dynamically.
    """

    def __init__(
        self,
        initial_timeout: float = 30.0,
        min_timeout: float = 10.0,
        max_timeout: float = 600.0,
        learning_rate: float = 0.1,
    ):
        self._current_timeout = initial_timeout
        self._initial_timeout = initial_timeout
        self._min_timeout = min_timeout
        self._max_timeout = max_timeout
        self._learning_rate = learning_rate
        self._success_history: Deque[Tuple[float, bool]] = deque(maxlen=50)
        self._lock = threading.Lock()

    def record_recovery_attempt(self, delay_used: float, success: bool) -> None:
        """Record a recovery attempt result for learning."""
        with self._lock:
            self._success_history.append((delay_used, success))

            # Adjust timeout based on result
            if success:
                # Decrease timeout on success (service recovering faster)
                self._current_timeout = max(
                    self._min_timeout,
                    self._current_timeout * (1 - self._learning_rate),
                )
            else:
                # Increase timeout on failure (service needs more time)
                self._current_timeout = min(
                    self._max_timeout,
                    self._current_timeout * (1 + self._learning_rate * 2),
                )

    def get_recovery_delay(self, attempt: int) -> float:
        with self._lock:
            # Add small increment for repeated attempts
            return self._current_timeout * (1 + attempt * 0.2)

    def should_attempt_recovery(
        self, time_since_open: float, attempt: int
    ) -> bool:
        return time_since_open >= self.get_recovery_delay(attempt)

    def reset(self) -> None:
        with self._lock:
            self._current_timeout = self._initial_timeout

    def get_stats(self) -> Dict[str, Any]:
        """Get adaptive recovery statistics."""
        with self._lock:
            successes = sum(1 for _, s in self._success_history if s)
            total = len(self._success_history)
            return {
                "current_timeout": self._current_timeout,
                "initial_timeout": self._initial_timeout,
                "recovery_success_rate": successes / total if total > 0 else 0.0,
                "samples": total,
            }


# =============================================================================
# FALLBACK HANDLER
# =============================================================================


class FallbackHandler(Generic[T]):
    """
    Handles graceful degradation when circuit is open.

    Provides multiple fallback strategies:
    - Default value
    - Cached value
    - Alternative service
    - Custom fallback function
    """

    def __init__(
        self,
        default_value: Optional[T] = None,
        fallback_function: Optional[Callable[..., T]] = None,
        cache_duration_seconds: float = 300.0,
        log_fallbacks: bool = True,
    ):
        self._default_value = default_value
        self._fallback_function = fallback_function
        self._cache_duration = cache_duration_seconds
        self._log_fallbacks = log_fallbacks

        self._cache: Dict[str, Tuple[T, float]] = {}
        self._fallback_count = 0
        self._lock = threading.RLock()

    def cache_result(self, key: str, value: T) -> None:
        """Cache a successful result for later fallback use."""
        with self._lock:
            self._cache[key] = (value, time.time())

    def get_cached(self, key: str) -> Optional[T]:
        """Get cached value if still valid."""
        with self._lock:
            if key not in self._cache:
                return None

            value, cached_at = self._cache[key]
            if time.time() - cached_at > self._cache_duration:
                del self._cache[key]
                return None

            return value

    def execute_fallback(
        self,
        *args,
        cache_key: Optional[str] = None,
        **kwargs,
    ) -> T:
        """
        Execute fallback strategy.

        Priority:
        1. Cached value (if key provided and cache hit)
        2. Fallback function (if provided)
        3. Default value
        """
        with self._lock:
            self._fallback_count += 1

            if self._log_fallbacks:
                logger.warning(
                    f"Executing fallback (count: {self._fallback_count})"
                )

            # Try cache first
            if cache_key:
                cached = self.get_cached(cache_key)
                if cached is not None:
                    logger.info(f"Using cached fallback for key: {cache_key}")
                    return cached

            # Try fallback function
            if self._fallback_function:
                try:
                    return self._fallback_function(*args, **kwargs)
                except Exception as e:
                    logger.error(f"Fallback function failed: {e}")

            # Return default
            if self._default_value is not None:
                return self._default_value

            raise CircuitBreakerOpenError(
                "Circuit is open and no fallback available"
            )

    async def execute_fallback_async(
        self,
        *args,
        cache_key: Optional[str] = None,
        **kwargs,
    ) -> T:
        """Async version of execute_fallback."""
        with self._lock:
            self._fallback_count += 1

            if self._log_fallbacks:
                logger.warning(
                    f"Executing async fallback (count: {self._fallback_count})"
                )

            # Try cache first
            if cache_key:
                cached = self.get_cached(cache_key)
                if cached is not None:
                    return cached

            # Try fallback function
            if self._fallback_function:
                try:
                    result = self._fallback_function(*args, **kwargs)
                    if asyncio.iscoroutine(result):
                        return await result
                    return result
                except Exception as e:
                    logger.error(f"Async fallback function failed: {e}")

            # Return default
            if self._default_value is not None:
                return self._default_value

            raise CircuitBreakerOpenError(
                "Circuit is open and no fallback available"
            )

    def get_stats(self) -> Dict[str, Any]:
        """Get fallback statistics."""
        with self._lock:
            return {
                "fallback_count": self._fallback_count,
                "cache_size": len(self._cache),
                "has_default": self._default_value is not None,
                "has_function": self._fallback_function is not None,
            }

    def clear_cache(self) -> None:
        """Clear the fallback cache."""
        with self._lock:
            self._cache.clear()


# =============================================================================
# STATE MANAGER
# =============================================================================


class StateManager:
    """
    Manages circuit breaker state transitions.

    Handles:
    - State validation
    - Transition rules
    - State persistence
    - State history
    """

    # Valid state transitions
    VALID_TRANSITIONS: Dict[CircuitState, Set[CircuitState]] = {
        CircuitState.CLOSED: {
            CircuitState.OPEN,
            CircuitState.FORCED_OPEN,
            CircuitState.DISABLED,
        },
        CircuitState.OPEN: {
            CircuitState.HALF_OPEN,
            CircuitState.FORCED_OPEN,
            CircuitState.DISABLED,
        },
        CircuitState.HALF_OPEN: {
            CircuitState.CLOSED,
            CircuitState.OPEN,
            CircuitState.FORCED_OPEN,
            CircuitState.DISABLED,
        },
        CircuitState.FORCED_OPEN: {
            CircuitState.CLOSED,
            CircuitState.DISABLED,
        },
        CircuitState.DISABLED: {
            CircuitState.CLOSED,
            CircuitState.OPEN,
            CircuitState.FORCED_OPEN,
        },
    }

    def __init__(
        self,
        name: str,
        initial_state: CircuitState = CircuitState.CLOSED,
        persist_path: Optional[Path] = None,
        on_state_change: Optional[Callable[[CircuitState, CircuitState], None]] = None,
    ):
        self._name = name
        self._state = initial_state
        self._persist_path = persist_path
        self._on_state_change = on_state_change
        self._lock = threading.RLock()

        self._state_entered_at = time.time()
        self._half_open_calls = 0
        self._half_open_successes = 0
        self._recovery_attempts = 0

        self._history: Deque[Dict[str, Any]] = deque(maxlen=100)

        # Load persisted state if available
        if persist_path and persist_path.exists():
            self._load_state()

    @property
    def state(self) -> CircuitState:
        """Get current state."""
        with self._lock:
            return self._state

    @property
    def time_in_state(self) -> float:
        """Get time spent in current state (seconds)."""
        with self._lock:
            return time.time() - self._state_entered_at

    def can_transition(self, to_state: CircuitState) -> bool:
        """Check if transition to target state is valid."""
        with self._lock:
            return to_state in self.VALID_TRANSITIONS.get(self._state, set())

    def transition(self, to_state: CircuitState, reason: str = "") -> bool:
        """
        Attempt state transition.

        Returns True if transition succeeded, False otherwise.
        """
        with self._lock:
            if not self.can_transition(to_state):
                logger.warning(
                    f"Invalid state transition: {self._state.value} -> {to_state.value}"
                )
                return False

            old_state = self._state
            self._state = to_state
            self._state_entered_at = time.time()

            # Reset counters on state change
            if to_state == CircuitState.HALF_OPEN:
                self._half_open_calls = 0
                self._half_open_successes = 0
            elif to_state == CircuitState.OPEN:
                self._recovery_attempts = 0

            # Record history
            self._history.append(
                {
                    "from": old_state.value,
                    "to": to_state.value,
                    "reason": reason,
                    "timestamp": time.time(),
                }
            )

            # Persist if configured
            if self._persist_path:
                self._save_state()

            # Notify callback
            if self._on_state_change:
                try:
                    self._on_state_change(old_state, to_state)
                except Exception as e:
                    logger.error(f"State change callback failed: {e}")

            logger.info(
                f"[{self._name}] State transition: {old_state.value} -> {to_state.value}"
                f" (reason: {reason})"
            )

            return True

    def record_half_open_call(self, success: bool) -> CircuitState:
        """
        Record a call result in half-open state.

        Returns the new state after processing.
        """
        with self._lock:
            if self._state != CircuitState.HALF_OPEN:
                return self._state

            self._half_open_calls += 1
            if success:
                self._half_open_successes += 1
            else:
                # Any failure in half-open immediately reopens
                self.transition(CircuitState.OPEN, "failure during half-open test")

            return self._state

    def increment_recovery_attempts(self) -> int:
        """Increment and return recovery attempt count."""
        with self._lock:
            self._recovery_attempts += 1
            return self._recovery_attempts

    def get_state_info(self) -> Dict[str, Any]:
        """Get detailed state information."""
        with self._lock:
            return {
                "state": self._state.value,
                "time_in_state_seconds": time.time() - self._state_entered_at,
                "half_open_calls": self._half_open_calls,
                "half_open_successes": self._half_open_successes,
                "recovery_attempts": self._recovery_attempts,
                "history_length": len(self._history),
            }

    def get_history(self, limit: int = 10) -> List[Dict[str, Any]]:
        """Get recent state transition history."""
        with self._lock:
            return list(self._history)[-limit:]

    def _save_state(self) -> None:
        """Persist current state to file."""
        if not self._persist_path:
            return

        try:
            state_data = {
                "name": self._name,
                "state": self._state.value,
                "state_entered_at": self._state_entered_at,
                "recovery_attempts": self._recovery_attempts,
                "saved_at": time.time(),
            }
            self._persist_path.parent.mkdir(parents=True, exist_ok=True)
            with open(self._persist_path, "w") as f:
                json.dump(state_data, f, indent=2)
        except Exception as e:
            logger.error(f"Failed to persist state: {e}")

    def _load_state(self) -> None:
        """Load state from persistence file."""
        if not self._persist_path or not self._persist_path.exists():
            return

        try:
            with open(self._persist_path) as f:
                data = json.load(f)

            self._state = CircuitState(data["state"])
            self._state_entered_at = data.get("state_entered_at", time.time())
            self._recovery_attempts = data.get("recovery_attempts", 0)

            logger.info(
                f"[{self._name}] Loaded persisted state: {self._state.value}"
            )
        except Exception as e:
            logger.error(f"Failed to load persisted state: {e}")


# =============================================================================
# EXCEPTIONS
# =============================================================================


class CircuitBreakerError(Exception):
    """Base exception for circuit breaker errors."""

    pass


class CircuitBreakerOpenError(CircuitBreakerError):
    """Raised when circuit is open and call is rejected."""

    def __init__(self, message: str, circuit_name: str = ""):
        super().__init__(message)
        self.circuit_name = circuit_name


class CircuitBreakerTimeoutError(CircuitBreakerError):
    """Raised when call times out."""

    pass


# =============================================================================
# MAIN CIRCUIT BREAKER
# =============================================================================


class CircuitBreaker:
    """
    Production-grade circuit breaker implementation.

    Provides comprehensive failure detection, automatic recovery,
    and graceful degradation for protecting against cascading failures.

    Features:
        - Multiple state transitions (Closed/Open/Half-Open/Forced/Disabled)
        - Configurable failure thresholds (count and rate-based)
        - Multiple recovery strategies (immediate, exponential, adaptive)
        - Fallback handling with caching
        - Comprehensive metrics collection
        - Thread-safe operation
        - Async support
        - State persistence
    """

    def __init__(
        self,
        name: str,
        config: Optional[CircuitBreakerConfig] = None,
        recovery_strategy: Optional[RecoveryStrategy] = None,
        fallback_handler: Optional[FallbackHandler] = None,
        on_state_change: Optional[Callable[[CircuitState, CircuitState], None]] = None,
    ):
        self._name = name
        self._config = config or CircuitBreakerConfig()
        self._config.validate()

        # Initialize components
        self._state_manager = StateManager(
            name=name,
            persist_path=(
                Path(self._config.state_file_path)
                if self._config.state_file_path
                else None
            ),
            on_state_change=self._handle_state_change,
        )

        self._failure_detector = FailureDetector(
            slow_call_threshold_ms=self._config.slow_call_duration_threshold,
        )

        self._recovery_strategy = recovery_strategy or ExponentialBackoffRecovery(
            base_timeout=self._config.recovery_timeout,
        )

        self._fallback_handler = fallback_handler
        self._user_state_callback = on_state_change

        self._metrics = MetricsCollector(
            window_size=self._config.sliding_window_size,
            window_seconds=self._config.metrics_window_seconds,
        ) if self._config.metrics_enabled else None

        # Tracking
        self._failure_count = 0
        self._success_count = 0
        self._last_failure_time: Optional[float] = None
        self._lock = threading.RLock()

        logger.info(f"CircuitBreaker '{name}' initialized")

    @property
    def name(self) -> str:
        """Get circuit breaker name."""
        return self._name

    @property
    def state(self) -> CircuitState:
        """Get current state, checking for automatic recovery."""
        with self._lock:
            current = self._state_manager.state

            if (
                current == CircuitState.OPEN
                and self._config.automatic_transition_enabled
            ):
                if self._should_attempt_recovery():
                    self._state_manager.transition(
                        CircuitState.HALF_OPEN,
                        "recovery timeout elapsed",
                    )
                    return CircuitState.HALF_OPEN

            return current

    @property
    def is_available(self) -> bool:
        """Check if circuit allows requests."""
        state = self.state
        return state in (
            CircuitState.CLOSED,
            CircuitState.HALF_OPEN,
            CircuitState.DISABLED,
        )

    def _should_attempt_recovery(self) -> bool:
        """Check if conditions are met for recovery attempt."""
        if self._last_failure_time is None:
            return True

        time_since_open = time.time() - self._last_failure_time
        attempts = self._state_manager.increment_recovery_attempts()

        return self._recovery_strategy.should_attempt_recovery(
            time_since_open, attempts - 1
        )

    def _handle_state_change(
        self, old_state: CircuitState, new_state: CircuitState
    ) -> None:
        """Handle state transitions."""
        if self._metrics:
            self._metrics.record_state_change(old_state.value, new_state.value)

        if self._user_state_callback:
            try:
                self._user_state_callback(old_state, new_state)
            except Exception as e:
                logger.error(f"User state callback failed: {e}")

    def _record_success(self, duration_ms: float) -> None:
        """Record a successful call."""
        with self._lock:
            self._success_count += 1
            self._failure_count = max(0, self._failure_count - 1)  # Decay failures

            self._failure_detector.record_success()

            if self._metrics:
                self._metrics.record_call(duration_ms, success=True)

            # Handle half-open state
            if self._state_manager.state == CircuitState.HALF_OPEN:
                new_state = self._state_manager.record_half_open_call(success=True)

                # Check if we should close
                info = self._state_manager.get_state_info()
                if info["half_open_successes"] >= self._config.success_threshold:
                    self._state_manager.transition(
                        CircuitState.CLOSED,
                        f"successful recovery ({info['half_open_successes']} successes)",
                    )
                    self._recovery_strategy.reset()

    def _record_failure(
        self,
        duration_ms: float,
        exception: Optional[Exception] = None,
    ) -> None:
        """Record a failed call."""
        with self._lock:
            self._failure_count += 1
            self._last_failure_time = time.time()

            failure_type = (
                self._failure_detector.classify_exception(exception)
                if exception
                else FailureType.UNKNOWN
            )

            consecutive_breached, consecutive = self._failure_detector.record_failure()

            if self._metrics:
                self._metrics.record_call(
                    duration_ms,
                    success=False,
                    failure_type=failure_type,
                    exception=exception,
                )

            current_state = self._state_manager.state

            # Handle based on current state
            if current_state == CircuitState.HALF_OPEN:
                # Any failure in half-open reopens immediately
                self._state_manager.record_half_open_call(success=False)

            elif current_state == CircuitState.CLOSED:
                # Check if we should open
                window_stats = (
                    self._metrics.get_sliding_window_stats()
                    if self._metrics
                    else {"calls_in_window": 0, "failure_rate": 0}
                )

                should_trip = self._failure_detector.should_trip(
                    self._failure_count,
                    window_stats.get("calls_in_window", 0),
                    self._config.failure_threshold,
                    self._config.failure_rate_threshold,
                )

                if should_trip:
                    self._state_manager.transition(
                        CircuitState.OPEN,
                        f"failure threshold reached ({self._failure_count} failures)",
                    )

    def _record_rejection(self) -> None:
        """Record a rejected call."""
        with self._lock:
            if self._metrics:
                self._metrics.record_rejection()

    def execute(
        self,
        func: Callable[..., T],
        *args,
        cache_key: Optional[str] = None,
        **kwargs,
    ) -> T:
        """
        Execute a function with circuit breaker protection.

        Args:
            func: Function to execute
            *args: Positional arguments for function
            cache_key: Optional cache key for fallback caching
            **kwargs: Keyword arguments for function

        Returns:
            Result from function or fallback

        Raises:
            CircuitBreakerOpenError: If circuit is open and no fallback
        """
        # Check if circuit allows requests
        if not self.is_available:
            self._record_rejection()

            if self._fallback_handler:
                return self._fallback_handler.execute_fallback(
                    *args, cache_key=cache_key, **kwargs
                )

            raise CircuitBreakerOpenError(
                f"Circuit '{self._name}' is {self.state.value}",
                circuit_name=self._name,
            )

        # Execute with timing
        start_time = time.time()
        try:
            result = func(*args, **kwargs)
            duration_ms = (time.time() - start_time) * 1000

            self._record_success(duration_ms)

            # Cache result for fallback
            if self._fallback_handler and cache_key:
                self._fallback_handler.cache_result(cache_key, result)

            return result

        except self._config.ignore_exceptions:
            # Don't count ignored exceptions as failures
            duration_ms = (time.time() - start_time) * 1000
            self._record_success(duration_ms)
            raise

        except self._config.record_exceptions as e:
            duration_ms = (time.time() - start_time) * 1000
            self._record_failure(duration_ms, e)
            raise

    async def execute_async(
        self,
        func: Callable[..., Coroutine[Any, Any, T]],
        *args,
        cache_key: Optional[str] = None,
        timeout: Optional[float] = None,
        **kwargs,
    ) -> T:
        """
        Execute an async function with circuit breaker protection.

        Args:
            func: Async function to execute
            *args: Positional arguments for function
            cache_key: Optional cache key for fallback caching
            timeout: Optional timeout override
            **kwargs: Keyword arguments for function

        Returns:
            Result from function or fallback
        """
        # Check if circuit allows requests
        if not self.is_available:
            self._record_rejection()

            if self._fallback_handler:
                return await self._fallback_handler.execute_fallback_async(
                    *args, cache_key=cache_key, **kwargs
                )

            raise CircuitBreakerOpenError(
                f"Circuit '{self._name}' is {self.state.value}",
                circuit_name=self._name,
            )

        # Execute with timing and optional timeout
        effective_timeout = timeout or self._config.timeout
        start_time = time.time()

        try:
            if effective_timeout:
                result = await asyncio.wait_for(
                    func(*args, **kwargs),
                    timeout=effective_timeout,
                )
            else:
                result = await func(*args, **kwargs)

            duration_ms = (time.time() - start_time) * 1000
            self._record_success(duration_ms)

            # Cache result for fallback
            if self._fallback_handler and cache_key:
                self._fallback_handler.cache_result(cache_key, result)

            return result

        except asyncio.TimeoutError:
            duration_ms = (time.time() - start_time) * 1000
            if self._metrics:
                self._metrics.record_timeout()
            self._record_failure(duration_ms, CircuitBreakerTimeoutError("Call timed out"))
            raise CircuitBreakerTimeoutError(
                f"Call timed out after {effective_timeout}s"
            )

        except self._config.ignore_exceptions:
            duration_ms = (time.time() - start_time) * 1000
            self._record_success(duration_ms)
            raise

        except self._config.record_exceptions as e:
            duration_ms = (time.time() - start_time) * 1000
            self._record_failure(duration_ms, e)
            raise

    def protect(self, func: Callable[..., T]) -> Callable[..., T]:
        """
        Decorator to protect a synchronous function.

        Usage:
            @circuit_breaker.protect
            def call_api(data):
                return api.post(data)
        """

        @wraps(func)
        def wrapper(*args, **kwargs) -> T:
            return self.execute(func, *args, **kwargs)

        return wrapper

    def protect_async(
        self, func: Callable[..., Coroutine[Any, Any, T]]
    ) -> Callable[..., Coroutine[Any, Any, T]]:
        """
        Decorator to protect an async function.

        Usage:
            @circuit_breaker.protect_async
            async def call_api(data):
                return await api.post(data)
        """

        @wraps(func)
        async def wrapper(*args, **kwargs) -> T:
            return await self.execute_async(func, *args, **kwargs)

        return wrapper

    def __call__(self, func: Callable) -> Callable:
        """Allow using circuit breaker as decorator directly."""
        if asyncio.iscoroutinefunction(func):
            return self.protect_async(func)
        return self.protect(func)

    def force_open(self, reason: str = "manual") -> None:
        """Manually force circuit open."""
        with self._lock:
            self._state_manager.transition(
                CircuitState.FORCED_OPEN,
                f"forced open: {reason}",
            )

    def force_close(self, reason: str = "manual") -> None:
        """Manually force circuit closed."""
        with self._lock:
            self._state_manager.transition(
                CircuitState.CLOSED,
                f"forced close: {reason}",
            )
            self._failure_count = 0
            self._recovery_strategy.reset()

    def disable(self, reason: str = "manual") -> None:
        """Disable circuit breaker (pass-through mode)."""
        with self._lock:
            self._state_manager.transition(
                CircuitState.DISABLED,
                f"disabled: {reason}",
            )

    def enable(self) -> None:
        """Re-enable a disabled circuit breaker."""
        with self._lock:
            if self._state_manager.state == CircuitState.DISABLED:
                self._state_manager.transition(
                    CircuitState.CLOSED,
                    "re-enabled",
                )

    def reset(self) -> None:
        """Reset circuit breaker to initial state."""
        with self._lock:
            self._state_manager.transition(CircuitState.CLOSED, "reset")
            self._failure_count = 0
            self._success_count = 0
            self._last_failure_time = None
            self._recovery_strategy.reset()
            if self._metrics:
                self._metrics.reset()
            if self._fallback_handler:
                self._fallback_handler.clear_cache()

    def get_metrics(self) -> Optional[Dict[str, Any]]:
        """Get current metrics."""
        if not self._metrics:
            return None
        return self._metrics.export_metrics()

    def get_health(self) -> Dict[str, Any]:
        """Get circuit breaker health status."""
        with self._lock:
            state_info = self._state_manager.get_state_info()
            metrics = self._metrics.export_metrics() if self._metrics else {}

            return {
                "name": self._name,
                "state": state_info["state"],
                "is_available": self.is_available,
                "time_in_state_seconds": state_info["time_in_state_seconds"],
                "failure_count": self._failure_count,
                "health_score": metrics.get("health_score", 1.0),
                "failure_detector": self._failure_detector.get_failure_summary(),
                "recovery_attempts": state_info["recovery_attempts"],
            }

    def get_stats(self) -> Dict[str, Any]:
        """Get comprehensive statistics."""
        with self._lock:
            return {
                "name": self._name,
                "health": self.get_health(),
                "metrics": self.get_metrics(),
                "state_history": self._state_manager.get_history(),
                "fallback_stats": (
                    self._fallback_handler.get_stats()
                    if self._fallback_handler
                    else None
                ),
                "config": {
                    "failure_threshold": self._config.failure_threshold,
                    "recovery_timeout": self._config.recovery_timeout,
                    "failure_rate_threshold": self._config.failure_rate_threshold,
                },
            }


# =============================================================================
# CIRCUIT BREAKER REGISTRY
# =============================================================================


class CircuitBreakerRegistry:
    """
    Registry for managing multiple circuit breakers.

    Provides centralized management, health monitoring,
    and coordinated operations across all circuit breakers.
    """

    _instance: Optional["CircuitBreakerRegistry"] = None
    _lock = threading.Lock()

    def __new__(cls) -> "CircuitBreakerRegistry":
        """Singleton pattern for global registry."""
        with cls._lock:
            if cls._instance is None:
                cls._instance = super().__new__(cls)
                cls._instance._initialized = False
            return cls._instance

    def __init__(self):
        if self._initialized:
            return

        self._breakers: Dict[str, CircuitBreaker] = {}
        self._registry_lock = threading.RLock()
        self._initialized = True

    def register(
        self,
        name: str,
        config: Optional[CircuitBreakerConfig] = None,
        **kwargs,
    ) -> CircuitBreaker:
        """Register a new circuit breaker."""
        with self._registry_lock:
            if name in self._breakers:
                logger.warning(f"Circuit breaker '{name}' already exists, returning existing")
                return self._breakers[name]

            breaker = CircuitBreaker(name=name, config=config, **kwargs)
            self._breakers[name] = breaker
            logger.info(f"Registered circuit breaker: {name}")
            return breaker

    def get(self, name: str) -> Optional[CircuitBreaker]:
        """Get a circuit breaker by name."""
        return self._breakers.get(name)

    def get_or_create(
        self,
        name: str,
        config: Optional[CircuitBreakerConfig] = None,
        **kwargs,
    ) -> CircuitBreaker:
        """Get existing or create new circuit breaker."""
        with self._registry_lock:
            if name not in self._breakers:
                return self.register(name, config, **kwargs)
            return self._breakers[name]

    def unregister(self, name: str) -> bool:
        """Remove a circuit breaker from registry."""
        with self._registry_lock:
            if name in self._breakers:
                del self._breakers[name]
                logger.info(f"Unregistered circuit breaker: {name}")
                return True
            return False

    def get_all_health(self) -> Dict[str, Dict[str, Any]]:
        """Get health status of all circuit breakers."""
        with self._registry_lock:
            return {name: cb.get_health() for name, cb in self._breakers.items()}

    def get_summary(self) -> Dict[str, Any]:
        """Get summary of all circuit breakers."""
        with self._registry_lock:
            states = {name: cb.state.value for name, cb in self._breakers.items()}
            open_circuits = [n for n, s in states.items() if s == "open"]
            half_open = [n for n, s in states.items() if s == "half_open"]
            forced_open = [n for n, s in states.items() if s == "forced_open"]

            return {
                "total_circuits": len(self._breakers),
                "open": len(open_circuits),
                "half_open": len(half_open),
                "forced_open": len(forced_open),
                "closed": len(states) - len(open_circuits) - len(half_open) - len(forced_open),
                "open_circuits": open_circuits,
                "half_open_circuits": half_open,
                "healthy": len(open_circuits) == 0 and len(forced_open) == 0,
            }

    def reset_all(self) -> None:
        """Reset all circuit breakers."""
        with self._registry_lock:
            for cb in self._breakers.values():
                cb.reset()
            logger.info("Reset all circuit breakers")

    def force_open_all(self, reason: str = "global") -> None:
        """Force all circuits open (emergency stop)."""
        with self._registry_lock:
            for cb in self._breakers.values():
                cb.force_open(reason)
            logger.warning(f"Forced all circuits open: {reason}")


# =============================================================================
# CONVENIENCE FUNCTIONS
# =============================================================================


def get_registry() -> CircuitBreakerRegistry:
    """Get the global circuit breaker registry."""
    return CircuitBreakerRegistry()


def get_circuit_breaker(name: str, **kwargs) -> CircuitBreaker:
    """Get or create a circuit breaker from global registry."""
    return get_registry().get_or_create(name, **kwargs)


def circuit_breaker(
    name: str,
    failure_threshold: int = 5,
    recovery_timeout: float = 30.0,
    **kwargs,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
    """
    Decorator factory for circuit breaker protection.

    Usage:
        @circuit_breaker("external_api", failure_threshold=3)
        def call_api(data):
            return api.post(data)
    """
    config = CircuitBreakerConfig(
        failure_threshold=failure_threshold,
        recovery_timeout=recovery_timeout,
        **kwargs,
    )
    cb = get_circuit_breaker(name, config=config)
    return cb


# =============================================================================
# CLI INTERFACE
# =============================================================================


def main():
    """Command-line interface for testing circuit breaker."""
    import argparse
    import random

    parser = argparse.ArgumentParser(description="AIVA Circuit Breaker System")
    parser.add_argument(
        "command",
        choices=["status", "demo", "stress"],
        help="Command to run",
    )
    parser.add_argument(
        "--name",
        default="test",
        help="Circuit breaker name",
    )
    parser.add_argument(
        "--iterations",
        type=int,
        default=20,
        help="Number of iterations for demo/stress",
    )

    args = parser.parse_args()

    if args.command == "status":
        registry = get_registry()
        summary = registry.get_summary()
        print(json.dumps(summary, indent=2))

    elif args.command == "demo":
        print("=== Circuit Breaker Demo ===\n")

        # Configure with low thresholds for demo
        config = CircuitBreakerConfig(
            failure_threshold=3,
            recovery_timeout=5.0,
            half_open_max_calls=2,
            success_threshold=2,
        )

        fallback = FallbackHandler(
            default_value={"status": "degraded", "message": "using fallback"},
        )

        cb = CircuitBreaker(
            name=args.name,
            config=config,
            fallback_handler=fallback,
        )

        def unreliable_service(fail: bool = False):
            if fail:
                raise ConnectionError("Service unavailable")
            return {"status": "success"}

        print("Simulating service failures...\n")

        for i in range(args.iterations):
            # Fail first 5 calls, then succeed
            should_fail = i < 5

            try:
                result = cb.execute(unreliable_service, fail=should_fail)
                print(f"Call {i+1}: SUCCESS - {result}")
            except CircuitBreakerOpenError as e:
                print(f"Call {i+1}: REJECTED (circuit open)")
            except ConnectionError as e:
                print(f"Call {i+1}: FAILED - {e}")

            health = cb.get_health()
            print(f"  State: {health['state']}, Failures: {health['failure_count']}\n")

            time.sleep(0.5)

            # Wait longer when circuit opens to demonstrate recovery
            if health["state"] == "open":
                print("  Waiting for recovery timeout...\n")
                time.sleep(6)

        print("\n=== Final Statistics ===")
        print(json.dumps(cb.get_stats(), indent=2, default=str))

    elif args.command == "stress":
        print("=== Circuit Breaker Stress Test ===\n")

        config = CircuitBreakerConfig(
            failure_threshold=10,
            recovery_timeout=2.0,
            sliding_window_size=100,
        )

        cb = CircuitBreaker(name=f"{args.name}_stress", config=config)

        success_count = 0
        failure_count = 0
        rejection_count = 0

        def random_service():
            if random.random() < 0.3:  # 30% failure rate
                raise Exception("Random failure")
            time.sleep(random.uniform(0.01, 0.1))
            return "ok"

        start = time.time()

        for i in range(args.iterations):
            try:
                cb.execute(random_service)
                success_count += 1
            except CircuitBreakerOpenError:
                rejection_count += 1
            except Exception:
                failure_count += 1

        elapsed = time.time() - start

        print(f"Completed {args.iterations} calls in {elapsed:.2f}s")
        print(f"Successes: {success_count}")
        print(f"Failures: {failure_count}")
        print(f"Rejections: {rejection_count}")
        print(f"\nFinal state: {cb.state.value}")

        if cb.get_metrics():
            print(f"\nMetrics:")
            print(json.dumps(cb.get_metrics(), indent=2))


if __name__ == "__main__":
    main()
