"""
AIVA Production Rate Limiting System
=====================================
Comprehensive rate limiting infrastructure for the AIVA API.

Components:
- TokenBucket: Classic token bucket algorithm for burst handling
- SlidingWindow: Sliding window rate limiter for smooth limiting
- LeakyBucket: Leaky bucket for request smoothing
- AdaptiveRateLimiter: Dynamic limit adjustment based on system load
- DistributedRateLimiter: Cross-node rate limiting via Redis
- QuotaManager: Per-user quota management and enforcement

Author: Genesis System
Version: 1.0.0
"""

from __future__ import annotations

import asyncio
import hashlib
import json
import logging
import time
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum, auto
from typing import (
    Any,
    Callable,
    Coroutine,
    Dict,
    Generic,
    List,
    Optional,
    Set,
    Tuple,
    TypeVar,
    Union,
)
import threading
from contextlib import asynccontextmanager
from functools import wraps

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("aiva.rate_limiter")


# =============================================================================
# ENUMS AND CONSTANTS
# =============================================================================

class RateLimitStrategy(Enum):
    """Rate limiting strategies."""
    TOKEN_BUCKET = auto()
    SLIDING_WINDOW = auto()
    LEAKY_BUCKET = auto()
    ADAPTIVE = auto()
    DISTRIBUTED = auto()


class QuotaPeriod(Enum):
    """Quota period types."""
    SECOND = 1
    MINUTE = 60
    HOUR = 3600
    DAY = 86400
    WEEK = 604800
    MONTH = 2592000  # 30 days


class RateLimitAction(Enum):
    """Actions to take when rate limit is exceeded."""
    REJECT = auto()
    QUEUE = auto()
    THROTTLE = auto()
    DEGRADE = auto()


# Default configurations
DEFAULT_TOKEN_RATE = 100  # tokens per second
DEFAULT_BUCKET_SIZE = 1000  # max tokens
DEFAULT_WINDOW_SIZE = 60  # seconds
DEFAULT_MAX_REQUESTS = 1000  # per window


# =============================================================================
# DATA CLASSES
# =============================================================================

@dataclass
class RateLimitResult:
    """Result of a rate limit check."""
    allowed: bool
    remaining: int
    reset_at: float
    retry_after: Optional[float] = None
    limit: int = 0
    current: int = 0
    strategy: str = ""
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_headers(self) -> Dict[str, str]:
        """Convert to HTTP rate limit headers."""
        headers = {
            "X-RateLimit-Limit": str(self.limit),
            "X-RateLimit-Remaining": str(self.remaining),
            "X-RateLimit-Reset": str(int(self.reset_at)),
        }
        if self.retry_after is not None:
            headers["Retry-After"] = str(int(self.retry_after))
        return headers


@dataclass
class QuotaConfig:
    """Configuration for user quotas."""
    user_id: str
    requests_per_second: int = 10
    requests_per_minute: int = 300
    requests_per_hour: int = 10000
    requests_per_day: int = 100000
    burst_allowance: float = 1.5  # 150% of normal rate
    priority: int = 0  # Higher = more priority
    custom_limits: Dict[str, int] = field(default_factory=dict)


@dataclass
class SystemMetrics:
    """System metrics for adaptive rate limiting."""
    cpu_usage: float = 0.0
    memory_usage: float = 0.0
    request_latency_p99: float = 0.0
    error_rate: float = 0.0
    queue_depth: int = 0
    active_connections: int = 0
    timestamp: float = field(default_factory=time.time)


# =============================================================================
# BASE RATE LIMITER
# =============================================================================

class BaseRateLimiter(ABC):
    """Abstract base class for rate limiters."""

    def __init__(self, name: str = "base"):
        self.name = name
        self._lock = asyncio.Lock()
        self._sync_lock = threading.Lock()
        self.stats = {
            "allowed": 0,
            "rejected": 0,
            "total": 0,
        }

    @abstractmethod
    async def acquire(self, key: str, cost: int = 1) -> RateLimitResult:
        """Attempt to acquire rate limit permission."""
        pass

    @abstractmethod
    async def release(self, key: str, cost: int = 1) -> None:
        """Release acquired rate limit (if applicable)."""
        pass

    @abstractmethod
    def reset(self, key: str) -> None:
        """Reset rate limit for a key."""
        pass

    def get_stats(self) -> Dict[str, Any]:
        """Get rate limiter statistics."""
        return {
            "name": self.name,
            "allowed": self.stats["allowed"],
            "rejected": self.stats["rejected"],
            "total": self.stats["total"],
            "rejection_rate": (
                self.stats["rejected"] / self.stats["total"]
                if self.stats["total"] > 0 else 0
            ),
        }

    def _record_result(self, allowed: bool) -> None:
        """Record rate limit result for statistics."""
        with self._sync_lock:
            self.stats["total"] += 1
            if allowed:
                self.stats["allowed"] += 1
            else:
                self.stats["rejected"] += 1


