import time
import threading
from collections import deque
from aiva.error_handling import RateLimitExceededError, handle_rate_limit_error
import logging

logger = logging.getLogger(__name__)

class AdaptiveRateLimiter:
    def __init__(self, capacity, refill_rate, priority_threshold=5):
        self.capacity = capacity
        self.tokens = capacity
        self.refill_rate = refill_rate  # tokens per second
        self.last_refill = time.monotonic()
        self.lock = threading.Lock()
        self.priority_threshold = priority_threshold # Requests with priority >= this are prioritized
        self.request_queue = deque()  # Queue of (priority, request) tuples

    def _refill(self):
        now = time.monotonic()
        elapsed_time = now - self.last_refill
        refill_amount = elapsed_time * self.refill_rate
        with self.lock:
            self.tokens = min(self.capacity, self.tokens + refill_amount)
            self.last_refill = now

    def consume(self, tokens=1, priority=0, request=None):
        self._refill()
        with self.lock:
            if self.tokens >= tokens:
                self.tokens -= tokens
                logger.debug(f"Consumed {tokens} tokens. Remaining tokens: {self.tokens}")
                return True  # Request is allowed
            else:
                logger.warning(f"Rate limit exceeded.  Priority: {priority}, Tokens requested: {tokens}, Tokens available: {self.tokens}")
                if priority >= self.priority_threshold:
                    # Attempt to consume anyway, potentially going negative temporarily.
                    # This allows important requests to proceed even when slightly over the limit.
                    if self.tokens + (self.capacity * 0.1) >= tokens: # Allow up to 10% over capacity
                        self.tokens -= tokens
                        logger.warning(f"Priority request allowed, exceeding limit. Remaining tokens: {self.tokens}")
                        return True
                # Queue the request for later processing with priority
                self.request_queue.append((priority, request))
                return False # Request is rate limited

    def process_queue(self):
        while self.request_queue:
            priority, request = self.request_queue.popleft()
            if self.consume(1, priority, request):
                # Process the request.  For now, just log it.
                logger.info(f"Processed queued request with priority {priority}: {request}")
            else:
                # Still rate limited.  Put it back in the queue (at the end to avoid starvation)
                self.request_queue.append((priority, request))
                break # Stop processing if we hit a rate limit

    def is_rate_limited(self, tokens=1, priority=0):
        self._refill()
        with self.lock:
            if self.tokens >= tokens:
                return False
            elif priority >= self.priority_threshold and self.tokens + (self.capacity * 0.1) >= tokens:
                return False # Prioritized request can proceed
            else:
                return True

    def get_status(self):
        self._refill()
        with self.lock:
            return {
                "capacity": self.capacity,
                "tokens": self.tokens,
                "refill_rate": self.refill_rate,
                "last_refill": self.last_refill
            }


# Example Usage (for testing):
if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)
    limiter = AdaptiveRateLimiter(capacity=10, refill_rate=2)

    for i in range(15):
        if limiter.consume(1, priority=i % 3):
            print(f"Request {i} allowed.")
        else:
            print(f"Request {i} rate limited.")
            # Simulate handling the rate limit by returning an error to the client.
            try:
                raise RateLimitExceededError("Rate limit exceeded. Please try again later.")
            except RateLimitExceededError as e:
                handle_rate_limit_error(e)
        time.sleep(0.2)

    print("Processing queue...")
    limiter.process_queue()
