#!/usr/bin/env python3
"""
GENESIS RATE LIMITER
=====================
Advanced rate limiting with multiple algorithms and quotas.

Features:
    - Token bucket algorithm
    - Sliding window rate limiting
    - Fixed window counters
    - Per-key/per-user limits
    - Quota management
    - Decorator support
    - Redis-compatible backend

Usage:
    limiter = RateLimiter(rate=10, period=60)  # 10 requests per minute

    if limiter.allow("user:123"):
        process_request()
    else:
        raise RateLimitExceeded()

    @rate_limit(rate=5, period=60)
    def api_endpoint():
        return {"data": "..."}
"""

import asyncio
import hashlib
import json
import threading
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from functools import wraps
from pathlib import Path
from typing import Dict, List, Any, Optional, Callable, Union, Tuple


class RateLimitAlgorithm(Enum):
    """Rate limiting algorithms."""
    TOKEN_BUCKET = "token_bucket"
    SLIDING_WINDOW = "sliding_window"
    FIXED_WINDOW = "fixed_window"
    LEAKY_BUCKET = "leaky_bucket"


class RateLimitExceeded(Exception):
    """Raised when rate limit is exceeded."""
    def __init__(self, message: str, retry_after: float = None):
        super().__init__(message)
        self.retry_after = retry_after


@dataclass
class RateLimitResult:
    """Result of a rate limit check."""
    allowed: bool
    remaining: int
    limit: int
    reset_at: float  # Unix timestamp
    retry_after: Optional[float] = None  # Seconds until retry

    def to_headers(self) -> Dict[str, str]:
        """Convert to HTTP rate limit headers."""
        headers = {
            "X-RateLimit-Limit": str(self.limit),
            "X-RateLimit-Remaining": str(self.remaining),
            "X-RateLimit-Reset": str(int(self.reset_at))
        }
        if self.retry_after:
            headers["Retry-After"] = str(int(self.retry_after))
        return headers


@dataclass
class RateLimitConfig:
    """Configuration for a rate limit."""
    rate: int                  # Number of requests allowed
    period: float              # Time period in seconds
    algorithm: RateLimitAlgorithm = RateLimitAlgorithm.TOKEN_BUCKET
    burst: Optional[int] = None  # Max burst (token bucket)
    key_prefix: str = ""       # Prefix for rate limit keys
    group: Optional[str] = None  # Group for shared limits


@dataclass
class QuotaConfig:
    """Configuration for usage quotas."""
    name: str
    limit: int
    period: str  # "hour", "day", "month"
    soft_limit: Optional[int] = None  # Warning threshold
    rollover: bool = False  # Unused quota rolls over


class RateLimitBackend(ABC):
    """Abstract backend for rate limit storage."""

    @abstractmethod
    def get(self, key: str) -> Optional[Dict[str, Any]]:
        """Get rate limit data for key."""
        pass

    @abstractmethod
    def set(self, key: str, data: Dict[str, Any], ttl: int = None) -> bool:
        """Set rate limit data."""
        pass

    @abstractmethod
    def increment(self, key: str, amount: int = 1) -> int:
        """Increment counter and return new value."""
        pass

    @abstractmethod
    def delete(self, key: str) -> bool:
        """Delete rate limit data."""
        pass


class MemoryBackend(RateLimitBackend):
    """In-memory rate limit storage."""

    def __init__(self):
        self._data: Dict[str, Dict[str, Any]] = {}
        self._expires: Dict[str, float] = {}
        self._lock = threading.RLock()

    def _cleanup_expired(self):
        """Remove expired entries."""
        now = time.time()
        expired = [k for k, exp in self._expires.items() if exp < now]
        for key in expired:
            self._data.pop(key, None)
            self._expires.pop(key, None)

    def get(self, key: str) -> Optional[Dict[str, Any]]:
        with self._lock:
            self._cleanup_expired()
            return self._data.get(key)

    def set(self, key: str, data: Dict[str, Any], ttl: int = None) -> bool:
        with self._lock:
            self._data[key] = data.copy()
            if ttl:
                self._expires[key] = time.time() + ttl
            return True

    def increment(self, key: str, amount: int = 1) -> int:
        with self._lock:
            if key not in self._data:
                self._data[key] = {"count": 0}
            self._data[key]["count"] = self._data[key].get("count", 0) + amount
            return self._data[key]["count"]

    def delete(self, key: str) -> bool:
        with self._lock:
            if key in self._data:
                del self._data[key]
                self._expires.pop(key, None)
                return True
            return False


