#!/usr/bin/env python3
"""
AIVA Production Load Balancer
=============================

A high-performance, production-grade load balancer for AIVA multi-instance deployment.
Supports multiple load balancing strategies, health checking, session affinity,
circuit breakers, and comprehensive monitoring.

Genesis Protocol - Production Infrastructure Layer
"""

import asyncio
import hashlib
import json
import logging
import random
import threading
import time
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
import aiohttp
import ssl
from contextlib import asynccontextmanager


# =============================================================================
# CONFIGURATION
# =============================================================================

@dataclass
class LoadBalancerConfig:
    """Configuration for the load balancer."""
    health_check_interval: float = 10.0  # seconds
    health_check_timeout: float = 5.0    # seconds
    health_check_path: str = "/health"
    max_consecutive_failures: int = 3
    circuit_breaker_timeout: float = 30.0  # seconds
    session_ttl: float = 3600.0  # 1 hour
    connection_timeout: float = 10.0
    request_timeout: float = 60.0
    max_connections_per_backend: int = 100
    retry_attempts: int = 3
    retry_delay: float = 1.0
    enable_ssl_verification: bool = True
    metrics_enabled: bool = True
    log_level: str = "INFO"


class BackendStatus(Enum):
    """Health status of a backend instance."""
    HEALTHY = "healthy"
    UNHEALTHY = "unhealthy"
    DRAINING = "draining"  # Accepting no new connections, finishing existing
    CIRCUIT_OPEN = "circuit_open"  # Temporarily blocked
    UNKNOWN = "unknown"


@dataclass
class BackendMetrics:
    """Metrics for a backend instance."""
    total_requests: int = 0
    successful_requests: int = 0
    failed_requests: int = 0
    total_latency_ms: float = 0.0
    active_connections: int = 0
    consecutive_failures: int = 0
    last_success: Optional[datetime] = None
    last_failure: Optional[datetime] = None
    last_health_check: Optional[datetime] = None
    circuit_opened_at: Optional[datetime] = None

    @property
    def average_latency_ms(self) -> float:
        if self.total_requests == 0:
            return 0.0
        return self.total_latency_ms / self.total_requests

    @property
    def success_rate(self) -> float:
        if self.total_requests == 0:
            return 1.0
        return self.successful_requests / self.total_requests


@dataclass
class Backend:
    """Represents a backend AIVA instance."""
    id: str
    host: str
    port: int
    weight: int = 1
    max_connections: int = 100
    status: BackendStatus = BackendStatus.UNKNOWN
    metrics: BackendMetrics = field(default_factory=BackendMetrics)
    metadata: Dict[str, Any] = field(default_factory=dict)

    @property
    def url(self) -> str:
        return f"http://{self.host}:{self.port}"

    @property
    def is_available(self) -> bool:
        return (
            self.status == BackendStatus.HEALTHY and
            self.metrics.active_connections < self.max_connections
        )


@dataclass
class Request:
    """Represents an incoming request to be load balanced."""
    id: str
    method: str
    path: str
    headers: Dict[str, str]
    body: Optional[bytes] = None
    client_ip: str = ""
    session_id: Optional[str] = None
    created_at: datetime = field(default_factory=datetime.now)

    @staticmethod
    def generate_id() -> str:
        return str(uuid.uuid4())


@dataclass
class Response:
    """Represents a response from a backend."""
    status_code: int
    headers: Dict[str, str]
    body: bytes
    backend_id: str
    latency_ms: float

    @property
    def is_success(self) -> bool:
        return 200 <= self.status_code < 400


# =============================================================================
# LOAD BALANCING STRATEGIES
# =============================================================================

class LoadBalancingStrategy(ABC):
    """Abstract base class for load balancing strategies."""

    @abstractmethod
    def select_backend(
        self,
        backends: List[Backend],
        request: Request
    ) -> Optional[Backend]:
        """Select a backend for the given request."""
        pass

    @abstractmethod
    def on_request_complete(
        self,
        backend: Backend,
        request: Request,
        success: bool
    ) -> None:
        """Called when a request completes."""
        pass

    @property
    @abstractmethod
    def name(self) -> str:
        """Strategy name for logging and metrics."""
        pass


class RoundRobin(LoadBalancingStrategy):
    """
    Round-robin load balancing strategy.
    Distributes requests evenly across all healthy backends in rotation.
    """

    def __init__(self):
        self._index = 0
        self._lock = threading.Lock()

    @property
    def name(self) -> str:
        return "round_robin"

    def select_backend(
        self,
        backends: List[Backend],
        request: Request
    ) -> Optional[Backend]:
        available = [b for b in backends if b.is_available]
        if not available:
            return None

        with self._lock:
            backend = available[self._index % len(available)]
            self._index = (self._index + 1) % len(available)

        return backend

    def on_request_complete(
        self,
        backend: Backend,
        request: Request,
        success: bool
    ) -> None:
        # Round robin doesn't need to track completion
        pass


