#!/usr/bin/env python3
"""
GENESIS RETRY MANAGER
======================
Robust retry logic with exponential backoff and circuit breaker.

Features:
    - Exponential backoff
    - Jitter for avoiding thundering herd
    - Circuit breaker pattern
    - Retry decorators
    - Custom retry conditions
    - Statistics tracking

Usage:
    @retry(max_attempts=3, backoff=2.0)
    def flaky_operation():
        ...

    async with RetryContext(max_attempts=5) as ctx:
        result = await ctx.execute(operation)
"""

import asyncio
import random
import time
import threading
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from functools import wraps
from typing import Dict, List, Any, Optional, Callable, Type, Tuple, Union
from enum import Enum


class RetryStrategy(Enum):
    """Retry strategies."""
    FIXED = "fixed"              # Fixed delay between retries
    EXPONENTIAL = "exponential"  # Exponential backoff
    FIBONACCI = "fibonacci"      # Fibonacci sequence delays
    LINEAR = "linear"            # Linear increase


class CircuitState(Enum):
    """Circuit breaker states."""
    CLOSED = "closed"      # Normal operation
    OPEN = "open"          # Failing, reject calls
    HALF_OPEN = "half_open"  # Testing if recovered


@dataclass
class RetryConfig:
    """Retry configuration."""
    max_attempts: int = 3
    base_delay: float = 1.0
    max_delay: float = 60.0
    strategy: RetryStrategy = RetryStrategy.EXPONENTIAL
    exponential_base: float = 2.0
    jitter: bool = True
    jitter_factor: float = 0.2
    retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
    non_retryable_exceptions: Tuple[Type[Exception], ...] = ()

    def calculate_delay(self, attempt: int) -> float:
        """Calculate delay for given attempt number."""
        if self.strategy == RetryStrategy.FIXED:
            delay = self.base_delay

        elif self.strategy == RetryStrategy.EXPONENTIAL:
            delay = self.base_delay * (self.exponential_base ** attempt)

        elif self.strategy == RetryStrategy.FIBONACCI:
            # Fibonacci sequence
            a, b = self.base_delay, self.base_delay
            for _ in range(attempt):
                a, b = b, a + b
            delay = a

        elif self.strategy == RetryStrategy.LINEAR:
            delay = self.base_delay * (1 + attempt)

        else:
            delay = self.base_delay

        # Apply max
        delay = min(delay, self.max_delay)

        # Apply jitter
        if self.jitter:
            jitter_range = delay * self.jitter_factor
            delay += random.uniform(-jitter_range, jitter_range)
            delay = max(0, delay)

        return delay

    def is_retryable(self, exception: Exception) -> bool:
        """Check if exception should trigger retry."""
        if isinstance(exception, self.non_retryable_exceptions):
            return False
        return isinstance(exception, self.retryable_exceptions)


@dataclass
class RetryStats:
    """Statistics for retry operations."""
    total_attempts: int = 0
    successful_attempts: int = 0
    failed_attempts: int = 0
    retries: int = 0
    total_delay_seconds: float = 0.0
    last_attempt: Optional[str] = None
    last_error: Optional[str] = None

    @property
    def success_rate(self) -> float:
        if self.total_attempts == 0:
            return 0.0
        return self.successful_attempts / self.total_attempts

    def to_dict(self) -> Dict:
        return {
            "total_attempts": self.total_attempts,
            "successful": self.successful_attempts,
            "failed": self.failed_attempts,
            "retries": self.retries,
            "total_delay_seconds": round(self.total_delay_seconds, 2),
            "success_rate": round(self.success_rate, 4),
            "last_attempt": self.last_attempt,
            "last_error": self.last_error
        }


class CircuitBreaker:
    """
    Circuit breaker for protecting against cascading failures.
    """

    def __init__(
        self,
        name: str,
        failure_threshold: int = 5,
        recovery_timeout: float = 30.0,
        half_open_max_calls: int = 3
    ):
        self.name = name
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.half_open_max_calls = half_open_max_calls

        self._state = CircuitState.CLOSED
        self._failure_count = 0
        self._success_count = 0
        self._last_failure_time: Optional[float] = None
        self._half_open_calls = 0
        self._lock = threading.RLock()

    @property
    def state(self) -> CircuitState:
        with self._lock:
            if self._state == CircuitState.OPEN:
                if self._last_failure_time:
                    if time.time() - self._last_failure_time >= self.recovery_timeout:
                        self._state = CircuitState.HALF_OPEN
                        self._half_open_calls = 0
            return self._state

    def allow_request(self) -> bool:
        """Check if request should be allowed."""
        state = self.state  # This may transition state

        if state == CircuitState.CLOSED:
            return True
        elif state == CircuitState.HALF_OPEN:
            with self._lock:
                if self._half_open_calls < self.half_open_max_calls:
                    self._half_open_calls += 1
                    return True
                return False
        else:  # OPEN
            return False

    def record_success(self):
        """Record a successful call."""
        with self._lock:
            if self._state == CircuitState.HALF_OPEN:
                self._success_count += 1
                if self._success_count >= self.half_open_max_calls:
                    self._state = CircuitState.CLOSED
                    self._failure_count = 0
                    self._success_count = 0
            elif self._state == CircuitState.CLOSED:
                self._failure_count = 0

    def record_failure(self):
        """Record a failed call."""
        with self._lock:
            self._failure_count += 1
            self._last_failure_time = time.time()

            if self._state == CircuitState.HALF_OPEN:
                self._state = CircuitState.OPEN
                self._half_open_calls = 0
            elif self._state == CircuitState.CLOSED:
                if self._failure_count >= self.failure_threshold:
                    self._state = CircuitState.OPEN

    def reset(self):
        """Reset circuit breaker."""
        with self._lock:
            self._state = CircuitState.CLOSED
            self._failure_count = 0
            self._success_count = 0
            self._half_open_calls = 0
            self._last_failure_time = None

    def get_status(self) -> Dict:
        """Get circuit breaker status."""
        return {
            "name": self.name,
            "state": self.state.value,
            "failure_count": self._failure_count,
            "failure_threshold": self.failure_threshold
        }


