# retry_engine.py
"""
Configurable retry engine with exponential backoff, jitter, and max retries.
"""

import asyncio
import logging
import random
import time
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
from functools import wraps
from enum import Enum

import psycopg2
from psycopg2 import pool

logger = logging.getLogger(__name__)

T = TypeVar('T')


class RetryStrategy(Enum):
    """Retry strategies."""
    EXPONENTIAL = "EXPONENTIAL"
    LINEAR = "LINEAR"
    CONSTANT = "CONSTANT"


class RetryConfig:
    """Configuration for retry behavior."""
    
    def __init__(
        self,
        max_retries: int = 5,
        base_delay: float = 1.0,
        max_delay: float = 60.0,
        strategy: RetryStrategy = RetryStrategy.EXPONENTIAL,
        jitter: bool = True,
        jitter_factor: float = 0.1,
        exponential_base: float = 2.0,
        retry_on_exceptions: tuple = (Exception,),
        retryable_status_codes: Optional[List[int]] = None
    ):
        self.max_retries = max_retries
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.strategy = strategy
        self.jitter = jitter
        self.jitter_factor = jitter_factor
        self.exponential_base = exponential_base
        self.retry_on_exceptions = retry_on_exceptions
        self.retryable_status_codes = retryable_status_codes or [500, 502, 503, 504, 429]
    
    def calculate_delay(self, attempt: int) -> float:
        """Calculate delay for given attempt number."""
        if self.strategy == RetryStrategy.LINEAR:
            delay = self.base_delay * attempt
        elif self.strategy == RetryStrategy.CONSTANT:
            delay = self.base_delay
        else:  # EXPONENTIAL
            delay = self.base_delay * (self.exponential_base ** attempt)
        
        delay = min(delay, self.max_delay)
        
        if self.jitter:
            jitter_amount = delay * self.jitter_factor
            delay += random.uniform(-jitter_amount, jitter_amount)
        
        return max(0, delay)