# =============================================================================
# TOKEN BUCKET RATE LIMITER
# =============================================================================

class TokenBucket(BaseRateLimiter):
    """
    Token Bucket Algorithm Implementation.

    The token bucket algorithm is a control algorithm that allows for
    burst traffic while maintaining a long-term rate limit.

    Features:
    - Configurable bucket size (burst capacity)
    - Configurable refill rate
    - Thread-safe async operations
    - Per-key bucket management
    """

    def __init__(
        self,
        rate: float = DEFAULT_TOKEN_RATE,
        bucket_size: int = DEFAULT_BUCKET_SIZE,
        name: str = "token_bucket",
    ):
        super().__init__(name)
        self.rate = rate  # tokens per second
        self.bucket_size = bucket_size
        self._buckets: Dict[str, Dict[str, float]] = {}

    def _get_bucket(self, key: str) -> Dict[str, float]:
        """Get or create a bucket for a key."""
        if key not in self._buckets:
            self._buckets[key] = {
                "tokens": float(self.bucket_size),
                "last_update": time.time(),
            }
        return self._buckets[key]

    def _refill(self, bucket: Dict[str, float]) -> None:
        """Refill tokens based on elapsed time."""
        now = time.time()
        elapsed = now - bucket["last_update"]
        tokens_to_add = elapsed * self.rate
        bucket["tokens"] = min(
            self.bucket_size,
            bucket["tokens"] + tokens_to_add
        )
        bucket["last_update"] = now

    async def acquire(self, key: str, cost: int = 1) -> RateLimitResult:
        """
        Attempt to acquire tokens from the bucket.

        Args:
            key: Identifier for the rate limit (e.g., user_id, ip)
            cost: Number of tokens required

        Returns:
            RateLimitResult with status and metadata
        """
        async with self._lock:
            bucket = self._get_bucket(key)
            self._refill(bucket)

            allowed = bucket["tokens"] >= cost

            if allowed:
                bucket["tokens"] -= cost
                remaining = int(bucket["tokens"])
                retry_after = None
            else:
                remaining = 0
                # Calculate time until enough tokens are available
                tokens_needed = cost - bucket["tokens"]
                retry_after = tokens_needed / self.rate

            # Calculate reset time (when bucket is full)
            tokens_to_full = self.bucket_size - bucket["tokens"]
            reset_at = time.time() + (tokens_to_full / self.rate)

            self._record_result(allowed)

            return RateLimitResult(
                allowed=allowed,
                remaining=remaining,
                reset_at=reset_at,
                retry_after=retry_after,
                limit=self.bucket_size,
                current=self.bucket_size - remaining,
                strategy="token_bucket",
                metadata={
                    "rate": self.rate,
                    "bucket_size": self.bucket_size,
                    "tokens": bucket["tokens"],
                },
            )

    async def release(self, key: str, cost: int = 1) -> None:
        """Token bucket doesn't support release - tokens auto-refill."""
        pass

    def reset(self, key: str) -> None:
        """Reset bucket to full capacity."""
        with self._sync_lock:
            if key in self._buckets:
                self._buckets[key] = {
                    "tokens": float(self.bucket_size),
                    "last_update": time.time(),
                }


# =============================================================================
# SLIDING WINDOW RATE LIMITER
# =============================================================================

class SlidingWindow(BaseRateLimiter):
    """
    Sliding Window Rate Limiter.

    Uses a sliding window approach to count requests over time,
    providing smoother rate limiting than fixed windows.

    Features:
    - Configurable window size
    - Request timestamp tracking
    - Automatic cleanup of old entries
    - Per-key window management
    """

    def __init__(
        self,
        window_size: int = DEFAULT_WINDOW_SIZE,
        max_requests: int = DEFAULT_MAX_REQUESTS,
        name: str = "sliding_window",
    ):
        super().__init__(name)
        self.window_size = window_size  # seconds
        self.max_requests = max_requests
        self._windows: Dict[str, deque] = {}
        self._cleanup_interval = window_size * 2
        self._last_cleanup = time.time()

    def _get_window(self, key: str) -> deque:
        """Get or create a window for a key."""
        if key not in self._windows:
            self._windows[key] = deque()
        return self._windows[key]

    def _cleanup_window(self, window: deque, now: float) -> None:
        """Remove expired timestamps from window."""
        cutoff = now - self.window_size
        while window and window[0] < cutoff:
            window.popleft()

    def _maybe_global_cleanup(self) -> None:
        """Periodically clean up all windows."""
        now = time.time()
        if now - self._last_cleanup > self._cleanup_interval:
            self._last_cleanup = now
            empty_keys = []
            for key, window in self._windows.items():
                self._cleanup_window(window, now)
                if not window:
                    empty_keys.append(key)
            for key in empty_keys:
                del self._windows[key]

    async def acquire(self, key: str, cost: int = 1) -> RateLimitResult:
        """
        Attempt to acquire a request slot in the sliding window.

        Args:
            key: Identifier for the rate limit
            cost: Number of request slots to acquire

        Returns:
            RateLimitResult with status and metadata
        """
        async with self._lock:
            now = time.time()
            window = self._get_window(key)
            self._cleanup_window(window, now)
            self._maybe_global_cleanup()

            current_count = len(window)
            allowed = current_count + cost <= self.max_requests

            if allowed:
                # Add timestamps for each unit of cost
                for _ in range(cost):
                    window.append(now)
                remaining = self.max_requests - len(window)
                retry_after = None
            else:
                remaining = 0
                # Calculate when oldest request expires
                if window:
                    oldest = window[0]
                    retry_after = (oldest + self.window_size) - now
                else:
                    retry_after = 0

            reset_at = now + self.window_size

            self._record_result(allowed)

            return RateLimitResult(
                allowed=allowed,
                remaining=remaining,
                reset_at=reset_at,
                retry_after=retry_after,
                limit=self.max_requests,
                current=len(window),
                strategy="sliding_window",
                metadata={
                    "window_size": self.window_size,
                    "request_count": len(window),
                },
            )

    async def release(self, key: str, cost: int = 1) -> None:
        """Sliding window doesn't support release - requests age out."""
        pass

    def reset(self, key: str) -> None:
        """Clear all requests for a key."""
        with self._sync_lock:
            if key in self._windows:
                self._windows[key].clear()


