"""
Genesis Retry Utilities
=======================
Exponential backoff with jitter for resilient operations.

Usage:
    from retry_utils import retry, RetryConfig

    # As decorator
    @retry(max_attempts=3, base_delay=1.0)
    def flaky_operation():
        # May fail sometimes
        pass

    # With custom config
    config = RetryConfig(max_attempts=5, max_delay=60.0)

    @retry(config=config)
    def another_operation():
        pass

    # Programmatic retry
    result = retry_call(flaky_operation, max_attempts=3)
"""

import time
import random
import functools
from typing import Callable, Any, Optional, Tuple, Type, Union, List
from dataclasses import dataclass, field
from datetime import datetime
import threading


@dataclass
class RetryConfig:
    """Configuration for retry behavior."""
    max_attempts: int = 3
    base_delay: float = 1.0         # Initial delay in seconds
    max_delay: float = 60.0         # Maximum delay cap
    exponential_base: float = 2.0   # Multiplier for exponential backoff
    jitter: bool = True             # Add randomness to prevent thundering herd
    jitter_range: Tuple[float, float] = (0.5, 1.5)  # Jitter multiplier range
    exceptions: Tuple[Type[Exception], ...] = (Exception,)  # Exceptions to retry
    on_retry: Optional[Callable[[int, Exception, float], None]] = None  # Callback


@dataclass
class RetryStats:
    """Statistics for retry operations."""
    total_calls: int = 0
    successful_first_try: int = 0
    successful_after_retry: int = 0
    failed_all_attempts: int = 0
    total_retries: int = 0
    total_delay_seconds: float = 0.0
    last_error: Optional[str] = None
    errors_by_type: dict = field(default_factory=dict)

    def record_success(self, attempts: int) -> None:
        self.total_calls += 1
        if attempts == 1:
            self.successful_first_try += 1
        else:
            self.successful_after_retry += 1
            self.total_retries += attempts - 1

    def record_failure(self, attempts: int, error: Exception) -> None:
        self.total_calls += 1
        self.failed_all_attempts += 1
        self.total_retries += attempts - 1
        self.last_error = str(error)
        error_type = type(error).__name__
        self.errors_by_type[error_type] = self.errors_by_type.get(error_type, 0) + 1

    def to_dict(self) -> dict:
        return {
            "total_calls": self.total_calls,
            "successful_first_try": self.successful_first_try,
            "successful_after_retry": self.successful_after_retry,
            "failed_all_attempts": self.failed_all_attempts,
            "total_retries": self.total_retries,
            "success_rate": (self.successful_first_try + self.successful_after_retry) /
                           max(1, self.total_calls),
            "first_try_rate": self.successful_first_try / max(1, self.total_calls),
            "total_delay_seconds": round(self.total_delay_seconds, 2),
            "errors_by_type": self.errors_by_type,
            "last_error": self.last_error
        }


# Global stats tracker
_global_stats = RetryStats()
_stats_lock = threading.Lock()


def calculate_delay(
    attempt: int,
    base_delay: float = 1.0,
    max_delay: float = 60.0,
    exponential_base: float = 2.0,
    jitter: bool = True,
    jitter_range: Tuple[float, float] = (0.5, 1.5)
) -> float:
    """
    Calculate delay for the given attempt number.

    Uses exponential backoff with optional jitter:
    delay = min(base * (exp_base ^ attempt), max_delay) * jitter

    Args:
        attempt: Current attempt number (0-indexed)
        base_delay: Base delay in seconds
        max_delay: Maximum delay cap
        exponential_base: Multiplier for backoff
        jitter: Whether to add randomness
        jitter_range: Range for jitter multiplier

    Returns:
        Delay in seconds
    """
    # Exponential backoff
    delay = min(base_delay * (exponential_base ** attempt), max_delay)

    # Add jitter to prevent thundering herd
    if jitter:
        jitter_multiplier = random.uniform(*jitter_range)
        delay *= jitter_multiplier

    return delay


def retry_call(
    func: Callable,
    *args,
    max_attempts: int = 3,
    base_delay: float = 1.0,
    max_delay: float = 60.0,
    exponential_base: float = 2.0,
    jitter: bool = True,
    exceptions: Tuple[Type[Exception], ...] = (Exception,),
    on_retry: Optional[Callable[[int, Exception, float], None]] = None,
    **kwargs
) -> Any:
    """
    Call a function with retry logic.

    Args:
        func: Function to call
        *args: Positional arguments for func
        max_attempts: Maximum number of attempts
        base_delay: Initial delay between retries
        max_delay: Maximum delay cap
        exponential_base: Backoff multiplier
        jitter: Add randomness to delays
        exceptions: Exception types to catch and retry
        on_retry: Callback(attempt, exception, delay) on each retry
        **kwargs: Keyword arguments for func

    Returns:
        Result of successful function call

    Raises:
        Last exception if all attempts fail
    """
    last_exception: Optional[Exception] = None
    total_delay = 0.0

    for attempt in range(max_attempts):
        try:
            result = func(*args, **kwargs)

            # Record success in global stats
            with _stats_lock:
                _global_stats.record_success(attempt + 1)
                _global_stats.total_delay_seconds += total_delay

            return result

        except exceptions as e:
            last_exception = e

            # Check if we have more attempts
            if attempt + 1 >= max_attempts:
                with _stats_lock:
                    _global_stats.record_failure(attempt + 1, e)
                    _global_stats.total_delay_seconds += total_delay
                raise

            # Calculate delay
            delay = calculate_delay(
                attempt=attempt,
                base_delay=base_delay,
                max_delay=max_delay,
                exponential_base=exponential_base,
                jitter=jitter
            )
            total_delay += delay

            # Call retry callback if provided
            if on_retry:
                on_retry(attempt + 1, e, delay)

            # Log retry
            print(f"[Retry] Attempt {attempt + 1}/{max_attempts} failed: {e}")
            print(f"[Retry] Waiting {delay:.2f}s before retry...")

            time.sleep(delay)

    # Should not reach here, but just in case
    if last_exception:
        raise last_exception


