# dead_letter_queue.py
"""
Dead letter queue for storing failed messages for manual review/retry.
"""

import hashlib
import json
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from enum import Enum
from dataclasses import dataclass, asdict

import psycopg2
from psycopg2 import pool
from psycopg2.extras import Json

logger = logging.getLogger(__name__)


class DLQStatus(Enum):
    """Dead letter queue entry status."""
    PENDING = "PENDING"
    RETRYING = "RETRYING"
    RESOLVED = "RESOLVED"
    PERMANENT_FAILURE = "PERMANENT_FAILURE"


class DLQPriority(Enum):
    """Priority levels for DLQ entries."""
    CRITICAL = 1
    HIGH = 2
    MEDIUM = 3
    LOW = 4


@dataclass
class DLQEntry:
    """Dead letter queue entry."""
    correlation_id: str
    payload: Dict[str, Any]
    payload_hash: str
    error_type: str
    error_message: str
    error_details: Dict[str, Any]
    source_service: str
    destination_service: str
    retry_count: int = 0
    max_retries: int = 3
    status: str = DLQStatus.PENDING.value
    priority: int = DLQPriority.MEDIUM.value
    original_timestamp: Optional[datetime] = None
    next_retry_at: Optional[datetime] = None
    resolved_at: Optional[datetime] = None
    resolved_by: Optional[str] = None
    metadata: Optional[Dict] = None