# =============================================================================
# LEAKY BUCKET RATE LIMITER
# =============================================================================

class LeakyBucket(BaseRateLimiter):
    """
    Leaky Bucket Rate Limiter.

    The leaky bucket algorithm smooths out bursty traffic by
    processing requests at a constant rate.

    Features:
    - Configurable leak rate (requests per second)
    - Configurable bucket capacity
    - Request queuing support
    - Smooth output rate
    """

    def __init__(
        self,
        leak_rate: float = 10.0,  # requests per second
        bucket_capacity: int = 100,
        name: str = "leaky_bucket",
    ):
        super().__init__(name)
        self.leak_rate = leak_rate
        self.bucket_capacity = bucket_capacity
        self._buckets: Dict[str, Dict[str, Any]] = {}

    def _get_bucket(self, key: str) -> Dict[str, Any]:
        """Get or create a bucket for a key."""
        if key not in self._buckets:
            self._buckets[key] = {
                "water_level": 0.0,
                "last_leak": time.time(),
            }
        return self._buckets[key]

    def _leak(self, bucket: Dict[str, Any]) -> None:
        """Leak water from the bucket based on elapsed time."""
        now = time.time()
        elapsed = now - bucket["last_leak"]
        leaked = elapsed * self.leak_rate
        bucket["water_level"] = max(0, bucket["water_level"] - leaked)
        bucket["last_leak"] = now

    async def acquire(self, key: str, cost: int = 1) -> RateLimitResult:
        """
        Attempt to add water (request) to the bucket.

        Args:
            key: Identifier for the rate limit
            cost: Amount of water to add

        Returns:
            RateLimitResult with status and metadata
        """
        async with self._lock:
            bucket = self._get_bucket(key)
            self._leak(bucket)

            new_level = bucket["water_level"] + cost
            allowed = new_level <= self.bucket_capacity

            if allowed:
                bucket["water_level"] = new_level
                remaining = int(self.bucket_capacity - bucket["water_level"])
                retry_after = None
            else:
                remaining = 0
                # Calculate time until bucket has space
                overflow = new_level - self.bucket_capacity
                retry_after = overflow / self.leak_rate

            # Time until bucket is empty
            reset_at = time.time() + (bucket["water_level"] / self.leak_rate)

            self._record_result(allowed)

            return RateLimitResult(
                allowed=allowed,
                remaining=remaining,
                reset_at=reset_at,
                retry_after=retry_after,
                limit=self.bucket_capacity,
                current=int(bucket["water_level"]),
                strategy="leaky_bucket",
                metadata={
                    "leak_rate": self.leak_rate,
                    "water_level": bucket["water_level"],
                    "capacity": self.bucket_capacity,
                },
            )

    async def release(self, key: str, cost: int = 1) -> None:
        """Manually remove water from the bucket."""
        async with self._lock:
            if key in self._buckets:
                bucket = self._buckets[key]
                bucket["water_level"] = max(0, bucket["water_level"] - cost)

    def reset(self, key: str) -> None:
        """Empty the bucket."""
        with self._sync_lock:
            if key in self._buckets:
                self._buckets[key] = {
                    "water_level": 0.0,
                    "last_leak": time.time(),
                }


# =============================================================================
# ADAPTIVE RATE LIMITER
# =============================================================================