class TokenBucket:
    """Token bucket rate limiter."""

    def __init__(self, rate: int, period: float, burst: int = None):
        self.rate = rate  # Tokens per period
        self.period = period  # Period in seconds
        self.burst = burst or rate  # Max tokens (capacity)

        self._tokens = float(self.burst)
        self._last_update = time.time()
        self._lock = threading.RLock()

    def _refill(self):
        """Refill tokens based on time elapsed."""
        now = time.time()
        elapsed = now - self._last_update
        self._last_update = now

        # Calculate tokens to add
        tokens_per_second = self.rate / self.period
        tokens_to_add = elapsed * tokens_per_second

        self._tokens = min(self.burst, self._tokens + tokens_to_add)

    def allow(self, tokens: int = 1) -> Tuple[bool, float]:
        """
        Check if request should be allowed.
        Returns (allowed, tokens_remaining).
        """
        with self._lock:
            self._refill()

            if self._tokens >= tokens:
                self._tokens -= tokens
                return True, self._tokens
            else:
                # Calculate time until enough tokens available
                deficit = tokens - self._tokens
                tokens_per_second = self.rate / self.period
                wait_time = deficit / tokens_per_second
                return False, wait_time

    def get_status(self) -> Dict:
        with self._lock:
            self._refill()
            return {
                "tokens": round(self._tokens, 2),
                "capacity": self.burst,
                "rate": self.rate,
                "period": self.period
            }


class SlidingWindowCounter:
    """Sliding window log rate limiter."""

    def __init__(self, rate: int, period: float):
        self.rate = rate
        self.period = period
        self._timestamps: List[float] = []
        self._lock = threading.RLock()

    def _cleanup(self):
        """Remove old timestamps."""
        cutoff = time.time() - self.period
        self._timestamps = [t for t in self._timestamps if t > cutoff]

    def allow(self) -> Tuple[bool, int]:
        """Check if request should be allowed."""
        with self._lock:
            self._cleanup()

            if len(self._timestamps) < self.rate:
                self._timestamps.append(time.time())
                return True, self.rate - len(self._timestamps)
            else:
                # Calculate time until oldest expires
                oldest = min(self._timestamps)
                retry_after = oldest + self.period - time.time()
                return False, max(0, retry_after)

    def get_count(self) -> int:
        with self._lock:
            self._cleanup()
            return len(self._timestamps)


class FixedWindowCounter:
    """Fixed window counter rate limiter."""

    def __init__(self, rate: int, period: float):
        self.rate = rate
        self.period = period
        self._count = 0
        self._window_start = time.time()
        self._lock = threading.RLock()

    def _reset_if_needed(self):
        """Reset counter if window has passed."""
        now = time.time()
        if now - self._window_start >= self.period:
            self._count = 0
            self._window_start = now

    def allow(self) -> Tuple[bool, int, float]:
        """Check if request should be allowed."""
        with self._lock:
            self._reset_if_needed()

            if self._count < self.rate:
                self._count += 1
                remaining = self.rate - self._count
                reset_at = self._window_start + self.period
                return True, remaining, reset_at
            else:
                reset_at = self._window_start + self.period
                retry_after = reset_at - time.time()
                return False, 0, reset_at