class LeastConnections(LoadBalancingStrategy):
    """
    Least-connections load balancing strategy.
    Directs traffic to the backend with the fewest active connections.
    """

    def __init__(self):
        self._connection_counts: Dict[str, int] = defaultdict(int)
        self._lock = threading.Lock()

    @property
    def name(self) -> str:
        return "least_connections"

    def select_backend(
        self,
        backends: List[Backend],
        request: Request
    ) -> Optional[Backend]:
        available = [b for b in backends if b.is_available]
        if not available:
            return None

        with self._lock:
            # Select backend with least connections
            backend = min(
                available,
                key=lambda b: b.metrics.active_connections
            )
            self._connection_counts[backend.id] += 1

        return backend

    def on_request_complete(
        self,
        backend: Backend,
        request: Request,
        success: bool
    ) -> None:
        with self._lock:
            if self._connection_counts[backend.id] > 0:
                self._connection_counts[backend.id] -= 1


class WeightedStrategy(LoadBalancingStrategy):
    """
    Weighted load balancing strategy.
    Distributes traffic based on backend weights (higher weight = more traffic).
    Supports dynamic weight adjustment based on performance.
    """

    def __init__(self, adaptive: bool = True):
        self._adaptive = adaptive
        self._weight_multipliers: Dict[str, float] = defaultdict(lambda: 1.0)
        self._lock = threading.Lock()

    @property
    def name(self) -> str:
        return "weighted"

    def select_backend(
        self,
        backends: List[Backend],
        request: Request
    ) -> Optional[Backend]:
        available = [b for b in backends if b.is_available]
        if not available:
            return None

        with self._lock:
            # Calculate effective weights
            weights = []
            for backend in available:
                effective_weight = backend.weight * self._weight_multipliers[backend.id]
                weights.append(effective_weight)

            # Weighted random selection
            total_weight = sum(weights)
            if total_weight <= 0:
                return random.choice(available)

            r = random.uniform(0, total_weight)
            cumulative = 0.0

            for backend, weight in zip(available, weights):
                cumulative += weight
                if r <= cumulative:
                    return backend

            return available[-1]

    def on_request_complete(
        self,
        backend: Backend,
        request: Request,
        success: bool
    ) -> None:
        if not self._adaptive:
            return

        with self._lock:
            multiplier = self._weight_multipliers[backend.id]

            if success:
                # Gradually increase weight multiplier on success
                self._weight_multipliers[backend.id] = min(2.0, multiplier * 1.01)
            else:
                # Decrease weight on failure
                self._weight_multipliers[backend.id] = max(0.1, multiplier * 0.9)

    def set_weight(self, backend_id: str, weight_multiplier: float) -> None:
        """Manually set weight multiplier for a backend."""
        with self._lock:
            self._weight_multipliers[backend_id] = max(0.1, min(2.0, weight_multiplier))

    def get_effective_weights(self, backends: List[Backend]) -> Dict[str, float]:
        """Get effective weights for all backends."""
        with self._lock:
            return {
                b.id: b.weight * self._weight_multipliers[b.id]
                for b in backends
            }


class IPHashStrategy(LoadBalancingStrategy):
    """
    IP Hash load balancing strategy.
    Routes requests from the same client IP to the same backend.
    """

    @property
    def name(self) -> str:
        return "ip_hash"

    def select_backend(
        self,
        backends: List[Backend],
        request: Request
    ) -> Optional[Backend]:
        available = [b for b in backends if b.is_available]
        if not available:
            return None

        # Hash the client IP
        ip_hash = int(hashlib.md5(
            request.client_ip.encode()
        ).hexdigest(), 16)

        return available[ip_hash % len(available)]

    def on_request_complete(
        self,
        backend: Backend,
        request: Request,
        success: bool
    ) -> None:
        pass


class RandomStrategy(LoadBalancingStrategy):
    """
    Random load balancing strategy.
    Randomly selects a healthy backend for each request.
    """

    @property
    def name(self) -> str:
        return "random"

    def select_backend(
        self,
        backends: List[Backend],
        request: Request
    ) -> Optional[Backend]:
        available = [b for b in backends if b.is_available]
        if not available:
            return None
        return random.choice(available)

    def on_request_complete(
        self,
        backend: Backend,
        request: Request,
        success: bool
    ) -> None:
        pass


# =============================================================================
# HEALTH CHECKER
# =============================================================================

