# circuit_breaker.py
"""
Circuit breaker pattern for external service calls.
"""

import logging
import threading
import time
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Callable, Dict, Optional
from functools import wraps

import psycopg2
from psycopg2 import pool

logger = logging.getLogger(__name__)


class CircuitState(Enum):
    """Circuit breaker states."""
    CLOSED = "CLOSED"      # Normal operation
    OPEN = "OPEN"         # Failing, rejecting calls
    HALF_OPEN = "HALF_OPEN"  # Testing if service recovered


class CircuitBreakerConfig:
    """Configuration for circuit breaker."""
    
    def __init__(
        self,
        failure_threshold: int = 5,
        success_threshold: int = 2,
        timeout: float = 30.0,
        half_open_max_calls: int = 3,
        excluded_exceptions: tuple = ()
    ):
        self.failure_threshold = failure_threshold
        self.success_threshold = success_threshold
        self.timeout = timeout
        self.half_open_max_calls = half_open_max_calls
        self.excluded_exceptions = excluded_exceptions


class CircuitBreaker:
    """Circuit breaker implementation."""
    
    def __init__(
        self,
        name: str,
        db_pool: pool.ThreadedConnectionPool,
        config: Optional[CircuitBreakerConfig] = None
    ):
        self.name = name
        self.db_pool = db_pool
        self.config = config or CircuitBreakerConfig()
        
        self._state = CircuitState.CLOSED
        self._failure_count = 0
        self._success_count = 0
        self._last_failure_time: Optional[datetime] = None
        self._last_state_change: datetime = datetime.utcnow()
        self._half_open_calls = 0
        self._lock = threading.RLock()
        
        self._setup_circuit_table()
        self._load_state_from_db()
    
    def _setup_circuit_table(self):
        """Create circuit breaker state table."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        CREATE TABLE IF NOT EXISTS genesis_bridge.circuit_breakers (
                            id SERIAL PRIMARY KEY,
                            service_name VARCHAR(255) NOT NULL UNIQUE,
                            state VARCHAR(50) NOT NULL,
                            failure_count INTEGER DEFAULT 0,
                            success_count INTEGER DEFAULT 0,
                            last_failure_time TIMESTAMP,
                            last_state_change TIMESTAMP NOT NULL,
                            half_open_calls INTEGER DEFAULT 0,
                            total_calls INTEGER DEFAULT 0,
                            total_failures INTEGER DEFAULT 0,
                            created_at TIMESTAMP DEFAULT NOW(),
                            updated_at TIMESTAMP DEFAULT NOW()
                        );
                    """)
                    conn.commit()
        except Exception as e:
            logger.error(f"Failed to setup circuit breaker table: {e}")
    
    def _load_state_from_db(self):
        """Load circuit state from database."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        SELECT state, failure_count, success_count, 
                               last_failure_time, last_state_change, half_open_calls
                        FROM genesis_bridge.circuit_breakers
                        WHERE service_name = %s
                    """, (self.name,))
                    row = cur.fetchone()
                    
                    if row:
                        self._state = CircuitState(row[0])
                        self._failure_count = row[1]
                        self._success_count = row[2]
                        self._last_failure_time = row[3]
                        self._last_state_change = row[4]
                        self._half_open_calls = row[5]
                        
                        # Check if we should transition from OPEN to HALF_OPEN
                        if self._state == CircuitState.OPEN:
                            if self._last_failure_time and \
                               (datetime.utcnow() - self._last_failure_time).total_seconds() > self.config.timeout:
                                self._state = CircuitState.HALF_OPEN
                                self._half_open_calls = 0
                                self._save_state()
        except Exception as e:
            logger.error(f"Failed to load circuit state: {e}")
    
    def _save_state(self):
        """Save circuit state to database."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        INSERT INTO genesis_bridge.circuit_breakers 
                        (service_name, state, failure_count, success_count, 
                         last_failure_time, last_state_change, half_open_calls)
                        VALUES (%s, %s, %s, %s, %s, %s, %s)
                        ON CONFLICT (service_name) DO UPDATE SET
                            state = EXCLUDED.state,
                            failure_count = EXCLUDED.failure_count,
                            success_count = EXCLUDED.success_count,
                            last_failure_time = EXCLUDED.last_failure_time,
                            last_state_change = EXCLUDED.last_state_change,
                            half_open_calls = EXCLUDED.half_open_calls,
                            updated_at = NOW()
                    """, (
                        self.name,
                        self._state.value,
                        self._failure_count,
                        self._success_count,
                        self._last_failure_time,
                        self._last_state_change,
                        self._half_open_calls
                    ))
                    conn.commit()
        except Exception as e:
            logger.error(f"Failed to save circuit state: {e}")
    
    @property
    def state(self) -> CircuitState:
        """Get current circuit state."""
        with self._lock:
            # Check if we should transition from OPEN to HALF_OPEN
            if self._state == CircuitState.OPEN and self._last_failure_time:
                if (datetime.utcnow() - self._last_failure_time).total_seconds() >= self.config.timeout:
                    self._state = CircuitState.HALF_OPEN
                    self._half_open_calls = 0
                    self._last_state_change = datetime.utcnow()
                    self._save_state()
            return self._state
    
    def can_execute(self) -> bool:
        """Check if execution is allowed."""
        if self.state == CircuitState.CLOSED:
            return True
        elif self.state == CircuitState.HALF_OPEN:
            return self._half_open_calls < self.config.half_open_max_calls
        else:  # OPEN
            return False
    
    def _record_success(self):
        """Record successful call."""
        with self._lock:
            self._failure_count = 0
            
            if self._state == CircuitState.HALF_OPEN:
                self._success_count += 1
                if self._success_count >= self.config.success_threshold:
                    self._state = CircuitState.CLOSED
                    self._success_count = 0
                    self._half_open_calls = 0
                    logger.info(f"Circuit {self.name} closed")
            elif self._state == CircuitState.CLOSED:
                self._success_count = 0
            
            self._last_state_change = datetime.utcnow()
            self._save_state()
    
    def _record_failure(self):
        """Record failed call."""
        with self._lock:
            self._failure_count += 1
            self._last_failure_time = datetime.utcnow()
            
            if self._state == CircuitState.HALF_OPEN:
                self._state = CircuitState.OPEN
                self._half_open_calls = 0
                logger.warning(f"Circuit {self.name} opened after half-open failure")
            elif self._state == CircuitState.CLOSED:
                if self._failure_count >= self.config.failure_threshold:
                    self._state = CircuitState.OPEN
                    logger.warning(f"Circuit {self.name} opened after {self._failure_count} failures")
            
            self._last_state_change = datetime.utcnow()
            self._save_state()
    
    def record_call(self, success: bool):
        """Record call result."""
        if success:
            self._record_success()
        else:
            self._record_failure()
    
    def get_stats(self) -> Dict[str, Any]:
        """Get circuit breaker statistics."""
        with self._lock:
            return {
                "service_name": self.name,
                "state": self.state.value,
                "failure_count": self._failure_count,
                "success_count": self._success_count,
                "last_failure_time": self._last_failure_time.isoformat() if self._last_failure_time else None,
                "last_state_change": self._last_state_change.isoformat(),
                "half_open_calls": self._half_open_calls,
                "can_execute": self.can_execute()
            }
    
    def reset(self):
        """Manually reset circuit breaker."""
        with self._lock:
            self._state = CircuitState.CLOSED
            self._failure_count = 0
            self._success_count = 0
            self._last_failure_time = None
            self._half_open_calls = 0
            self._last_state_change = datetime.utcnow()
            self._save_state()
            logger.info(f"Circuit {self.name} manually reset")