class RateLimiter:
    """
    Unified rate limiter with multiple algorithms.
    """

    def __init__(
        self,
        rate: int = 100,
        period: float = 60.0,
        algorithm: RateLimitAlgorithm = RateLimitAlgorithm.TOKEN_BUCKET,
        burst: int = None,
        backend: RateLimitBackend = None
    ):
        self.config = RateLimitConfig(
            rate=rate,
            period=period,
            algorithm=algorithm,
            burst=burst
        )
        self.backend = backend or MemoryBackend()
        self._limiters: Dict[str, Any] = {}
        self._lock = threading.RLock()

        # Stats tracking
        self._stats = {
            "allowed": 0,
            "denied": 0,
            "total": 0
        }

    def _get_limiter(self, key: str):
        """Get or create limiter for key."""
        with self._lock:
            if key not in self._limiters:
                if self.config.algorithm == RateLimitAlgorithm.TOKEN_BUCKET:
                    self._limiters[key] = TokenBucket(
                        self.config.rate,
                        self.config.period,
                        self.config.burst
                    )
                elif self.config.algorithm == RateLimitAlgorithm.SLIDING_WINDOW:
                    self._limiters[key] = SlidingWindowCounter(
                        self.config.rate,
                        self.config.period
                    )
                elif self.config.algorithm == RateLimitAlgorithm.FIXED_WINDOW:
                    self._limiters[key] = FixedWindowCounter(
                        self.config.rate,
                        self.config.period
                    )
                else:
                    self._limiters[key] = TokenBucket(
                        self.config.rate,
                        self.config.period,
                        self.config.burst
                    )
            return self._limiters[key]

    def allow(self, key: str = "default", cost: int = 1) -> RateLimitResult:
        """Check if request should be allowed."""
        limiter = self._get_limiter(key)
        self._stats["total"] += 1

        if isinstance(limiter, TokenBucket):
            allowed, remaining_or_wait = limiter.allow(cost)
            if allowed:
                self._stats["allowed"] += 1
                return RateLimitResult(
                    allowed=True,
                    remaining=int(remaining_or_wait),
                    limit=self.config.rate,
                    reset_at=time.time() + self.config.period
                )
            else:
                self._stats["denied"] += 1
                return RateLimitResult(
                    allowed=False,
                    remaining=0,
                    limit=self.config.rate,
                    reset_at=time.time() + remaining_or_wait,
                    retry_after=remaining_or_wait
                )

        elif isinstance(limiter, SlidingWindowCounter):
            allowed, remaining_or_wait = limiter.allow()
            if allowed:
                self._stats["allowed"] += 1
                return RateLimitResult(
                    allowed=True,
                    remaining=remaining_or_wait,
                    limit=self.config.rate,
                    reset_at=time.time() + self.config.period
                )
            else:
                self._stats["denied"] += 1
                return RateLimitResult(
                    allowed=False,
                    remaining=0,
                    limit=self.config.rate,
                    reset_at=time.time() + remaining_or_wait,
                    retry_after=remaining_or_wait
                )

        elif isinstance(limiter, FixedWindowCounter):
            allowed, remaining, reset_at = limiter.allow()
            if allowed:
                self._stats["allowed"] += 1
                return RateLimitResult(
                    allowed=True,
                    remaining=remaining,
                    limit=self.config.rate,
                    reset_at=reset_at
                )
            else:
                self._stats["denied"] += 1
                retry_after = reset_at - time.time()
                return RateLimitResult(
                    allowed=False,
                    remaining=0,
                    limit=self.config.rate,
                    reset_at=reset_at,
                    retry_after=max(0, retry_after)
                )

        # Default fallback
        return RateLimitResult(
            allowed=True,
            remaining=self.config.rate,
            limit=self.config.rate,
            reset_at=time.time() + self.config.period
        )

    def check(self, key: str = "default") -> bool:
        """Simple check if request is allowed (convenience method)."""
        return self.allow(key).allowed

    def reset(self, key: str = None):
        """Reset rate limit for a key or all keys."""
        with self._lock:
            if key:
                self._limiters.pop(key, None)
            else:
                self._limiters.clear()

    def get_stats(self) -> Dict:
        """Get rate limiter statistics."""
        return {
            **self._stats,
            "denial_rate": (
                self._stats["denied"] / self._stats["total"]
                if self._stats["total"] > 0 else 0
            ),
            "active_keys": len(self._limiters),
            "config": {
                "rate": self.config.rate,
                "period": self.config.period,
                "algorithm": self.config.algorithm.value
            }
        }


