# error_handler.py
"""
Centralized error handling with categorized exceptions and correlation IDs.
"""

import logging
import traceback
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional
from functools import wraps

import psycopg2
from psycopg2 import pool
from fastapi import Request, HTTPException
from pydantic import BaseModel

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - [%(correlation_id)s] - %(message)s'
)
logger = logging.getLogger(__name__)


class ErrorCategory(Enum):
    """Error categories for classification."""
    DATABASE = "DATABASE"
    NETWORK = "NETWORK"
    EXTERNAL_SERVICE = "EXTERNAL_SERVICE"
    VALIDATION = "VALIDATION"
    AUTHENTICATION = "AUTHENTICATION"
    RATE_LIMIT = "RATE_LIMIT"
    QUEUE_OVERFLOW = "QUEUE_OVERFLOW"
    TIMEOUT = "TIMEOUT"
    UNKNOWN = "UNKNOWN"


class ErrorSeverity(Enum):
    """Error severity levels."""
    LOW = "LOW"
    MEDIUM = "MEDIUM"
    HIGH = "HIGH"
    CRITICAL = "CRITICAL"


class BridgeError(Exception):
    """Base exception for all bridge errors."""
    
    def __init__(
        self,
        message: str,
        category: ErrorCategory = ErrorCategory.UNKNOWN,
        severity: ErrorSeverity = ErrorSeverity.MEDIUM,
        correlation_id: Optional[str] = None,
        details: Optional[Dict[str, Any]] = None,
        retryable: bool = False
    ):
        super().__init__(message)
        self.message = message
        self.category = category
        self.severity = severity
        self.correlation_id = correlation_id or str(uuid.uuid4())
        self.details = details or {}
        self.retryable = retryable
        self.timestamp = datetime.utcnow()
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "message": self.message,
            "category": self.category.value,
            "severity": self.severity.value,
            "correlation_id": self.correlation_id,
            "details": self.details,
            "retryable": self.retryable,
            "timestamp": self.timestamp.isoformat()
        }


class DatabaseError(BridgeError):
    """Database-related errors."""
    
    def __init__(self, message: str, correlation_id: Optional[str] = None, 
                 details: Optional[Dict] = None, retryable: bool = True):
        super().__init__(
            message=message,
            category=ErrorCategory.DATABASE,
            severity=ErrorSeverity.HIGH,
            correlation_id=correlation_id,
            details=details,
            retryable=retryable
        )


class NetworkError(BridgeError):
    """Network-related errors."""
    
    def __init__(self, message: str, correlation_id: Optional[str] = None,
                 details: Optional[Dict] = None, retryable: bool = True):
        super().__init__(
            message=message,
            category=ErrorCategory.NETWORK,
            severity=ErrorSeverity.MEDIUM,
            correlation_id=correlation_id,
            details=details,
            retryable=retryable
        )


class ExternalServiceError(BridgeError):
    """External service (Telnyx, Claude Code) errors."""
    
    def __init__(self, message: str, correlation_id: Optional[str] = None,
                 details: Optional[Dict] = None, retryable: bool = True):
        super().__init__(
            message=message,
            category=ErrorCategory.EXTERNAL_SERVICE,
            severity=ErrorSeverity.HIGH,
            correlation_id=correlation_id,
            details=details,
            retryable=retryable
        )


class ValidationError(BridgeError):
    """Payload validation errors."""
    
    def __init__(self, message: str, correlation_id: Optional[str] = None,
                 details: Optional[Dict] = None):
        super().__init__(
            message=message,
            category=ErrorCategory.VALIDATION,
            severity=ErrorSeverity.LOW,
            correlation_id=correlation_id,
            details=details,
            retryable=False
        )


class QueueOverflowError(BridgeError):
    """Queue overflow errors."""
    
    def __init__(self, message: str, correlation_id: Optional[str] = None,
                 details: Optional[Dict] = None):
        super().__init__(
            message=message,
            category=ErrorCategory.QUEUE_OVERFLOW,
            severity=ErrorSeverity.CRITICAL,
            correlation_id=correlation_id,
            details=details,
            retryable=False
        )


class RateLimitError(BridgeError):
    """Rate limiting errors."""
    
    def __init__(self, message: str, retry_after: int = 60,
                 correlation_id: Optional[str] = None, details: Optional[Dict] = None):
        details = details or {}
        details["retry_after"] = retry_after
        super().__init__(
            message=message,
            category=ErrorCategory.RATE_LIMIT,
            severity=ErrorSeverity.MEDIUM,
            correlation_id=correlation_id,
            details=details,
            retryable=True
        )


class AuthenticationError(BridgeError):
    """Authentication/Authorization errors."""
    
    def __init__(self, message: str, correlation_id: Optional[str] = None,
                 details: Optional[Dict] = None):
        super().__init__(
            message=message,
            category=ErrorCategory.AUTHENTICATION,
            severity=ErrorSeverity.CRITICAL,
            correlation_id=correlation_id,
            details=details,
            retryable=False
        )