class RetryEngine:
    """Retry engine with persistence for tracking retry state."""
    
    def __init__(self, db_pool: pool.ThreadedConnectionPool):
        self.db_pool = db_pool
        self._setup_retry_table()
        self._local_cache: Dict[str, Dict] = {}
    
    def _setup_retry_table(self):
        """Create retry tracking table."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        CREATE TABLE IF NOT EXISTS genesis_bridge.retry_tracking (
                            id SERIAL PRIMARY KEY,
                            correlation_id VARCHAR(255) NOT NULL UNIQUE,
                            operation_name VARCHAR(255) NOT NULL,
                            payload_hash VARCHAR(64) NOT NULL,
                            attempt_count INTEGER DEFAULT 0,
                            max_retries INTEGER DEFAULT 5,
                            last_attempt_at TIMESTAMP,
                            next_attempt_at TIMESTAMP,
                            status VARCHAR(50) DEFAULT 'PENDING',
                            last_error TEXT,
                            created_at TIMESTAMP DEFAULT NOW(),
                            completed_at TIMESTAMP,
                            metadata JSONB
                        );
                        CREATE INDEX IF NOT EXISTS idx_retry_tracking_status 
                            ON genesis_bridge.retry_tracking(status);
                        CREATE INDEX IF NOT EXISTS idx_retry_tracking_next_attempt 
                            ON genesis_bridge.retry_tracking(next_attempt_at);
                        CREATE INDEX IF NOT EXISTS idx_retry_tracking_correlation 
                            ON genesis_bridge.retry_tracking(correlation_id);
                    """)
                    conn.commit()
        except Exception as e:
            logger.error(f"Failed to setup retry tracking table: {e}")
    
    def _calculate_delay(self, config: RetryConfig, attempt: int) -> float:
        """Calculate delay with caching consideration."""
        return config.calculate_delay(attempt)
    
    async def execute_with_retry(
        self,
        func: Callable,
        config: RetryConfig,
        correlation_id: str,
        operation_name: str,
        payload_hash: Optional[str] = None,
        *args,
        **kwargs
    ) -> Any:
        """Execute function with retry logic and persistence."""
        
        # Check if operation is already in progress
        existing = self._get_retry_state(correlation_id)
        if existing and existing["status"] == "COMPLETED":
            logger.info(f"[{correlation_id}] Operation already completed")
            return None
        
        if existing and existing["status"] == "FAILED_PERMANENT":
            logger.warning(f"[{correlation_id}] Operation failed permanently")
            raise Exception(f"Operation {correlation_id} failed permanently: {existing.get('last_error')}")
        
        attempt = existing["attempt_count"] + 1 if existing else 1
        last_error = None
        
        while attempt <= config.max_retries:
            try:
                if asyncio.iscoroutinefunction(func):
                    result = await func(*args, **kwargs)
                else:
                    result = func(*args, **kwargs)
                
                # Success - mark as completed
                self._update_retry_state(
                    correlation_id=correlation_id,
                    operation_name=operation_name,
                    payload_hash=payload_hash,
                    attempt_count=attempt,
                    max_retries=config.max_retries,
                    status="COMPLETED",
                    last_error=None,
                    metadata={"result": str(result)[:500]}
                )
                
                logger.info(f"[{correlation_id}] Operation succeeded on attempt {attempt}")
                return result
                
            except config.retry_on_exceptions as e:
                last_error = str(e)
                logger.warning(
                    f"[{correlation_id}] Attempt {attempt}/{config.max_retries} failed: {last_error}"
                )
                
                # Check if error is retryable
                if not self._is_retryable_error(e, config):
                    self._update_retry_state(
                        correlation_id=correlation_id,
                        operation_name=operation_name,
                        payload_hash=payload_hash,
                        attempt_count=attempt,
                        max_retries=config.max_retries,
                        status="FAILED_PERMANENT",
                        last_error=last_error,
                        metadata={"non_retryable": True}
                    )
                    raise
                
                # Update retry state
                delay = self._calculate_delay(config, attempt)
                next_attempt_at = datetime.utcnow() + timedelta(seconds=delay)
                
                self._update_retry_state(
                    correlation_id=correlation_id,
                    operation_name=operation_name,
                    payload_hash=payload_hash,
                    attempt_count=attempt,
                    max_retries=config.max_retries,
                    status="IN_PROGRESS",
                    last_error=last_error,
                    next_attempt_at=next_attempt_at,
                    metadata={"delay": delay}
                )
                
                # Wait before next attempt
                if attempt < config.max_retries:
                    await asyncio.sleep(delay)
                
                attempt += 1
        
        # All retries exhausted
        self._update_retry_state(
            correlation_id=correlation_id,
            operation_name=operation_name,
            payload_hash=payload_hash,
            attempt_count=attempt - 1,
            max_retries=config.max_retries,
            status="FAILED_PERMANENT",
            last_error=f"Max retries ({config.max_retries}) exhausted. Last error: {last_error}",
            metadata={"exhausted": True}
        )
        
        raise Exception(f"Max retries ({config.max_retries}) exhausted: {last_error}")
    
    def _is_retryable_error(self, error: Exception, config: RetryConfig) -> bool:
        """Determine if error is retryable."""
        if hasattr(error, 'response'):
            status_code = getattr(error.response, 'status_code', None)
            if status_code and status_code in config.retryable_status_codes:
                return True
        
        if hasattr(error, 'retryable'):
            return error.retryable
        
        return True
    
    def _get_retry_state(self, correlation_id: str) -> Optional[Dict]:
        """Get retry state from database or cache."""
        if correlation_id in self._local_cache:
            cached = self._local_cache[correlation_id]
            if datetime.utcnow() < cached.get("next_attempt_at", datetime.min):
                return cached
        
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        SELECT attempt_count, max_retries, status, last_error, 
                               next_attempt_at, metadata
                        FROM genesis_bridge.retry_tracking
                        WHERE correlation_id = %s
                    """, (correlation_id,))
                    row = cur.fetchone()
                    
                    if row:
                        state = {
                            "attempt_count": row[0],
                            "max_retries": row[1],
                            "status": row[2],
                            "last_error": row[3],
                            "next_attempt_at": row[4],
                            "metadata": row[5]
                        }
                        self._local_cache[correlation_id] = state
                        return state
        except Exception as e:
            logger.error(f"Failed to get retry state: {e}")
        
        return None
    
    def _update_retry_state(
        self,
        correlation_id: str,
        operation_name: str,
        payload_hash: Optional[str],
        attempt_count: int,
        max_retries: int,
        status: str,
        last_error: Optional[str],
        next_attempt_at: Optional[datetime] = None,
        metadata: Optional[Dict] = None
    ):
        """Update retry state in database."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        INSERT INTO genesis_bridge.retry_tracking 
                        (correlation_id, operation_name, payload_hash, attempt_count, 
                         max_retries, status, last_error, next_attempt_at, metadata)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
                        ON CONFLICT (correlation_id) DO UPDATE SET
                            attempt_count = EXCLUDED.attempt_count,
                            status = EXCLUDED.status,
                            last_error = EXCLUDED.last_error,
                            next_attempt_at = EXCLUDED.next_attempt_at,
                            last_attempt_at = NOW(),
                            metadata = EXCLUDED.metadata,
                            completed_at = CASE 
                                WHEN EXCLUDED.status = 'COMPLETED' THEN NOW() 
                                ELSE genesis_bridge.retry_tracking.completed_at 
                            END
                    """, (
                        correlation_id,
                        operation_name,
                        payload_hash,
                        attempt_count,
                        max_retries,
                        status,
                        last_error,
                        next_attempt_at,
                        psycopg2.extras.Json(metadata) if metadata else None
                    ))
                    conn.commit()
            
            # Update local cache
            self._local_cache[correlation_id] = {
                "attempt_count": attempt_count,
                "max_retries": max_retries,
                "status": status,
                "last_error": last_error,
                "next_attempt_at": next_attempt_at,
                "metadata": metadata
            }
            
        except Exception as e:
            logger.error(f"Failed to update retry state: {e}")
    
    def get_pending_retries(self, limit: int = 100) -> List[Dict]:
        """Get pending retries that are due."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        SELECT correlation_id, operation_name, payload_hash, 
                               attempt_count, max_retries, last_error, metadata
                        FROM genesis_bridge.retry_tracking
                        WHERE status = 'IN_PROGRESS'
                          AND (next_attempt_at IS NULL OR next_attempt_at <= NOW())
                          AND attempt_count < max_retries
                        ORDER BY next_attempt_at ASC
                        LIMIT %s
                    """, (limit,))
                    
                    return [
                        {
                            "correlation_id": row[0],
                            "operation_name": row[1],
                            "payload_hash": row[2],
                            "attempt_count": row[3],
                            "max_retries": row[4],
                            "last_error": row[5],
                            "metadata": row[6]
                        }
                        for row in cur.fetchall()
                    ]
        except Exception as e:
            logger.error(f"Failed to get pending retries: {e}")
            return []
    
    def clear_completed_retries(self, older_than_hours: int = 24):
        """Clear completed retry entries older than specified hours."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        DELETE FROM genesis_bridge.retry_tracking
                        WHERE status IN ('COMPLETED', 'FAILED_PERMANENT')
                          AND completed_at < NOW() - INTERVAL '%s hours'
                    """, (older_than_hours,))
                    deleted = cur.rowcount
                    conn.commit()
                    logger.info(f"Cleared {deleted} completed retry entries")
                    return deleted
        except Exception as e:
            logger.error(f"Failed to clear completed retries: {e}")
            return 0


def with_retry(
    config: Optional[RetryConfig] = None,
    correlation_id_param: str = "correlation_id"
):
    """Decorator for automatic retry."""
    if config is None:
        config = RetryConfig()
    
    def decorator(func):
        @wraps(func)
        async def async_wrapper(*args, **kwargs):
            correlation_id = kwargs.get(correlation_id_param) or str(uuid.uuid4())
            kwargs[correlation_id_param] = correlation_id
            
            engine = kwargs.get('_retry_engine')
            if engine:
                return await engine.execute_with_retry(
                    func, config, correlation_id, func.__name__, *args, **kwargs
                )
            
            # Direct execution without persistence
            attempt = 0
            last_error = None
            
            while attempt < config.max_retries:
                try:
                    return await func(*args, **kwargs)
                except Exception as e:
                    last_error = e
                    attempt += 1
                    if attempt < config.max_retries:
                        delay = config.calculate_delay(attempt)
                        await asyncio.sleep(delay)
            
            raise last_error
        
        @wraps(func)
        def sync_wrapper(*args, **kwargs):
            correlation_id = kwargs.get(correlation_id_param) or str(uuid.uuid4())
            kwargs[correlation_id_param] = correlation_id
            
            attempt = 0
            last_error = None
            
            while attempt < config.max_retries:
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    last_error = e
                    attempt += 1
                    if attempt < config.max_retries:
                        delay = config.calculate_delay(attempt)
                        time.sleep(delay)
            
            raise last_error
        
        import asyncio
        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        return sync_wrapper
    
    return decorator