class CircuitBreakerManager:
    """Manager for multiple circuit breakers."""
    
    def __init__(self, db_pool: pool.ThreadedConnectionPool):
        self.db_pool = db_pool
        self._breakers: Dict[str, CircuitBreaker] = {}
        self._lock = threading.Lock()
    
    def get_breaker(
        self,
        name: str,
        config: Optional[CircuitBreakerConfig] = None
    ) -> CircuitBreaker:
        """Get or create circuit breaker."""
        with self._lock:
            if name not in self._breakers:
                self._breakers[name] = CircuitBreaker(name, self.db_pool, config)
            return self._breakers[name]
    
    def get_all_stats(self) -> Dict[str, Dict]:
        """Get statistics for all circuit breakers."""
        return {name: breaker.get_stats() for name, breaker in self._breakers.items()}


def circuit_breaker(
    service_name: str,
    config: Optional[CircuitBreakerConfig] = None,
    manager: Optional[CircuitBreakerManager] = None
):
    """Decorator for circuit breaker protection."""
    def decorator(func):
        @wraps(func)
        async def async_wrapper(*args, **kwargs):
            if manager:
                breaker = manager.get_breaker(service_name, config)
            else:
                breaker = CircuitBreaker(service_name, None, config)
            
            if not breaker.can_execute():
                raise Exception(f"Circuit breaker {service_name} is OPEN")
            
            try:
                if breaker.state == CircuitBreaker(state=CircuitState.HALF_OPEN):
                    with breaker._lock:
                        breaker._half_open_calls += 1
                
                result = await func(*args, **kwargs)
                breaker.record_call(success=True)
                return result
            except Exception as e:
                # Check if exception should be excluded
                if config and config.excluded_exceptions and isinstance(e, config.excluded_exceptions):
                    raise
                breaker.record_call(success=False)
                raise
        
        @wraps(func)
        def sync_wrapper(*args, **kwargs):
            if manager:
                breaker = manager.get_breaker(service_name, config)
            else:
                breaker = CircuitBreaker(service_name, None, config)
            
            if not breaker.can_execute():
                raise Exception(f"Circuit breaker {service_name} is OPEN")
            
            try:
                if breaker.state == CircuitState.HALF_OPEN:
                    with breaker._lock:
                        breaker._half_open_calls += 1
                
                result = func(*args, **kwargs)
                breaker.record_call(success=True)
                return result
            except Exception as e:
                if config and config.excluded_exceptions and isinstance(e, config.excluded_exceptions):
                    raise
                breaker.record_call(success=False)
                raise
        
        import asyncio
        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        return sync_wrapper
    
    return decorator