def retry(
    func: Optional[Callable] = None,
    *,
    max_attempts: int = 3,
    base_delay: float = 1.0,
    max_delay: float = 60.0,
    exponential_base: float = 2.0,
    jitter: bool = True,
    exceptions: Tuple[Type[Exception], ...] = (Exception,),
    on_retry: Optional[Callable[[int, Exception, float], None]] = None,
    config: Optional[RetryConfig] = None
) -> Callable:
    """
    Decorator for retry logic with exponential backoff.

    Can be used with or without parentheses:
        @retry
        def func(): ...

        @retry(max_attempts=5)
        def func(): ...

        @retry(config=RetryConfig(...))
        def func(): ...
    """
    # Use config if provided
    if config:
        max_attempts = config.max_attempts
        base_delay = config.base_delay
        max_delay = config.max_delay
        exponential_base = config.exponential_base
        jitter = config.jitter
        exceptions = config.exceptions
        on_retry = config.on_retry

    def decorator(f: Callable) -> Callable:
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            return retry_call(
                f, *args,
                max_attempts=max_attempts,
                base_delay=base_delay,
                max_delay=max_delay,
                exponential_base=exponential_base,
                jitter=jitter,
                exceptions=exceptions,
                on_retry=on_retry,
                **kwargs
            )
        return wrapper

    # Handle @retry without parentheses
    if func is not None:
        return decorator(func)

    return decorator


class RetryContext:
    """
    Context manager for retry logic.

    Usage:
        with RetryContext(max_attempts=3) as ctx:
            while ctx.should_continue():
                try:
                    result = risky_operation()
                    ctx.success()
                    break
                except Exception as e:
                    ctx.failed(e)
    """

    def __init__(
        self,
        max_attempts: int = 3,
        base_delay: float = 1.0,
        max_delay: float = 60.0,
        exponential_base: float = 2.0,
        jitter: bool = True
    ):
        self.max_attempts = max_attempts
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.exponential_base = exponential_base
        self.jitter = jitter

        self._attempt = 0
        self._succeeded = False
        self._last_error: Optional[Exception] = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        return False

    def should_continue(self) -> bool:
        """Check if should continue retrying."""
        return self._attempt < self.max_attempts and not self._succeeded

    def success(self) -> None:
        """Mark operation as successful."""
        self._succeeded = True

    def failed(self, error: Exception) -> None:
        """Record failure and wait before retry."""
        self._last_error = error
        self._attempt += 1

        if self._attempt < self.max_attempts:
            delay = calculate_delay(
                attempt=self._attempt - 1,
                base_delay=self.base_delay,
                max_delay=self.max_delay,
                exponential_base=self.exponential_base,
                jitter=self.jitter
            )
            print(f"[RetryContext] Attempt {self._attempt}/{self.max_attempts} failed")
            print(f"[RetryContext] Waiting {delay:.2f}s...")
            time.sleep(delay)

    @property
    def attempts(self) -> int:
        """Get number of attempts made."""
        return self._attempt

    @property
    def succeeded(self) -> bool:
        """Check if operation succeeded."""
        return self._succeeded

    @property
    def last_error(self) -> Optional[Exception]:
        """Get the last error."""
        return self._last_error


def get_retry_stats() -> dict:
    """Get global retry statistics."""
    with _stats_lock:
        return _global_stats.to_dict()


def reset_retry_stats() -> None:
    """Reset global retry statistics."""
    global _global_stats
    with _stats_lock:
        _global_stats = RetryStats()


# CLI interface
if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1:
        cmd = sys.argv[1]

        if cmd == "stats":
            print("Global Retry Stats:")
            stats = get_retry_stats()
            for k, v in stats.items():
                print(f"  {k}: {v}")

        elif cmd == "demo":
            print("Demo: Retry with exponential backoff")
            print("-" * 40)

            fail_count = 0
            max_fails = 2

            @retry(max_attempts=5, base_delay=0.5)
            def flaky_operation():
                global fail_count
                fail_count += 1
                if fail_count <= max_fails:
                    raise ConnectionError(f"Simulated failure {fail_count}")
                return "Success!"

            try:
                result = flaky_operation()
                print(f"\nResult: {result}")
            except Exception as e:
                print(f"\nFailed: {e}")

            print(f"\nStats: {get_retry_stats()}")

        elif cmd == "delays":
            print("Delay progression (5 attempts):")
            for i in range(5):
                delay = calculate_delay(i, base_delay=1.0, jitter=False)
                delay_j = calculate_delay(i, base_delay=1.0, jitter=True)
                print(f"  Attempt {i+1}: {delay:.1f}s (with jitter: ~{delay_j:.1f}s)")

        else:
            print(f"Unknown command: {cmd}")
            print("Usage: python retry_utils.py [stats|demo|delays]")
    else:
        print("Genesis Retry Utilities")
        print("Usage: python retry_utils.py [stats|demo|delays]")