class DeadLetterQueue:
    """Dead letter queue implementation with PostgreSQL storage."""
    
    def __init__(self, db_pool: pool.ThreadedConnectionPool):
        self.db_pool = db_pool
        self._setup_dlq_table()
    
    def _setup_dlq_table(self):
        """Create dead letter queue table."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        CREATE TABLE IF NOT EXISTS genesis_bridge.dead_letter_queue (
                            id SERIAL PRIMARY KEY,
                            correlation_id VARCHAR(255) NOT NULL,
                            payload JSONB NOT NULL,
                            payload_hash VARCHAR(64) NOT NULL,
                            error_type VARCHAR(255) NOT NULL,
                            error_message TEXT NOT NULL,
                            error_details JSONB,
                            source_service VARCHAR(255) NOT NULL,
                            destination_service VARCHAR(255) NOT NULL,
                            retry_count INTEGER DEFAULT 0,
                            max_retries INTEGER DEFAULT 3,
                            status VARCHAR(50) DEFAULT 'PENDING',
                            priority INTEGER DEFAULT 3,
                            original_timestamp TIMESTAMP,
                            next_retry_at TIMESTAMP,
                            resolved_at TIMESTAMP,
                            resolved_by VARCHAR(255),
                            metadata JSONB,
                            created_at TIMESTAMP DEFAULT NOW(),
                            updated_at TIMESTAMP DEFAULT NOW()
                        );
                        
                        CREATE INDEX IF NOT EXISTS idx_dlq_status 
                            ON genesis_bridge.dead_letter_queue(status);
                        CREATE INDEX IF NOT EXISTS idx_dlq_correlation_id 
                            ON genesis_bridge.dead_letter_queue(correlation_id);
                        CREATE INDEX IF NOT EXISTS idx_dlq_payload_hash 
                            ON genesis_bridge.dead_letter_queue(payload_hash);
                        CREATE INDEX IF NOT EXISTS idx_dlq_priority 
                            ON genesis_bridge.dead_letter_queue(priority, status);
                        CREATE INDEX IF NOT EXISTS idx_dlq_next_retry 
                            ON genesis_bridge.dead_letter_queue(next_retry_at);
                    """)
                    conn.commit()
        except Exception as e:
            logger.error(f"Failed to setup DLQ table: {e}")
    
    @staticmethod
    def compute_hash(payload: Dict[str, Any]) -> str:
        """Compute SHA256 hash of payload."""
        serialized = json.dumps(payload, sort_keys=True, default=str)
        return hashlib.sha256(serialized.encode()).hexdigest()
    
    def add_entry(
        self,
        correlation_id: str,
        payload: Dict[str, Any],
        error_type: str,
        error_message: str,
        source_service: str,
        destination_service: str,
        error_details: Optional[Dict] = None,
        priority: int = DLQPriority.MEDIUM.value,
        max_retries: int = 3,
        metadata: Optional[Dict] = None
    ) -> int:
        """Add entry to dead letter queue."""
        payload_hash = self.compute_hash(payload)
        
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        INSERT INTO genesis_bridge.dead_letter_queue 
                        (correlation_id, payload, payload_hash, error_type, error_message,
                         error_details, source_service, destination_service, priority,
                         max_retries, metadata, original_timestamp)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())
                        RETURNING id
                    """, (
                        correlation_id,
                        Json(payload),
                        payload_hash,
                        error_type,
                        error_message,
                        Json(error_details) if error_details else None,
                        source_service,
                        destination_service,
                        priority,
                        max_retries,
                        Json(metadata) if metadata else None
                    ))
                    entry_id = cur.fetchone()[0]
                    conn.commit()
                    
                    logger.info(
                        f"Added DLQ entry {entry_id} for correlation {correlation_id}"
                    )
                    return entry_id
        except Exception as e:
            logger.error(f"Failed to add DLQ entry: {e}")
            raise
    
    def get_pending_entries(
        self,
        limit: int = 100,
        priority_filter: Optional[List[int]] = None
    ) -> List[Dict]:
        """Get pending DLQ entries ready for retry."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    query = """
                        SELECT id, correlation_id, payload, payload_hash, error_type,
                               error_message, source_service, destination_service,
                               retry_count, max_retries, priority, metadata
                        FROM genesis_bridge.dead_letter_queue
                        WHERE status IN ('PENDING', 'RETRYING')
                          AND (next_retry_at IS NULL OR next_retry_at <= NOW())
                          AND retry_count < max_retries
                    """
                    
                    if priority_filter:
                        query += f" AND priority IN ({','.join(map(str, priority_filter))})"
                    
                    query += " ORDER BY priority ASC, original_timestamp ASC LIMIT %s"
                    
                    cur.execute(query, (limit,))
                    
                    return [
                        {
                            "id": row[0],
                            "correlation_id": row[1],
                            "payload": row[2],
                            "payload_hash": row[3],
                            "error_type": row[4],
                            "error_message": row[5],
                            "source_service": row[6],
                            "destination_service": row[7],
                            "retry_count": row[8],
                            "max_retries": row[9],
                            "priority": row[10],
                            "metadata": row[11]
                        }
                        for row in cur.fetchall()
                    ]
        except Exception as e:
            logger.error(f"Failed to get pending DLQ entries: {e}")
            return []
    
    def update_retry(self, entry_id: int, status: str, next_retry_at: Optional[datetime] = None):
        """Update retry status of DLQ entry."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    if status == DLQStatus.RETRYING.value:
                        cur.execute("""
                            UPDATE genesis_bridge.dead_letter_queue
                            SET status = %s,
                                retry_count = retry_count + 1,
                                next_retry_at = %s,
                                updated_at = NOW()
                            WHERE id = %s
                        """, (status, next_retry_at, entry_id))
                    else:
                        cur.execute("""
                            UPDATE genesis_bridge.dead_letter_queue
                            SET status = %s,
                                updated_at = NOW()
                            WHERE id = %s
                        """, (status, entry_id))
                    conn.commit()
        except Exception as e:
            logger.error(f"Failed to update DLQ retry: {e}")
    
    def mark_resolved(
        self,
        entry_id: int,
        resolved_by: str = "system"
    ):
        """Mark DLQ entry as resolved."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        UPDATE genesis_bridge.dead_letter_queue
                        SET status = %s,
                            resolved_at = NOW(),
                            resolved_by = %s,
                            updated_at = NOW()
                        WHERE id = %s
                    """, (DLQStatus.RESOLVED.value, resolved_by, entry_id))
                    conn.commit()
                    
                    logger.info(f"DLQ entry {entry_id} resolved by {resolved_by}")
        except Exception as e:
            logger.error(f"Failed to resolve DLQ entry: {e}")
    
    def mark_permanent_failure(self, entry_id: int, reason: str):
        """Mark DLQ entry as permanent failure."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        UPDATE genesis_bridge.dead_letter_queue
                        SET status = %s,
                            error_message = error_message || ' | Permanent failure: ' || %s,
                            updated_at = NOW()
                        WHERE id = %s
                    """, (DLQStatus.PERMANENT_FAILURE.value, reason, entry_id))
                    conn.commit()
                    
                    logger.warning(f"DLQ entry {entry_id} marked as permanent failure")
        except Exception as e:
            logger.error(f"Failed to mark DLQ entry as permanent failure: {e}")
    
    def get_entry(self, entry_id: int) -> Optional[Dict]:
        """Get specific DLQ entry."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        SELECT * FROM genesis_bridge.dead_letter_queue
                        WHERE id = %s
                    """, (entry_id,))
                    row = cur.fetchone()
                    
                    if row:
                        columns = [desc[0] for desc in cur.description]
                        return dict(zip(columns, row))
        except Exception as e:
            logger.error(f"Failed to get DLQ entry: {e}")
        
        return None
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get DLQ statistics."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        SELECT 
                            COUNT(*) as total,
                            COUNT(*) FILTER (WHERE status = 'PENDING') as pending,
                            COUNT(*) FILTER (WHERE status = 'RETRYING') as retrying,
                            COUNT(*) FILTER (WHERE status = 'RESOLVED') as resolved,
                            COUNT(*) FILTER (WHERE status = 'PERMANENT_FAILURE') as permanent_failure,
                            COUNT(*) FILTER (WHERE priority = 1) as critical,
                            AVG(retry_count) as avg_retries
                        FROM genesis_bridge.dead_letter_queue
                    """)
                    row = cur.fetchone()
                    
                    return {
                        "total": row[0] or 0,
                        "pending": row[1] or 0,
                        "retrying": row[2] or 0,
                        "resolved": row[3] or 0,
                        "permanent_failure": row[4] or 0,
                        "critical": row[5] or 0,
                        "avg_retries": float(row[6] or 0)
                    }
        except Exception as e:
            logger.error(f"Failed to get DLQ statistics: {e}")
            return {}
    
    def clear_resolved(self, older_than_days: int = 7):
        """Clear resolved entries older than specified days."""
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        DELETE FROM genesis_bridge.dead_letter_queue
                        WHERE status = 'RESOLVED'
                          AND resolved_at < NOW() - INTERVAL '%s days'
                    """, (older_than_days,))
                    deleted = cur.rowcount
                    conn.commit()
                    
                    logger.info(f"Cleared {deleted} resolved DLQ entries")
                    return deleted
        except Exception as e:
            logger.error(f"Failed to clear resolved DLQ entries: {e}")
            return 0
    
    def check_duplicate(self, payload: Dict[str, Any], window_seconds: int = 60) -> bool:
        """Check if payload is duplicate within time window."""
        payload_hash = self.compute_hash(payload)
        
        try:
            with self.db_pool.getconn() as conn:
                with conn.cursor() as cur:
                    cur.execute("""
                        SELECT COUNT(*) FROM genesis_bridge.dead_letter_queue
                        WHERE payload_hash = %s
                          AND created_at > NOW() - INTERVAL '%s seconds'
                    """, (payload_hash, window_seconds))
                    
                    count = cur.fetchone()[0]
                    return count > 0
        except Exception as e:
            logger.error(f"Failed to check duplicate: {e}")
            return False