class HealthChecker:
    """
    Performs health checks on backend instances.
    Supports HTTP health checks with configurable paths and timeouts.
    """

    def __init__(
        self,
        config: LoadBalancerConfig,
        logger: Optional[logging.Logger] = None
    ):
        self.config = config
        self.logger = logger or logging.getLogger(__name__)
        self._running = False
        self._task: Optional[asyncio.Task] = None
        self._session: Optional[aiohttp.ClientSession] = None
        self._callbacks: List[Callable[[Backend, BackendStatus], None]] = []

    def add_status_callback(
        self,
        callback: Callable[[Backend, BackendStatus], None]
    ) -> None:
        """Add a callback to be called when backend status changes."""
        self._callbacks.append(callback)

    async def start(self, backends: List[Backend]) -> None:
        """Start the health checker."""
        if self._running:
            return

        self._running = True
        connector = aiohttp.TCPConnector(
            limit=len(backends) * 2,
            ssl=None if not self.config.enable_ssl_verification else ssl.create_default_context()
        )
        self._session = aiohttp.ClientSession(connector=connector)
        self._task = asyncio.create_task(self._run_health_checks(backends))
        self.logger.info("Health checker started")

    async def stop(self) -> None:
        """Stop the health checker."""
        self._running = False
        if self._task:
            self._task.cancel()
            try:
                await self._task
            except asyncio.CancelledError:
                pass
        if self._session:
            await self._session.close()
        self.logger.info("Health checker stopped")

    async def _run_health_checks(self, backends: List[Backend]) -> None:
        """Main health check loop."""
        while self._running:
            try:
                await asyncio.gather(
                    *[self._check_backend(b) for b in backends],
                    return_exceptions=True
                )
            except Exception as e:
                self.logger.error(f"Health check error: {e}")

            await asyncio.sleep(self.config.health_check_interval)

    async def _check_backend(self, backend: Backend) -> None:
        """Check health of a single backend."""
        if backend.status == BackendStatus.DRAINING:
            return  # Don't health check draining backends

        old_status = backend.status
        url = f"{backend.url}{self.config.health_check_path}"

        try:
            async with self._session.get(
                url,
                timeout=aiohttp.ClientTimeout(total=self.config.health_check_timeout)
            ) as response:
                if response.status == 200:
                    backend.metrics.consecutive_failures = 0
                    backend.metrics.last_health_check = datetime.now()

                    if backend.status == BackendStatus.CIRCUIT_OPEN:
                        # Circuit breaker recovery
                        if self._should_close_circuit(backend):
                            backend.status = BackendStatus.HEALTHY
                            self.logger.info(
                                f"Circuit closed for backend {backend.id}"
                            )
                    else:
                        backend.status = BackendStatus.HEALTHY
                else:
                    self._handle_health_check_failure(backend)

        except asyncio.TimeoutError:
            self.logger.warning(f"Health check timeout for {backend.id}")
            self._handle_health_check_failure(backend)

        except Exception as e:
            self.logger.warning(f"Health check failed for {backend.id}: {e}")
            self._handle_health_check_failure(backend)

        # Notify callbacks if status changed
        if backend.status != old_status:
            for callback in self._callbacks:
                try:
                    callback(backend, backend.status)
                except Exception as e:
                    self.logger.error(f"Status callback error: {e}")

    def _handle_health_check_failure(self, backend: Backend) -> None:
        """Handle a failed health check."""
        backend.metrics.consecutive_failures += 1
        backend.metrics.last_failure = datetime.now()

        if backend.metrics.consecutive_failures >= self.config.max_consecutive_failures:
            if backend.status != BackendStatus.CIRCUIT_OPEN:
                backend.status = BackendStatus.UNHEALTHY
                self.logger.warning(
                    f"Backend {backend.id} marked unhealthy after "
                    f"{backend.metrics.consecutive_failures} consecutive failures"
                )

    def _should_close_circuit(self, backend: Backend) -> bool:
        """Check if circuit breaker should close."""
        if not backend.metrics.circuit_opened_at:
            return True

        elapsed = (datetime.now() - backend.metrics.circuit_opened_at).total_seconds()
        return elapsed >= self.config.circuit_breaker_timeout

    async def check_now(self, backend: Backend) -> BackendStatus:
        """Perform an immediate health check on a backend."""
        await self._check_backend(backend)
        return backend.status


# =============================================================================
# SESSION AFFINITY
# =============================================================================