class AdaptiveRateLimiter(BaseRateLimiter):
    """
    Adaptive Rate Limiter with Dynamic Limit Adjustment.

    Automatically adjusts rate limits based on system metrics
    and load conditions.

    Features:
    - Dynamic limit adjustment based on system load
    - Multiple adjustment strategies
    - Health-based degradation
    - Configurable thresholds
    """

    def __init__(
        self,
        base_limiter: BaseRateLimiter,
        min_limit_factor: float = 0.1,
        max_limit_factor: float = 2.0,
        adjustment_interval: float = 5.0,
        name: str = "adaptive",
    ):
        super().__init__(name)
        self.base_limiter = base_limiter
        self.min_limit_factor = min_limit_factor
        self.max_limit_factor = max_limit_factor
        self.adjustment_interval = adjustment_interval
        self.current_factor = 1.0
        self._metrics_history: deque = deque(maxlen=60)
        self._last_adjustment = time.time()

        # Thresholds for adjustment
        self.thresholds = {
            "cpu_high": 0.8,
            "cpu_low": 0.3,
            "memory_high": 0.85,
            "latency_high": 1.0,  # seconds
            "error_rate_high": 0.05,  # 5%
        }

    def update_metrics(self, metrics: SystemMetrics) -> None:
        """Update system metrics for adaptive adjustment."""
        self._metrics_history.append(metrics)
        self._maybe_adjust()

    def _maybe_adjust(self) -> None:
        """Adjust rate limits if needed based on metrics."""
        now = time.time()
        if now - self._last_adjustment < self.adjustment_interval:
            return

        if not self._metrics_history:
            return

        self._last_adjustment = now

        # Calculate average metrics
        cpu_avg = sum(m.cpu_usage for m in self._metrics_history) / len(self._metrics_history)
        memory_avg = sum(m.memory_usage for m in self._metrics_history) / len(self._metrics_history)
        latency_avg = sum(m.request_latency_p99 for m in self._metrics_history) / len(self._metrics_history)
        error_avg = sum(m.error_rate for m in self._metrics_history) / len(self._metrics_history)

        # Determine adjustment direction
        adjustment = 0.0

        # CPU-based adjustment
        if cpu_avg > self.thresholds["cpu_high"]:
            adjustment -= 0.1
        elif cpu_avg < self.thresholds["cpu_low"]:
            adjustment += 0.05

        # Memory-based adjustment
        if memory_avg > self.thresholds["memory_high"]:
            adjustment -= 0.15

        # Latency-based adjustment
        if latency_avg > self.thresholds["latency_high"]:
            adjustment -= 0.2

        # Error rate adjustment
        if error_avg > self.thresholds["error_rate_high"]:
            adjustment -= 0.25

        # Apply adjustment with bounds
        new_factor = self.current_factor + adjustment
        self.current_factor = max(
            self.min_limit_factor,
            min(self.max_limit_factor, new_factor)
        )

        logger.info(
            f"Adaptive rate limiter adjusted: factor={self.current_factor:.2f} "
            f"(cpu={cpu_avg:.2f}, mem={memory_avg:.2f}, lat={latency_avg:.3f}, err={error_avg:.3f})"
        )

    async def acquire(self, key: str, cost: int = 1) -> RateLimitResult:
        """
        Acquire with adaptive cost adjustment.

        The effective cost is adjusted based on current system load.
        """
        # Adjust cost based on current factor
        # Lower factor = higher effective cost = stricter limiting
        adjusted_cost = max(1, int(cost / self.current_factor))

        result = await self.base_limiter.acquire(key, adjusted_cost)
        result.strategy = "adaptive"
        result.metadata["adaptation_factor"] = self.current_factor
        result.metadata["original_cost"] = cost
        result.metadata["adjusted_cost"] = adjusted_cost

        self._record_result(result.allowed)

        return result

    async def release(self, key: str, cost: int = 1) -> None:
        """Release with same adjustment."""
        adjusted_cost = max(1, int(cost / self.current_factor))
        await self.base_limiter.release(key, adjusted_cost)

    def reset(self, key: str) -> None:
        """Reset the underlying limiter."""
        self.base_limiter.reset(key)

    def get_stats(self) -> Dict[str, Any]:
        """Get stats including adaptation metrics."""
        stats = super().get_stats()
        stats["adaptation_factor"] = self.current_factor
        stats["base_limiter_stats"] = self.base_limiter.get_stats()
        return stats


# =============================================================================
# DISTRIBUTED RATE LIMITER
# =============================================================================