class RetryManager:
    """
    Central retry management with circuit breakers.
    """

    def __init__(self, default_config: RetryConfig = None):
        self.default_config = default_config or RetryConfig()
        self._circuit_breakers: Dict[str, CircuitBreaker] = {}
        self._stats: Dict[str, RetryStats] = {}
        self._lock = threading.RLock()

    def get_circuit_breaker(self, name: str, **kwargs) -> CircuitBreaker:
        """Get or create circuit breaker."""
        with self._lock:
            if name not in self._circuit_breakers:
                self._circuit_breakers[name] = CircuitBreaker(name, **kwargs)
            return self._circuit_breakers[name]

    def get_stats(self, name: str) -> RetryStats:
        """Get retry stats for an operation."""
        with self._lock:
            if name not in self._stats:
                self._stats[name] = RetryStats()
            return self._stats[name]

    def execute(
        self,
        operation: Callable,
        name: str = None,
        config: RetryConfig = None,
        circuit_breaker: CircuitBreaker = None
    ) -> Any:
        """Execute operation with retry logic."""
        config = config or self.default_config
        name = name or operation.__name__
        stats = self.get_stats(name)

        for attempt in range(config.max_attempts):
            # Check circuit breaker
            if circuit_breaker and not circuit_breaker.allow_request():
                raise CircuitOpenError(f"Circuit breaker '{circuit_breaker.name}' is open")

            try:
                stats.total_attempts += 1
                stats.last_attempt = datetime.now().isoformat()

                result = operation()

                # Success
                stats.successful_attempts += 1
                if circuit_breaker:
                    circuit_breaker.record_success()
                return result

            except Exception as e:
                stats.last_error = str(e)

                if circuit_breaker:
                    circuit_breaker.record_failure()

                # Check if retryable
                if not config.is_retryable(e):
                    stats.failed_attempts += 1
                    raise

                # Check if more attempts
                if attempt >= config.max_attempts - 1:
                    stats.failed_attempts += 1
                    raise

                # Calculate and apply delay
                delay = config.calculate_delay(attempt)
                stats.total_delay_seconds += delay
                stats.retries += 1

                time.sleep(delay)

    async def execute_async(
        self,
        operation: Callable,
        name: str = None,
        config: RetryConfig = None,
        circuit_breaker: CircuitBreaker = None
    ) -> Any:
        """Execute async operation with retry logic."""
        config = config or self.default_config
        name = name or operation.__name__
        stats = self.get_stats(name)

        for attempt in range(config.max_attempts):
            if circuit_breaker and not circuit_breaker.allow_request():
                raise CircuitOpenError(f"Circuit breaker '{circuit_breaker.name}' is open")

            try:
                stats.total_attempts += 1
                stats.last_attempt = datetime.now().isoformat()

                if asyncio.iscoroutinefunction(operation):
                    result = await operation()
                else:
                    result = operation()

                stats.successful_attempts += 1
                if circuit_breaker:
                    circuit_breaker.record_success()
                return result

            except Exception as e:
                stats.last_error = str(e)

                if circuit_breaker:
                    circuit_breaker.record_failure()

                if not config.is_retryable(e):
                    stats.failed_attempts += 1
                    raise

                if attempt >= config.max_attempts - 1:
                    stats.failed_attempts += 1
                    raise

                delay = config.calculate_delay(attempt)
                stats.total_delay_seconds += delay
                stats.retries += 1

                await asyncio.sleep(delay)

    def get_all_stats(self) -> Dict[str, Dict]:
        """Get all retry statistics."""
        with self._lock:
            return {name: stats.to_dict() for name, stats in self._stats.items()}

    def get_all_circuit_breakers(self) -> Dict[str, Dict]:
        """Get all circuit breaker statuses."""
        with self._lock:
            return {name: cb.get_status() for name, cb in self._circuit_breakers.items()}


class CircuitOpenError(Exception):
    """Raised when circuit breaker is open."""
    pass