class SessionAffinity:
    """
    Manages session affinity (sticky sessions) for the load balancer.
    Ensures requests from the same session go to the same backend.
    """

    def __init__(
        self,
        session_ttl: float = 3600.0,
        cookie_name: str = "AIVA_SESSION",
        logger: Optional[logging.Logger] = None
    ):
        self.session_ttl = session_ttl
        self.cookie_name = cookie_name
        self.logger = logger or logging.getLogger(__name__)

        self._sessions: Dict[str, Tuple[str, datetime]] = {}  # session_id -> (backend_id, expires)
        self._lock = threading.Lock()
        self._cleanup_task: Optional[asyncio.Task] = None

    async def start(self) -> None:
        """Start the session cleanup task."""
        self._cleanup_task = asyncio.create_task(self._cleanup_loop())
        self.logger.info("Session affinity manager started")

    async def stop(self) -> None:
        """Stop the session cleanup task."""
        if self._cleanup_task:
            self._cleanup_task.cancel()
            try:
                await self._cleanup_task
            except asyncio.CancelledError:
                pass
        self.logger.info("Session affinity manager stopped")

    def get_session_id(self, request: Request) -> Optional[str]:
        """Extract session ID from request."""
        # Check for session ID in request
        if request.session_id:
            return request.session_id

        # Check for session cookie
        cookies = request.headers.get("Cookie", "")
        for cookie in cookies.split(";"):
            cookie = cookie.strip()
            if cookie.startswith(f"{self.cookie_name}="):
                return cookie.split("=", 1)[1]

        return None

    def create_session(self, request: Request) -> str:
        """Create a new session ID."""
        session_id = str(uuid.uuid4())
        return session_id

    def get_backend_for_session(self, session_id: str) -> Optional[str]:
        """Get the backend ID associated with a session."""
        with self._lock:
            if session_id in self._sessions:
                backend_id, expires = self._sessions[session_id]
                if datetime.now() < expires:
                    return backend_id
                else:
                    del self._sessions[session_id]
        return None

    def bind_session(self, session_id: str, backend_id: str) -> None:
        """Bind a session to a backend."""
        with self._lock:
            expires = datetime.now() + timedelta(seconds=self.session_ttl)
            self._sessions[session_id] = (backend_id, expires)
            self.logger.debug(f"Session {session_id[:8]}... bound to backend {backend_id}")

    def unbind_session(self, session_id: str) -> None:
        """Remove session binding."""
        with self._lock:
            self._sessions.pop(session_id, None)

    def refresh_session(self, session_id: str) -> None:
        """Refresh session expiry time."""
        with self._lock:
            if session_id in self._sessions:
                backend_id, _ = self._sessions[session_id]
                expires = datetime.now() + timedelta(seconds=self.session_ttl)
                self._sessions[session_id] = (backend_id, expires)

    def get_session_cookie(self, session_id: str) -> str:
        """Generate session cookie header value."""
        return f"{self.cookie_name}={session_id}; Path=/; HttpOnly; SameSite=Strict"

    async def _cleanup_loop(self) -> None:
        """Periodically clean up expired sessions."""
        while True:
            try:
                await asyncio.sleep(60)  # Run every minute
                self._cleanup_expired()
            except asyncio.CancelledError:
                break
            except Exception as e:
                self.logger.error(f"Session cleanup error: {e}")

    def _cleanup_expired(self) -> None:
        """Remove expired sessions."""
        now = datetime.now()
        expired = []

        with self._lock:
            for session_id, (_, expires) in self._sessions.items():
                if now >= expires:
                    expired.append(session_id)

            for session_id in expired:
                del self._sessions[session_id]

        if expired:
            self.logger.debug(f"Cleaned up {len(expired)} expired sessions")

    @property
    def active_sessions(self) -> int:
        """Get count of active sessions."""
        with self._lock:
            return len(self._sessions)


# =============================================================================
# CIRCUIT BREAKER
# =============================================================================

class CircuitBreaker:
    """
    Circuit breaker pattern implementation.
    Prevents cascading failures by temporarily blocking requests to failing backends.
    """

    class State(Enum):
        CLOSED = "closed"      # Normal operation
        OPEN = "open"          # Blocking requests
        HALF_OPEN = "half_open"  # Testing if backend recovered

    def __init__(
        self,
        failure_threshold: int = 5,
        recovery_timeout: float = 30.0,
        half_open_max_calls: int = 3,
        logger: Optional[logging.Logger] = None
    ):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.half_open_max_calls = half_open_max_calls
        self.logger = logger or logging.getLogger(__name__)

        self._states: Dict[str, CircuitBreaker.State] = {}
        self._failure_counts: Dict[str, int] = defaultdict(int)
        self._last_failure_times: Dict[str, datetime] = {}
        self._half_open_calls: Dict[str, int] = defaultdict(int)
        self._lock = threading.Lock()

    def is_allowed(self, backend_id: str) -> bool:
        """Check if requests to this backend are allowed."""
        with self._lock:
            state = self._states.get(backend_id, CircuitBreaker.State.CLOSED)

            if state == CircuitBreaker.State.CLOSED:
                return True

            if state == CircuitBreaker.State.OPEN:
                # Check if recovery timeout has passed
                last_failure = self._last_failure_times.get(backend_id)
                if last_failure:
                    elapsed = (datetime.now() - last_failure).total_seconds()
                    if elapsed >= self.recovery_timeout:
                        self._states[backend_id] = CircuitBreaker.State.HALF_OPEN
                        self._half_open_calls[backend_id] = 0
                        self.logger.info(f"Circuit half-open for {backend_id}")
                        return True
                return False

            if state == CircuitBreaker.State.HALF_OPEN:
                # Allow limited calls in half-open state
                return self._half_open_calls[backend_id] < self.half_open_max_calls

            return False

    def record_success(self, backend_id: str) -> None:
        """Record a successful request."""
        with self._lock:
            state = self._states.get(backend_id, CircuitBreaker.State.CLOSED)

            if state == CircuitBreaker.State.HALF_OPEN:
                self._half_open_calls[backend_id] += 1
                if self._half_open_calls[backend_id] >= self.half_open_max_calls:
                    # All half-open calls succeeded, close circuit
                    self._states[backend_id] = CircuitBreaker.State.CLOSED
                    self._failure_counts[backend_id] = 0
                    self.logger.info(f"Circuit closed for {backend_id}")
            else:
                self._failure_counts[backend_id] = 0

    def record_failure(self, backend_id: str) -> None:
        """Record a failed request."""
        with self._lock:
            state = self._states.get(backend_id, CircuitBreaker.State.CLOSED)

            if state == CircuitBreaker.State.HALF_OPEN:
                # Failure in half-open state, reopen circuit
                self._states[backend_id] = CircuitBreaker.State.OPEN
                self._last_failure_times[backend_id] = datetime.now()
                self.logger.warning(f"Circuit reopened for {backend_id}")
                return

            self._failure_counts[backend_id] += 1
            self._last_failure_times[backend_id] = datetime.now()

            if self._failure_counts[backend_id] >= self.failure_threshold:
                self._states[backend_id] = CircuitBreaker.State.OPEN
                self.logger.warning(
                    f"Circuit opened for {backend_id} after "
                    f"{self._failure_counts[backend_id]} failures"
                )

    def get_state(self, backend_id: str) -> State:
        """Get current circuit state for a backend."""
        with self._lock:
            return self._states.get(backend_id, CircuitBreaker.State.CLOSED)

    def reset(self, backend_id: str) -> None:
        """Manually reset circuit breaker for a backend."""
        with self._lock:
            self._states[backend_id] = CircuitBreaker.State.CLOSED
            self._failure_counts[backend_id] = 0
            self._half_open_calls[backend_id] = 0
            self.logger.info(f"Circuit manually reset for {backend_id}")