class RedisBackend:
    """Redis backend abstraction for distributed rate limiting."""

    def __init__(self, redis_client: Any = None):
        self.redis = redis_client
        self._local_fallback: Dict[str, Any] = {}
        self._use_local = redis_client is None

    async def get(self, key: str) -> Optional[str]:
        """Get value from Redis or local fallback."""
        if self._use_local:
            return self._local_fallback.get(key)
        try:
            return await self.redis.get(key)
        except Exception as e:
            logger.warning(f"Redis get failed, using local: {e}")
            return self._local_fallback.get(key)

    async def set(self, key: str, value: str, ex: int = None) -> bool:
        """Set value in Redis or local fallback."""
        if self._use_local:
            self._local_fallback[key] = value
            return True
        try:
            await self.redis.set(key, value, ex=ex)
            return True
        except Exception as e:
            logger.warning(f"Redis set failed, using local: {e}")
            self._local_fallback[key] = value
            return True

    async def incr(self, key: str) -> int:
        """Increment counter."""
        if self._use_local:
            val = int(self._local_fallback.get(key, 0)) + 1
            self._local_fallback[key] = str(val)
            return val
        try:
            return await self.redis.incr(key)
        except Exception as e:
            logger.warning(f"Redis incr failed: {e}")
            val = int(self._local_fallback.get(key, 0)) + 1
            self._local_fallback[key] = str(val)
            return val

    async def expire(self, key: str, seconds: int) -> bool:
        """Set key expiration."""
        if self._use_local:
            return True  # Local fallback doesn't support expiration
        try:
            await self.redis.expire(key, seconds)
            return True
        except Exception:
            return False

    async def eval_script(self, script: str, keys: List[str], args: List[Any]) -> Any:
        """Execute Lua script atomically."""
        if self._use_local:
            raise NotImplementedError("Local fallback doesn't support Lua scripts")
        return await self.redis.eval(script, len(keys), *keys, *args)


class DistributedRateLimiter(BaseRateLimiter):
    """
    Distributed Rate Limiter using Redis.

    Provides consistent rate limiting across multiple nodes
    using Redis as the coordination backend.

    Features:
    - Cross-node rate limiting
    - Atomic operations via Lua scripts
    - Local fallback for Redis failures
    - Configurable synchronization
    """

    # Lua script for atomic token bucket operation
    TOKEN_BUCKET_SCRIPT = """
    local key = KEYS[1]
    local rate = tonumber(ARGV[1])
    local capacity = tonumber(ARGV[2])
    local now = tonumber(ARGV[3])
    local cost = tonumber(ARGV[4])

    local data = redis.call('GET', key)
    local tokens, last_update

    if data then
        local parsed = cjson.decode(data)
        tokens = parsed.tokens
        last_update = parsed.last_update
    else
        tokens = capacity
        last_update = now
    end

    -- Refill tokens
    local elapsed = now - last_update
    tokens = math.min(capacity, tokens + (elapsed * rate))

    local allowed = 0
    if tokens >= cost then
        tokens = tokens - cost
        allowed = 1
    end

    -- Save state
    redis.call('SET', key, cjson.encode({tokens = tokens, last_update = now}))
    redis.call('EXPIRE', key, 3600)

    return {allowed, tokens}
    """

    def __init__(
        self,
        redis_client: Any = None,
        rate: float = DEFAULT_TOKEN_RATE,
        bucket_size: int = DEFAULT_BUCKET_SIZE,
        key_prefix: str = "ratelimit:",
        name: str = "distributed",
    ):
        super().__init__(name)
        self.backend = RedisBackend(redis_client)
        self.rate = rate
        self.bucket_size = bucket_size
        self.key_prefix = key_prefix
        self._script_sha: Optional[str] = None

        # Local fallback limiter
        self._local_fallback = TokenBucket(rate, bucket_size)

    def _make_key(self, key: str) -> str:
        """Generate Redis key with prefix."""
        return f"{self.key_prefix}{key}"

    async def acquire(self, key: str, cost: int = 1) -> RateLimitResult:
        """
        Acquire rate limit across distributed nodes.

        Uses atomic Redis operations for consistency.
        """
        redis_key = self._make_key(key)
        now = time.time()

        try:
            if self.backend._use_local:
                raise NotImplementedError("Using local fallback")

            result = await self.backend.eval_script(
                self.TOKEN_BUCKET_SCRIPT,
                [redis_key],
                [self.rate, self.bucket_size, now, cost]
            )

            allowed = bool(result[0])
            tokens = float(result[1])

        except Exception as e:
            logger.warning(f"Distributed rate limit failed, using local: {e}")
            return await self._local_fallback.acquire(key, cost)

        remaining = int(tokens)
        retry_after = None if allowed else (cost / self.rate)
        reset_at = now + ((self.bucket_size - tokens) / self.rate)

        self._record_result(allowed)

        return RateLimitResult(
            allowed=allowed,
            remaining=remaining,
            reset_at=reset_at,
            retry_after=retry_after,
            limit=self.bucket_size,
            current=self.bucket_size - remaining,
            strategy="distributed",
            metadata={
                "node_key": redis_key,
                "rate": self.rate,
                "bucket_size": self.bucket_size,
            },
        )

    async def release(self, key: str, cost: int = 1) -> None:
        """Distributed bucket doesn't support explicit release."""
        pass

    def reset(self, key: str) -> None:
        """Reset distributed bucket (async operation needed)."""
        asyncio.create_task(self._async_reset(key))

    async def _async_reset(self, key: str) -> None:
        """Async reset implementation."""
        redis_key = self._make_key(key)
        await self.backend.set(
            redis_key,
            json.dumps({"tokens": self.bucket_size, "last_update": time.time()})
        )