class QuotaManager:
    """
    Manages usage quotas over longer periods.
    """

    def __init__(self, backend: RateLimitBackend = None):
        self.backend = backend or MemoryBackend()
        self._quotas: Dict[str, QuotaConfig] = {}
        self._lock = threading.RLock()

    def define_quota(self, config: QuotaConfig):
        """Define a quota."""
        self._quotas[config.name] = config

    def get_usage(self, quota_name: str, key: str) -> Dict:
        """Get current usage for a quota."""
        if quota_name not in self._quotas:
            return {"error": "quota not found"}

        config = self._quotas[quota_name]
        data_key = f"quota:{quota_name}:{key}"
        data = self.backend.get(data_key) or {"used": 0, "period_start": time.time()}

        # Check if period has reset
        period_seconds = self._period_to_seconds(config.period)
        if time.time() - data.get("period_start", 0) > period_seconds:
            # Reset with optional rollover
            rollover = 0
            if config.rollover:
                rollover = max(0, config.limit - data.get("used", 0))
            data = {"used": 0, "period_start": time.time(), "rollover": rollover}
            self.backend.set(data_key, data, ttl=int(period_seconds * 2))

        return {
            "quota": quota_name,
            "used": data.get("used", 0),
            "limit": config.limit + data.get("rollover", 0),
            "remaining": config.limit + data.get("rollover", 0) - data.get("used", 0),
            "period": config.period,
            "soft_limit": config.soft_limit,
            "period_start": data.get("period_start")
        }

    def consume(self, quota_name: str, key: str, amount: int = 1) -> Dict:
        """Consume quota. Returns usage info and whether allowed."""
        if quota_name not in self._quotas:
            return {"allowed": False, "error": "quota not found"}

        config = self._quotas[quota_name]
        usage = self.get_usage(quota_name, key)

        if usage.get("remaining", 0) < amount:
            return {
                "allowed": False,
                **usage,
                "error": "quota exceeded"
            }

        # Consume
        data_key = f"quota:{quota_name}:{key}"
        data = self.backend.get(data_key) or {"used": 0, "period_start": time.time()}
        data["used"] = data.get("used", 0) + amount
        period_seconds = self._period_to_seconds(config.period)
        self.backend.set(data_key, data, ttl=int(period_seconds * 2))

        return {
            "allowed": True,
            "used": data["used"],
            "limit": config.limit + data.get("rollover", 0),
            "remaining": config.limit + data.get("rollover", 0) - data["used"],
            "warning": (
                config.soft_limit and data["used"] >= config.soft_limit
            )
        }

    def _period_to_seconds(self, period: str) -> float:
        """Convert period string to seconds."""
        periods = {
            "minute": 60,
            "hour": 3600,
            "day": 86400,
            "week": 604800,
            "month": 2592000  # 30 days
        }
        return periods.get(period, 3600)


class RateLimitMiddleware:
    """Middleware for applying rate limits to functions."""

    def __init__(self, limiter: RateLimiter, key_func: Callable = None):
        self.limiter = limiter
        self.key_func = key_func or (lambda *a, **k: "default")

    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            key = self.key_func(*args, **kwargs)
            result = self.limiter.allow(key)

            if not result.allowed:
                raise RateLimitExceeded(
                    f"Rate limit exceeded. Retry after {result.retry_after:.1f}s",
                    retry_after=result.retry_after
                )

            return func(*args, **kwargs)

        @wraps(func)
        async def async_wrapper(*args, **kwargs):
            key = self.key_func(*args, **kwargs)
            result = self.limiter.allow(key)

            if not result.allowed:
                raise RateLimitExceeded(
                    f"Rate limit exceeded. Retry after {result.retry_after:.1f}s",
                    retry_after=result.retry_after
                )

            return await func(*args, **kwargs)

        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        return wrapper


# Convenience decorator
def rate_limit(
    rate: int = 10,
    period: float = 60.0,
    algorithm: RateLimitAlgorithm = RateLimitAlgorithm.TOKEN_BUCKET,
    key_func: Callable = None
):
    """Decorator to rate limit a function."""
    limiter = RateLimiter(rate=rate, period=period, algorithm=algorithm)
    return RateLimitMiddleware(limiter, key_func)


# Global rate limiter registry
_limiters: Dict[str, RateLimiter] = {}
_lock = threading.RLock()


def get_limiter(name: str, **kwargs) -> RateLimiter:
    """Get or create a named rate limiter."""
    with _lock:
        if name not in _limiters:
            _limiters[name] = RateLimiter(**kwargs)
        return _limiters[name]