class MaxRetriesExceeded(Exception):
    """Raised when max retries exceeded."""
    pass


# Global retry manager
_manager = RetryManager()


def retry(
    max_attempts: int = 3,
    base_delay: float = 1.0,
    max_delay: float = 60.0,
    strategy: RetryStrategy = RetryStrategy.EXPONENTIAL,
    exponential_base: float = 2.0,
    jitter: bool = True,
    retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,),
    circuit_breaker_name: str = None
):
    """
    Decorator for adding retry logic to a function.
    """
    config = RetryConfig(
        max_attempts=max_attempts,
        base_delay=base_delay,
        max_delay=max_delay,
        strategy=strategy,
        exponential_base=exponential_base,
        jitter=jitter,
        retryable_exceptions=retryable_exceptions
    )

    def decorator(func):
        cb = _manager.get_circuit_breaker(circuit_breaker_name) if circuit_breaker_name else None

        @wraps(func)
        def wrapper(*args, **kwargs):
            return _manager.execute(
                lambda: func(*args, **kwargs),
                name=func.__name__,
                config=config,
                circuit_breaker=cb
            )

        @wraps(func)
        async def async_wrapper(*args, **kwargs):
            return await _manager.execute_async(
                lambda: func(*args, **kwargs),
                name=func.__name__,
                config=config,
                circuit_breaker=cb
            )

        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        return wrapper

    return decorator


class RetryContext:
    """
    Context manager for retry logic.
    """

    def __init__(
        self,
        max_attempts: int = 3,
        base_delay: float = 1.0,
        strategy: RetryStrategy = RetryStrategy.EXPONENTIAL
    ):
        self.config = RetryConfig(
            max_attempts=max_attempts,
            base_delay=base_delay,
            strategy=strategy
        )
        self.attempt = 0
        self.last_exception: Optional[Exception] = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        return False

    def should_retry(self, exception: Exception = None) -> bool:
        """Check if should retry."""
        if exception:
            self.last_exception = exception

        if not self.config.is_retryable(exception):
            return False

        return self.attempt < self.config.max_attempts

    def wait(self):
        """Wait before next retry."""
        delay = self.config.calculate_delay(self.attempt)
        time.sleep(delay)
        self.attempt += 1

    async def wait_async(self):
        """Async wait before next retry."""
        delay = self.config.calculate_delay(self.attempt)
        await asyncio.sleep(delay)
        self.attempt += 1


def main():
    """CLI and demo for retry manager."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Retry Manager")
    parser.add_argument("command", choices=["demo", "stats"])
    args = parser.parse_args()

    if args.command == "demo":
        print("Retry Manager Demo")
        print("=" * 40)

        # Simple retry decorator
        print("\n1. @retry decorator:")
        call_count = [0]

        @retry(max_attempts=3, base_delay=0.5, jitter=False)
        def flaky_operation():
            call_count[0] += 1
            if call_count[0] < 3:
                print(f"  Attempt {call_count[0]}: failing...")
                raise ValueError("Simulated failure")
            print(f"  Attempt {call_count[0]}: success!")
            return "done"

        try:
            result = flaky_operation()
            print(f"  Result: {result}")
        except Exception as e:
            print(f"  Failed: {e}")

        # Circuit breaker
        print("\n2. Circuit breaker:")
        cb = CircuitBreaker("demo", failure_threshold=3, recovery_timeout=2.0)

        for i in range(5):
            if cb.allow_request():
                print(f"  Request {i+1}: allowed (state: {cb.state.value})")
                cb.record_failure()
            else:
                print(f"  Request {i+1}: blocked (state: {cb.state.value})")

        print("  Waiting for recovery...")
        time.sleep(2.5)
        print(f"  State after wait: {cb.state.value}")

        # Retry strategies
        print("\n3. Retry strategies:")
        for strategy in RetryStrategy:
            config = RetryConfig(base_delay=1.0, strategy=strategy)
            delays = [config.calculate_delay(i) for i in range(5)]
            print(f"  {strategy.value}: {[round(d, 2) for d in delays]}")

        # RetryContext
        print("\n4. RetryContext:")
        success = False
        ctx = RetryContext(max_attempts=4, base_delay=0.2)
        attempt = 0

        while True:
            attempt += 1
            try:
                if attempt < 3:
                    raise ConnectionError("Connection failed")
                print(f"  Attempt {attempt}: success!")
                success = True
                break
            except Exception as e:
                print(f"  Attempt {attempt}: {e}")
                if ctx.should_retry(e):
                    ctx.wait()
                else:
                    break

        print(f"  Final success: {success}")

        # Stats
        print("\n5. Statistics:")
        import json
        print(json.dumps(_manager.get_all_stats(), indent=2))

    elif args.command == "stats":
        import json
        print("Retry Stats:")
        print(json.dumps(_manager.get_all_stats(), indent=2))
        print("\nCircuit Breakers:")
        print(json.dumps(_manager.get_all_circuit_breakers(), indent=2))


if __name__ == "__main__":
    main()