# =============================================================================
# QUOTA MANAGER
# =============================================================================

class QuotaManager:
    """
    Per-User Quota Management System.

    Manages and enforces quotas at multiple time scales
    (second, minute, hour, day).

    Features:
    - Multi-tier quota enforcement
    - Per-user configuration
    - Quota usage tracking
    - Burst allowance support
    - Priority-based allocation
    """

    def __init__(
        self,
        default_config: Optional[QuotaConfig] = None,
        redis_client: Any = None,
    ):
        self._configs: Dict[str, QuotaConfig] = {}
        self._default_config = default_config or QuotaConfig(user_id="default")
        self._usage: Dict[str, Dict[str, deque]] = {}
        self._lock = asyncio.Lock()
        self.backend = RedisBackend(redis_client)

        # Rate limiters for each tier
        self._limiters: Dict[str, Dict[str, SlidingWindow]] = {}

    def set_quota(self, config: QuotaConfig) -> None:
        """Set quota configuration for a user."""
        self._configs[config.user_id] = config
        self._setup_limiters(config.user_id)

    def get_quota(self, user_id: str) -> QuotaConfig:
        """Get quota configuration for a user."""
        return self._configs.get(user_id, self._default_config)

    def _setup_limiters(self, user_id: str) -> None:
        """Setup rate limiters for a user's quotas."""
        config = self.get_quota(user_id)

        self._limiters[user_id] = {
            "second": SlidingWindow(
                window_size=1,
                max_requests=config.requests_per_second,
            ),
            "minute": SlidingWindow(
                window_size=60,
                max_requests=config.requests_per_minute,
            ),
            "hour": SlidingWindow(
                window_size=3600,
                max_requests=config.requests_per_hour,
            ),
            "day": SlidingWindow(
                window_size=86400,
                max_requests=config.requests_per_day,
            ),
        }

    async def check_quota(
        self,
        user_id: str,
        cost: int = 1,
        quota_type: Optional[str] = None,
    ) -> RateLimitResult:
        """
        Check if user has remaining quota.

        Args:
            user_id: User identifier
            cost: Request cost
            quota_type: Specific quota type to check (or all)

        Returns:
            RateLimitResult indicating quota status
        """
        if user_id not in self._limiters:
            self._setup_limiters(user_id)

        limiters = self._limiters[user_id]
        config = self.get_quota(user_id)

        # Apply burst allowance to cost
        effective_cost = cost
        if config.burst_allowance > 1.0:
            # Allow temporary bursts by reducing effective cost
            effective_cost = max(1, int(cost / config.burst_allowance))

        # Check specific quota type or all
        types_to_check = [quota_type] if quota_type else ["second", "minute", "hour", "day"]

        results = []
        for qt in types_to_check:
            if qt in limiters:
                result = await limiters[qt].acquire(user_id, effective_cost)
                results.append((qt, result))

        # Find the most restrictive result
        most_restrictive = None
        for qt, result in results:
            if not result.allowed:
                if most_restrictive is None or result.retry_after > most_restrictive[1].retry_after:
                    most_restrictive = (qt, result)

        if most_restrictive:
            qt, result = most_restrictive
            result.strategy = f"quota_{qt}"
            result.metadata["quota_type"] = qt
            result.metadata["user_id"] = user_id
            return result

        # All quotas passed - return least remaining
        min_remaining = min((r.remaining for _, r in results), default=0)
        min_result = results[0][1] if results else RateLimitResult(
            allowed=True, remaining=0, reset_at=time.time()
        )

        return RateLimitResult(
            allowed=True,
            remaining=min_remaining,
            reset_at=min_result.reset_at,
            retry_after=None,
            limit=min_result.limit,
            current=min_result.current,
            strategy="quota_multi",
            metadata={
                "user_id": user_id,
                "quotas_checked": types_to_check,
                "priority": config.priority,
            },
        )

    async def consume_quota(
        self,
        user_id: str,
        cost: int = 1,
    ) -> RateLimitResult:
        """
        Consume quota for a user (check and deduct).

        This is a convenience method that checks and consumes in one call.
        """
        return await self.check_quota(user_id, cost)

    def get_usage(self, user_id: str) -> Dict[str, Any]:
        """Get current usage statistics for a user."""
        if user_id not in self._limiters:
            return {"error": "User not found"}

        config = self.get_quota(user_id)
        limiters = self._limiters[user_id]

        usage = {}
        for period, limiter in limiters.items():
            stats = limiter.get_stats()
            usage[period] = {
                "limit": getattr(config, f"requests_per_{period}", 0),
                "used": stats["total"],
                "remaining": stats.get("remaining", 0),
                "rejection_rate": stats["rejection_rate"],
            }

        return {
            "user_id": user_id,
            "priority": config.priority,
            "burst_allowance": config.burst_allowance,
            "usage": usage,
        }

    def reset_quota(self, user_id: str, quota_type: Optional[str] = None) -> None:
        """Reset quota for a user."""
        if user_id not in self._limiters:
            return

        if quota_type:
            if quota_type in self._limiters[user_id]:
                self._limiters[user_id][quota_type].reset(user_id)
        else:
            for limiter in self._limiters[user_id].values():
                limiter.reset(user_id)