def main():
    """CLI and demo for rate limiter."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Rate Limiter")
    parser.add_argument("command", choices=["demo", "algorithms", "quota"])
    args = parser.parse_args()

    if args.command == "demo":
        print("Rate Limiter Demo")
        print("=" * 40)

        # Token bucket
        print("\n1. Token Bucket (5 requests/10 seconds):")
        limiter = RateLimiter(rate=5, period=10, algorithm=RateLimitAlgorithm.TOKEN_BUCKET)

        for i in range(8):
            result = limiter.allow("user:1")
            status = "✓" if result.allowed else "✗"
            print(f"  Request {i+1}: {status} (remaining: {result.remaining})")
            if not result.allowed:
                print(f"    Retry after: {result.retry_after:.1f}s")

        print(f"\n  Stats: {json.dumps(limiter.get_stats(), indent=4)}")

        # Sliding window
        print("\n2. Sliding Window (3 requests/5 seconds):")
        limiter2 = RateLimiter(rate=3, period=5, algorithm=RateLimitAlgorithm.SLIDING_WINDOW)

        for i in range(5):
            result = limiter2.allow("api")
            status = "✓" if result.allowed else "✗"
            print(f"  Request {i+1}: {status}")
            time.sleep(0.3)

        # Fixed window
        print("\n3. Fixed Window (4 requests/10 seconds):")
        limiter3 = RateLimiter(rate=4, period=10, algorithm=RateLimitAlgorithm.FIXED_WINDOW)

        for i in range(6):
            result = limiter3.allow("endpoint")
            status = "✓" if result.allowed else "✗"
            print(f"  Request {i+1}: {status} (remaining: {result.remaining})")

        # Decorator
        print("\n4. Rate limit decorator:")

        @rate_limit(rate=3, period=5)
        def protected_function():
            return "success"

        for i in range(5):
            try:
                result = protected_function()
                print(f"  Call {i+1}: {result}")
            except RateLimitExceeded as e:
                print(f"  Call {i+1}: Rate limited - {e}")

        # HTTP headers
        print("\n5. HTTP Headers:")
        result = limiter.allow("http:client")
        print(f"  {json.dumps(result.to_headers(), indent=4)}")

    elif args.command == "algorithms":
        print("Algorithm Comparison")
        print("=" * 40)

        algorithms = [
            ("Token Bucket", RateLimitAlgorithm.TOKEN_BUCKET),
            ("Sliding Window", RateLimitAlgorithm.SLIDING_WINDOW),
            ("Fixed Window", RateLimitAlgorithm.FIXED_WINDOW)
        ]

        for name, algo in algorithms:
            print(f"\n{name}:")
            limiter = RateLimiter(rate=5, period=2, algorithm=algo)

            # Burst of 10 requests
            allowed = 0
            for _ in range(10):
                if limiter.allow(f"test:{name}").allowed:
                    allowed += 1

            print(f"  Burst of 10 requests: {allowed} allowed")

            # Wait and try again
            time.sleep(1)
            result = limiter.allow(f"test:{name}")
            print(f"  After 1s wait: {'allowed' if result.allowed else 'denied'}")

    elif args.command == "quota":
        print("Quota Management Demo")
        print("=" * 40)

        qm = QuotaManager()

        # Define quotas
        qm.define_quota(QuotaConfig(
            name="api_calls",
            limit=100,
            period="hour",
            soft_limit=80
        ))

        qm.define_quota(QuotaConfig(
            name="storage",
            limit=1000,
            period="month",
            rollover=True
        ))

        print("\n1. API Calls Quota:")
        for i in range(5):
            result = qm.consume("api_calls", "user:123", amount=20)
            print(f"  Consume 20: used={result['used']}, remaining={result['remaining']}")
            if result.get("warning"):
                print(f"    ⚠️ Approaching limit!")
            if not result["allowed"]:
                print(f"    ❌ {result.get('error')}")

        print("\n2. Storage Quota:")
        usage = qm.get_usage("storage", "user:123")
        print(f"  Current: {json.dumps(usage, indent=4)}")


if __name__ == "__main__":
    main()