# =============================================================================
# LOAD BALANCER METRICS
# =============================================================================

@dataclass
class LoadBalancerMetrics:
    """Aggregated metrics for the load balancer."""
    total_requests: int = 0
    successful_requests: int = 0
    failed_requests: int = 0
    requests_per_second: float = 0.0
    average_latency_ms: float = 0.0
    active_backends: int = 0
    unhealthy_backends: int = 0
    active_sessions: int = 0

    def to_dict(self) -> Dict[str, Any]:
        return {
            "total_requests": self.total_requests,
            "successful_requests": self.successful_requests,
            "failed_requests": self.failed_requests,
            "requests_per_second": round(self.requests_per_second, 2),
            "average_latency_ms": round(self.average_latency_ms, 2),
            "active_backends": self.active_backends,
            "unhealthy_backends": self.unhealthy_backends,
            "active_sessions": self.active_sessions,
            "success_rate": round(
                self.successful_requests / max(1, self.total_requests) * 100, 2
            )
        }


class MetricsCollector:
    """Collects and aggregates metrics for the load balancer."""

    def __init__(self):
        self._start_time = time.time()
        self._request_times: List[float] = []
        self._latencies: List[float] = []
        self._lock = threading.Lock()

        # Per-second tracking
        self._requests_this_second = 0
        self._current_second = int(time.time())
        self._rps_samples: List[int] = []

    def record_request(self, latency_ms: float, success: bool) -> None:
        """Record a completed request."""
        now = time.time()
        current_second = int(now)

        with self._lock:
            self._request_times.append(now)
            self._latencies.append(latency_ms)

            # Keep only last 60 seconds of data
            cutoff = now - 60
            self._request_times = [t for t in self._request_times if t > cutoff]
            self._latencies = self._latencies[-1000:]  # Keep last 1000 latencies

            # Track RPS
            if current_second != self._current_second:
                self._rps_samples.append(self._requests_this_second)
                self._rps_samples = self._rps_samples[-60:]
                self._requests_this_second = 1
                self._current_second = current_second
            else:
                self._requests_this_second += 1

    def get_rps(self) -> float:
        """Get current requests per second."""
        with self._lock:
            if not self._rps_samples:
                return 0.0
            return sum(self._rps_samples) / len(self._rps_samples)

    def get_average_latency(self) -> float:
        """Get average latency in milliseconds."""
        with self._lock:
            if not self._latencies:
                return 0.0
            return sum(self._latencies) / len(self._latencies)

    def get_request_count(self, window_seconds: float = 60.0) -> int:
        """Get request count in the specified window."""
        now = time.time()
        cutoff = now - window_seconds
        with self._lock:
            return sum(1 for t in self._request_times if t > cutoff)


# =============================================================================
# MAIN LOAD BALANCER
# =============================================================================