# =============================================================================
# RATE LIMITER FACTORY
# =============================================================================

class RateLimiterFactory:
    """Factory for creating rate limiter instances."""

    @staticmethod
    def create(
        strategy: RateLimitStrategy,
        **kwargs,
    ) -> BaseRateLimiter:
        """
        Create a rate limiter with the specified strategy.

        Args:
            strategy: Rate limiting strategy to use
            **kwargs: Configuration options for the limiter

        Returns:
            Configured rate limiter instance
        """
        if strategy == RateLimitStrategy.TOKEN_BUCKET:
            return TokenBucket(
                rate=kwargs.get("rate", DEFAULT_TOKEN_RATE),
                bucket_size=kwargs.get("bucket_size", DEFAULT_BUCKET_SIZE),
            )

        elif strategy == RateLimitStrategy.SLIDING_WINDOW:
            return SlidingWindow(
                window_size=kwargs.get("window_size", DEFAULT_WINDOW_SIZE),
                max_requests=kwargs.get("max_requests", DEFAULT_MAX_REQUESTS),
            )

        elif strategy == RateLimitStrategy.LEAKY_BUCKET:
            return LeakyBucket(
                leak_rate=kwargs.get("leak_rate", 10.0),
                bucket_capacity=kwargs.get("bucket_capacity", 100),
            )

        elif strategy == RateLimitStrategy.ADAPTIVE:
            base = kwargs.get("base_limiter") or TokenBucket()
            return AdaptiveRateLimiter(
                base_limiter=base,
                min_limit_factor=kwargs.get("min_limit_factor", 0.1),
                max_limit_factor=kwargs.get("max_limit_factor", 2.0),
            )

        elif strategy == RateLimitStrategy.DISTRIBUTED:
            return DistributedRateLimiter(
                redis_client=kwargs.get("redis_client"),
                rate=kwargs.get("rate", DEFAULT_TOKEN_RATE),
                bucket_size=kwargs.get("bucket_size", DEFAULT_BUCKET_SIZE),
            )

        raise ValueError(f"Unknown rate limit strategy: {strategy}")


# =============================================================================
# DECORATORS AND MIDDLEWARE
# =============================================================================

def rate_limit(
    limiter: BaseRateLimiter,
    key_func: Optional[Callable[..., str]] = None,
    cost_func: Optional[Callable[..., int]] = None,
    on_limited: Optional[Callable[[RateLimitResult], Any]] = None,
):
    """
    Decorator for applying rate limiting to async functions.

    Args:
        limiter: Rate limiter instance to use
        key_func: Function to extract rate limit key from args
        cost_func: Function to calculate request cost
        on_limited: Callback when rate limited
    """
    def decorator(func: Callable[..., Coroutine]):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            # Determine key
            key = key_func(*args, **kwargs) if key_func else "default"

            # Determine cost
            cost = cost_func(*args, **kwargs) if cost_func else 1

            # Check rate limit
            result = await limiter.acquire(key, cost)

            if not result.allowed:
                if on_limited:
                    return on_limited(result)
                raise RateLimitExceeded(result)

            return await func(*args, **kwargs)

        return wrapper
    return decorator


class RateLimitExceeded(Exception):
    """Exception raised when rate limit is exceeded."""

    def __init__(self, result: RateLimitResult):
        self.result = result
        super().__init__(
            f"Rate limit exceeded. Retry after {result.retry_after:.2f}s"
        )


@asynccontextmanager
async def rate_limited(
    limiter: BaseRateLimiter,
    key: str,
    cost: int = 1,
):
    """
    Context manager for rate limiting.

    Usage:
        async with rate_limited(limiter, "user_123"):
            await do_something()
    """
    result = await limiter.acquire(key, cost)
    if not result.allowed:
        raise RateLimitExceeded(result)

    try:
        yield result
    finally:
        await limiter.release(key, cost)


# =============================================================================
# COMPOSITE RATE LIMITER
# =============================================================================