class ErrorHandler:
    """Centralized error handler with logging and persistence."""
    
    def __init__(self, db_pool: pool.ThreadedConnectionPool):
        self.db_pool = db_pool
        self._setup_error_logging_table()
    
    def _setup_error_logging_table(self):
        """Create error log table if not exists."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        CREATE TABLE IF NOT EXISTS genesis_bridge.error_log (
                            id SERIAL PRIMARY KEY,
                            correlation_id VARCHAR(255) NOT NULL,
                            error_type VARCHAR(255) NOT NULL,
                            category VARCHAR(50) NOT NULL,
                            severity VARCHAR(50) NOT NULL,
                            message TEXT NOT NULL,
                            details JSONB,
                            retryable BOOLEAN DEFAULT FALSE,
                            created_at TIMESTAMP DEFAULT NOW(),
                            resolved_at TIMESTAMP,
                            resolved_by VARCHAR(255)
                        );
                        CREATE INDEX IF NOT EXISTS idx_error_log_correlation_id 
                            ON genesis_bridge.error_log(correlation_id);
                        CREATE INDEX IF NOT EXISTS idx_error_log_category 
                            ON genesis_bridge.error_log(category);
                        CREATE INDEX IF NOT EXISTS idx_error_log_created_at 
                            ON genesis_bridge.error_log(created_at);
                    """)
                    conn.commit()
        except Exception as e:
            logger.error(f"Failed to setup error logging table: {e}")
    
    def log_error(self, error: Exception, context: Optional[Dict] = None) -> str:
        """Log error to database and return correlation ID."""
        correlation_id = str(uuid.uuid4())
        
        if isinstance(error, BridgeError):
            error_data = error.to_dict()
            correlation_id = error.correlation_id or correlation_id
        else:
            error_data = {
                "message": str(error),
                "category": ErrorCategory.UNKNOWN.value,
                "severity": ErrorSeverity.HIGH.value,
                "correlation_id": correlation_id,
                "details": {"traceback": traceback.format_exc()},
                "retryable": False,
                "timestamp": datetime.utcnow().isoformat()
            }
        
        if context:
            error_data["details"]["context"] = context
        
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        INSERT INTO genesis_bridge.error_log 
                        (correlation_id, error_type, category, severity, message, details, retryable)
                        VALUES (%s, %s, %s, %s, %s, %s, %s)
                    """, (
                        correlation_id,
                        type(error).__name__,
                        error_data.get("category", ErrorCategory.UNKNOWN.value),
                        error_data.get("severity", ErrorSeverity.HIGH.value),
                        error_data.get("message", str(error)),
                        psycopg2.extras.Json(error_data.get("details", {})),
                        error_data.get("retryable", False)
                    ))
                    conn.commit()
        except Exception as e:
            logger.error(f"Failed to log error to database: {e}")
        
        logger.error(
            f"[{correlation_id}] {error_data.get('category', 'UNKNOWN')}: {error_data.get('message', str(error))}",
            extra={"correlation_id": correlation_id}
        )
        
        return correlation_id
    
    def handle_http_exception(self, error: BridgeError) -> HTTPException:
        """Convert bridge error to HTTP exception."""
        status_codes = {
            ErrorCategory.DATABASE: 503,
            ErrorCategory.NETWORK: 503,
            ErrorCategory.EXTERNAL_SERVICE: 502,
            ErrorCategory.VALIDATION: 400,
            ErrorCategory.AUTHENTICATION: 401,
            ErrorCategory.RATE_LIMIT: 429,
            ErrorCategory.QUEUE_OVERFLOW: 503,
            ErrorCategory.TIMEOUT: 504,
        }
        
        status_code = status_codes.get(error.category, 500)
        
        return HTTPException(
            status_code=status_code,
            detail=error.to_dict()
        )


def with_error_handling(handler: 'ErrorHandler'):
    """Decorator for automatic error handling."""
    def decorator(func):
        @wraps(func)
        async def async_wrapper(*args, **kwargs):
            correlation_id = str(uuid.uuid4())
            try:
                return await func(*args, **kwargs)
            except BridgeError as e:
                e.correlation_id = e.correlation_id or correlation_id
                handler.log_error(e, {"function": func.__name__})
                raise handler.handle_http_exception(e)
            except Exception as e:
                error = BridgeError(
                    message=str(e),
                    category=ErrorCategory.UNKNOWN,
                    severity=ErrorSeverity.HIGH,
                    correlation_id=correlation_id,
                    details={"function": func.__name__, "traceback": traceback.format_exc()}
                )
                handler.log_error(error)
                raise handler.handle_http_exception(error)
        
        @wraps(func)
        def sync_wrapper(*args, **kwargs):
            correlation_id = str(uuid.uuid4())
            try:
                return func(*args, **kwargs)
            except BridgeError as e:
                e.correlation_id = e.correlation_id or correlation_id
                handler.log_error(e, {"function": func.__name__})
                raise handler.handle_http_exception(e)
            except Exception as e:
                error = BridgeError(
                    message=str(e),
                    category=ErrorCategory.UNKNOWN,
                    severity=ErrorSeverity.HIGH,
                    correlation_id=correlation_id,
                    details={"function": func.__name__, "traceback": traceback.format_exc()}
                )
                handler.log_error(error)
                raise handler.handle_http_exception(error)
        
        import asyncio
        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        return sync_wrapper
    return decorator


def get_correlation_id(request: Request) -> str:
    """Extract or generate correlation ID from request."""
    return request.headers.get("X-Correlation-ID", str(uuid.uuid4()))