class LoadBalancer:
    """
    Production-grade load balancer for AIVA multi-instance deployment.

    Features:
    - Multiple load balancing strategies
    - Health checking with circuit breakers
    - Session affinity (sticky sessions)
    - Connection pooling
    - Comprehensive metrics
    - Graceful shutdown
    """

    def __init__(
        self,
        config: Optional[LoadBalancerConfig] = None,
        strategy: Optional[LoadBalancingStrategy] = None,
        logger: Optional[logging.Logger] = None
    ):
        self.config = config or LoadBalancerConfig()
        self.strategy = strategy or RoundRobin()
        self.logger = logger or self._create_logger()

        self._backends: Dict[str, Backend] = {}
        self._backends_lock = threading.Lock()

        self.health_checker = HealthChecker(self.config, self.logger)
        self.session_affinity = SessionAffinity(
            session_ttl=self.config.session_ttl,
            logger=self.logger
        )
        self.circuit_breaker = CircuitBreaker(
            failure_threshold=self.config.max_consecutive_failures,
            recovery_timeout=self.config.circuit_breaker_timeout,
            logger=self.logger
        )
        self.metrics_collector = MetricsCollector()

        self._session: Optional[aiohttp.ClientSession] = None
        self._running = False

        # Total request counters
        self._total_requests = 0
        self._successful_requests = 0
        self._failed_requests = 0

        # Setup health check callback
        self.health_checker.add_status_callback(self._on_backend_status_change)

    def _create_logger(self) -> logging.Logger:
        """Create a configured logger."""
        logger = logging.getLogger("LoadBalancer")
        logger.setLevel(getattr(logging, self.config.log_level))

        if not logger.handlers:
            handler = logging.StreamHandler()
            handler.setFormatter(logging.Formatter(
                "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
            ))
            logger.addHandler(handler)

        return logger

    async def start(self) -> None:
        """Start the load balancer."""
        if self._running:
            self.logger.warning("Load balancer already running")
            return

        self._running = True

        # Create HTTP session with connection pooling
        connector = aiohttp.TCPConnector(
            limit=self.config.max_connections_per_backend * len(self._backends),
            limit_per_host=self.config.max_connections_per_backend,
            ttl_dns_cache=300,
            enable_cleanup_closed=True
        )
        self._session = aiohttp.ClientSession(
            connector=connector,
            timeout=aiohttp.ClientTimeout(
                total=self.config.request_timeout,
                connect=self.config.connection_timeout
            )
        )

        # Start subsystems
        await self.session_affinity.start()
        await self.health_checker.start(list(self._backends.values()))

        self.logger.info(
            f"Load balancer started with {len(self._backends)} backends "
            f"using {self.strategy.name} strategy"
        )

    async def stop(self) -> None:
        """Stop the load balancer gracefully."""
        self._running = False
        self.logger.info("Stopping load balancer...")

        # Stop subsystems
        await self.health_checker.stop()
        await self.session_affinity.stop()

        # Close HTTP session
        if self._session:
            await self._session.close()

        self.logger.info("Load balancer stopped")

    @asynccontextmanager
    async def running(self):
        """Context manager for running the load balancer."""
        await self.start()
        try:
            yield self
        finally:
            await self.stop()

    def add_backend(
        self,
        host: str,
        port: int,
        weight: int = 1,
        max_connections: int = 100,
        backend_id: Optional[str] = None,
        metadata: Optional[Dict[str, Any]] = None
    ) -> Backend:
        """Add a backend to the load balancer."""
        backend_id = backend_id or f"{host}:{port}"

        backend = Backend(
            id=backend_id,
            host=host,
            port=port,
            weight=weight,
            max_connections=max_connections,
            metadata=metadata or {}
        )

        with self._backends_lock:
            self._backends[backend_id] = backend

        self.logger.info(f"Added backend {backend_id} (weight={weight})")
        return backend

    def remove_backend(self, backend_id: str) -> bool:
        """Remove a backend from the load balancer."""
        with self._backends_lock:
            if backend_id in self._backends:
                del self._backends[backend_id]
                self.logger.info(f"Removed backend {backend_id}")
                return True
        return False

    def get_backend(self, backend_id: str) -> Optional[Backend]:
        """Get a backend by ID."""
        with self._backends_lock:
            return self._backends.get(backend_id)

    def list_backends(self) -> List[Backend]:
        """List all backends."""
        with self._backends_lock:
            return list(self._backends.values())

    def drain_backend(self, backend_id: str) -> bool:
        """Put a backend into draining mode."""
        backend = self.get_backend(backend_id)
        if backend:
            backend.status = BackendStatus.DRAINING
            self.logger.info(f"Backend {backend_id} is now draining")
            return True
        return False

    async def forward_request(self, request: Request) -> Response:
        """Forward a request to an appropriate backend."""
        start_time = time.time()
        self._total_requests += 1

        # Check for session affinity
        session_id = self.session_affinity.get_session_id(request)
        new_session = False

        if session_id:
            backend_id = self.session_affinity.get_backend_for_session(session_id)
            if backend_id:
                backend = self.get_backend(backend_id)
                if backend and backend.is_available:
                    return await self._forward_to_backend(
                        request, backend, session_id, start_time
                    )

        # Select backend using strategy
        backends = self.list_backends()
        available_backends = [
            b for b in backends
            if b.is_available and self.circuit_breaker.is_allowed(b.id)
        ]

        if not available_backends:
            self._failed_requests += 1
            self.logger.error("No available backends")
            return Response(
                status_code=503,
                headers={"X-LB-Error": "No available backends"},
                body=b'{"error": "Service Unavailable"}',
                backend_id="",
                latency_ms=(time.time() - start_time) * 1000
            )

        backend = self.strategy.select_backend(available_backends, request)

        if not backend:
            self._failed_requests += 1
            return Response(
                status_code=503,
                headers={"X-LB-Error": "Backend selection failed"},
                body=b'{"error": "Service Unavailable"}',
                backend_id="",
                latency_ms=(time.time() - start_time) * 1000
            )

        # Create new session if needed
        if not session_id:
            session_id = self.session_affinity.create_session(request)
            new_session = True

        return await self._forward_to_backend(
            request, backend, session_id, start_time, new_session
        )

    async def _forward_to_backend(
        self,
        request: Request,
        backend: Backend,
        session_id: str,
        start_time: float,
        new_session: bool = False
    ) -> Response:
        """Forward request to a specific backend with retries."""
        last_error = None

        for attempt in range(self.config.retry_attempts):
            try:
                response = await self._do_forward(
                    request, backend, session_id, new_session
                )

                latency_ms = (time.time() - start_time) * 1000
                response.latency_ms = latency_ms

                # Update metrics
                backend.metrics.total_requests += 1
                backend.metrics.total_latency_ms += latency_ms
                backend.metrics.last_success = datetime.now()

                if response.is_success:
                    self._successful_requests += 1
                    backend.metrics.successful_requests += 1
                    self.circuit_breaker.record_success(backend.id)
                    self.session_affinity.bind_session(session_id, backend.id)
                else:
                    self._failed_requests += 1
                    backend.metrics.failed_requests += 1

                self.strategy.on_request_complete(backend, request, response.is_success)
                self.metrics_collector.record_request(latency_ms, response.is_success)

                return response

            except Exception as e:
                last_error = e
                self.logger.warning(
                    f"Request to {backend.id} failed (attempt {attempt + 1}): {e}"
                )

                if attempt < self.config.retry_attempts - 1:
                    await asyncio.sleep(self.config.retry_delay * (attempt + 1))

        # All retries failed
        self._failed_requests += 1
        backend.metrics.failed_requests += 1
        backend.metrics.consecutive_failures += 1
        backend.metrics.last_failure = datetime.now()
        self.circuit_breaker.record_failure(backend.id)
        self.strategy.on_request_complete(backend, request, False)

        latency_ms = (time.time() - start_time) * 1000
        self.metrics_collector.record_request(latency_ms, False)

        return Response(
            status_code=502,
            headers={"X-LB-Error": str(last_error)},
            body=b'{"error": "Bad Gateway"}',
            backend_id=backend.id,
            latency_ms=latency_ms
        )

    async def _do_forward(
        self,
        request: Request,
        backend: Backend,
        session_id: str,
        new_session: bool
    ) -> Response:
        """Perform the actual HTTP request to backend."""
        backend.metrics.active_connections += 1

        try:
            url = f"{backend.url}{request.path}"

            # Prepare headers
            headers = dict(request.headers)
            headers["X-Forwarded-For"] = request.client_ip
            headers["X-Request-ID"] = request.id
            headers["X-Backend-ID"] = backend.id

            async with self._session.request(
                method=request.method,
                url=url,
                headers=headers,
                data=request.body
            ) as resp:
                body = await resp.read()
                response_headers = dict(resp.headers)

                # Add session cookie if new session
                if new_session:
                    response_headers["Set-Cookie"] = self.session_affinity.get_session_cookie(session_id)

                response_headers["X-Backend-ID"] = backend.id

                return Response(
                    status_code=resp.status,
                    headers=response_headers,
                    body=body,
                    backend_id=backend.id,
                    latency_ms=0  # Will be set by caller
                )
        finally:
            backend.metrics.active_connections -= 1

    def _on_backend_status_change(
        self,
        backend: Backend,
        new_status: BackendStatus
    ) -> None:
        """Handle backend status changes."""
        self.logger.info(f"Backend {backend.id} status changed to {new_status.value}")

        if new_status == BackendStatus.UNHEALTHY:
            # Open circuit breaker
            backend.metrics.circuit_opened_at = datetime.now()
            backend.status = BackendStatus.CIRCUIT_OPEN

    def get_metrics(self) -> LoadBalancerMetrics:
        """Get current load balancer metrics."""
        backends = self.list_backends()
        healthy = sum(1 for b in backends if b.status == BackendStatus.HEALTHY)
        unhealthy = len(backends) - healthy

        return LoadBalancerMetrics(
            total_requests=self._total_requests,
            successful_requests=self._successful_requests,
            failed_requests=self._failed_requests,
            requests_per_second=self.metrics_collector.get_rps(),
            average_latency_ms=self.metrics_collector.get_average_latency(),
            active_backends=healthy,
            unhealthy_backends=unhealthy,
            active_sessions=self.session_affinity.active_sessions
        )

    def get_backend_metrics(self) -> Dict[str, Dict[str, Any]]:
        """Get metrics for all backends."""
        result = {}
        for backend in self.list_backends():
            result[backend.id] = {
                "status": backend.status.value,
                "weight": backend.weight,
                "total_requests": backend.metrics.total_requests,
                "successful_requests": backend.metrics.successful_requests,
                "failed_requests": backend.metrics.failed_requests,
                "average_latency_ms": round(backend.metrics.average_latency_ms, 2),
                "active_connections": backend.metrics.active_connections,
                "success_rate": round(backend.metrics.success_rate * 100, 2),
                "circuit_state": self.circuit_breaker.get_state(backend.id).value
            }
        return result

    def set_strategy(self, strategy: LoadBalancingStrategy) -> None:
        """Change the load balancing strategy."""
        old_name = self.strategy.name
        self.strategy = strategy
        self.logger.info(f"Strategy changed from {old_name} to {strategy.name}")