class CompositeRateLimiter(BaseRateLimiter):
    """
    Composite rate limiter combining multiple strategies.

    All configured limiters must allow the request for it to proceed.
    """

    def __init__(
        self,
        limiters: List[BaseRateLimiter],
        name: str = "composite",
    ):
        super().__init__(name)
        self.limiters = limiters

    async def acquire(self, key: str, cost: int = 1) -> RateLimitResult:
        """
        Acquire from all limiters.

        Returns the most restrictive result.
        """
        results = []

        for limiter in self.limiters:
            result = await limiter.acquire(key, cost)
            results.append(result)

            if not result.allowed:
                # Fail fast - release acquired permits
                for prev_result, prev_limiter in zip(results[:-1], self.limiters[:-1]):
                    if prev_result.allowed:
                        await prev_limiter.release(key, cost)

                self._record_result(False)
                return result

        # All passed - return most restrictive
        min_remaining = min(r.remaining for r in results)
        max_reset = max(r.reset_at for r in results)

        self._record_result(True)

        return RateLimitResult(
            allowed=True,
            remaining=min_remaining,
            reset_at=max_reset,
            limit=min(r.limit for r in results),
            current=max(r.current for r in results),
            strategy="composite",
            metadata={
                "strategies": [r.strategy for r in results],
            },
        )

    async def release(self, key: str, cost: int = 1) -> None:
        """Release from all limiters."""
        for limiter in self.limiters:
            await limiter.release(key, cost)

    def reset(self, key: str) -> None:
        """Reset all limiters."""
        for limiter in self.limiters:
            limiter.reset(key)


# =============================================================================
# MAIN - TESTING AND DEMONSTRATION
# =============================================================================

async def main():
    """Test and demonstrate rate limiting components."""
    print("=" * 60)
    print("AIVA Rate Limiting System - Test Suite")
    print("=" * 60)

    # Test Token Bucket
    print("\n--- Token Bucket Test ---")
    tb = TokenBucket(rate=10, bucket_size=20)
    for i in range(25):
        result = await tb.acquire("test_user", cost=1)
        print(f"  Request {i+1}: allowed={result.allowed}, remaining={result.remaining}")
    print(f"  Stats: {tb.get_stats()}")

    # Test Sliding Window
    print("\n--- Sliding Window Test ---")
    sw = SlidingWindow(window_size=1, max_requests=5)
    for i in range(8):
        result = await sw.acquire("test_user", cost=1)
        print(f"  Request {i+1}: allowed={result.allowed}, remaining={result.remaining}")
    print(f"  Stats: {sw.get_stats()}")

    # Test Leaky Bucket
    print("\n--- Leaky Bucket Test ---")
    lb = LeakyBucket(leak_rate=5, bucket_capacity=10)
    for i in range(15):
        result = await lb.acquire("test_user", cost=1)
        print(f"  Request {i+1}: allowed={result.allowed}, remaining={result.remaining}")
        if not result.allowed:
            await asyncio.sleep(0.2)  # Wait for leak
    print(f"  Stats: {lb.get_stats()}")

    # Test Adaptive Rate Limiter
    print("\n--- Adaptive Rate Limiter Test ---")
    base = TokenBucket(rate=10, bucket_size=50)
    adaptive = AdaptiveRateLimiter(base)

    # Simulate high load
    metrics = SystemMetrics(
        cpu_usage=0.9,
        memory_usage=0.85,
        request_latency_p99=1.5,
        error_rate=0.08,
    )
    adaptive.update_metrics(metrics)

    result = await adaptive.acquire("test_user", cost=1)
    print(f"  High load result: allowed={result.allowed}, factor={result.metadata.get('adaptation_factor')}")
    print(f"  Stats: {adaptive.get_stats()}")

    # Test Quota Manager
    print("\n--- Quota Manager Test ---")
    qm = QuotaManager()
    qm.set_quota(QuotaConfig(
        user_id="premium_user",
        requests_per_second=10,
        requests_per_minute=100,
        requests_per_hour=1000,
        requests_per_day=10000,
        burst_allowance=2.0,
        priority=10,
    ))

    for i in range(15):
        result = await qm.check_quota("premium_user")
        print(f"  Quota check {i+1}: allowed={result.allowed}, remaining={result.remaining}")

    print(f"  Usage: {qm.get_usage('premium_user')}")

    # Test Composite Rate Limiter
    print("\n--- Composite Rate Limiter Test ---")
    composite = CompositeRateLimiter([
        TokenBucket(rate=5, bucket_size=10),
        SlidingWindow(window_size=1, max_requests=3),
    ])

    for i in range(5):
        result = await composite.acquire("test_user")
        print(f"  Request {i+1}: allowed={result.allowed}, strategies={result.metadata.get('strategies')}")

    # Test Factory
    print("\n--- Factory Test ---")
    for strategy in RateLimitStrategy:
        try:
            limiter = RateLimiterFactory.create(strategy)
            result = await limiter.acquire("factory_test")
            print(f"  {strategy.name}: created and tested OK")
        except Exception as e:
            print(f"  {strategy.name}: {e}")

    print("\n" + "=" * 60)
    print("All tests completed!")
    print("=" * 60)


if __name__ == "__main__":
    asyncio.run(main())