# =============================================================================
# FACTORY FUNCTIONS
# =============================================================================

def create_load_balancer(
    strategy: str = "round_robin",
    config: Optional[LoadBalancerConfig] = None,
    **kwargs
) -> LoadBalancer:
    """
    Factory function to create a load balancer with the specified strategy.

    Args:
        strategy: One of 'round_robin', 'least_connections', 'weighted',
                  'ip_hash', 'random'
        config: Optional LoadBalancerConfig
        **kwargs: Additional arguments for strategy

    Returns:
        Configured LoadBalancer instance
    """
    strategies = {
        "round_robin": RoundRobin,
        "least_connections": LeastConnections,
        "weighted": lambda: WeightedStrategy(adaptive=kwargs.get("adaptive", True)),
        "ip_hash": IPHashStrategy,
        "random": RandomStrategy
    }

    if strategy not in strategies:
        raise ValueError(f"Unknown strategy: {strategy}. Available: {list(strategies.keys())}")

    strategy_instance = strategies[strategy]()
    return LoadBalancer(config=config, strategy=strategy_instance)


# =============================================================================
# EXAMPLE USAGE AND TESTING
# =============================================================================

async def example_usage():
    """Example usage of the load balancer."""

    # Create configuration
    config = LoadBalancerConfig(
        health_check_interval=5.0,
        max_consecutive_failures=3,
        session_ttl=3600.0
    )

    # Create load balancer with weighted strategy
    lb = create_load_balancer(
        strategy="weighted",
        config=config,
        adaptive=True
    )

    # Add backends
    lb.add_backend("aiva-1.local", 8000, weight=3)
    lb.add_backend("aiva-2.local", 8000, weight=2)
    lb.add_backend("aiva-3.local", 8000, weight=1)

    async with lb.running():
        # Create a sample request
        request = Request(
            id=Request.generate_id(),
            method="POST",
            path="/api/v1/query",
            headers={"Content-Type": "application/json"},
            body=b'{"query": "What is Genesis?"}',
            client_ip="192.168.1.100"
        )

        # Forward request
        response = await lb.forward_request(request)

        print(f"Response status: {response.status_code}")
        print(f"Backend: {response.backend_id}")
        print(f"Latency: {response.latency_ms:.2f}ms")

        # Get metrics
        metrics = lb.get_metrics()
        print(f"\nLoad Balancer Metrics:")
        print(json.dumps(metrics.to_dict(), indent=2))

        # Get backend metrics
        backend_metrics = lb.get_backend_metrics()
        print(f"\nBackend Metrics:")
        print(json.dumps(backend_metrics, indent=2))


if __name__ == "__main__":
    print("AIVA Production Load Balancer")
    print("=" * 50)
    print("Components:")
    print("  - LoadBalancer: Main orchestrator")
    print("  - RoundRobin: Round-robin strategy")
    print("  - LeastConnections: Least connections strategy")
    print("  - WeightedStrategy: Weighted distribution")
    print("  - HealthChecker: Backend health checks")
    print("  - SessionAffinity: Sticky sessions")
    print("  - CircuitBreaker: Failure isolation")
    print("  - MetricsCollector: Performance metrics")
    print()

    # Run example if aiohttp is available
    try:
        asyncio.run(example_usage())
    except Exception as e:
        print(f"Example requires running backends: {e}")
        print("\nTo use in production:")
        print("""
    from prod_03_load_balancer import create_load_balancer, LoadBalancerConfig

    config = LoadBalancerConfig(
        health_check_interval=10.0,
        session_ttl=3600.0
    )

    lb = create_load_balancer("weighted", config=config)
    lb.add_backend("aiva-1.internal", 8000, weight=3)
    lb.add_backend("aiva-2.internal", 8000, weight=2)

    async with lb.running():
        response = await lb.forward_request(request)
        """)
