#!/usr/bin/env python3
"""
bridge_poller.py - AIVA Command Bridge Poller for Claude Code
Polls PostgreSQL for incoming directives and processes them.
"""

import os
import sys
import time
import signal
import argparse
import logging
import datetime
from typing import Optional, List, Dict, Any
from contextlib import contextmanager
from dataclasses import dataclass

import psycopg2
from psycopg2 import sql
from psycopg2.extras import RealDictCursor, Json


# ============================================================================
# CONFIGURATION
# ============================================================================

from psycopg2.extras import RealDictCursor, Json

from core.secrets_loader import get_postgres_config


# ============================================================================
# CONFIGURATION
# ============================================================================

@dataclass(frozen=True)
class Config:
    """Immutable configuration container following 12-factor app principles."""
    poll_interval: int = int(os.getenv("POLL_INTERVAL", "30"))
    schema: str = "genesis_bridge"
    log_file: str = os.getenv("LOG_FILE", "/var/log/aiva_bridge/poller.log")



# ============================================================================
# COLOR CODES
# ============================================================================

class Colors:
    """ANSI color codes for terminal output."""
    RED = '\033[91m'
    YELLOW = '\033[93m'
    GREEN = '\033[92m'
    BLUE = '\033[94m'
    MAGENTA = '\033[95m'
    CYAN = '\033[96m'
    WHITE = '\033[97m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    END = '\033[0m'
    
    @classmethod
    def colorize(cls, text: str, color: str) -> str:
        """Apply color to text."""
        return f"{color}{text}{cls.END}"


# ============================================================================
# DATABASE MANAGER
# ============================================================================

class DatabaseManager:
    """
    Manages PostgreSQL connections and operations.
    Implements connection pooling and transaction management.
    """
    
    def __init__(self, config: Config):
        self.config = config
        self._connection: Optional[psycopg2.extensions.connection] = None
        self._ensure_schema()
    
    def _ensure_schema(self) -> None:
        """Ensure the genesis_bridge schema exists."""
        with self.get_cursor() as cursor:
            cursor.execute(
                sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(
                    sql.Identifier(self.config.schema)
                )
            )
            self._connection.commit()
    
    @contextmanager
    def get_cursor(self, commit: bool = False):
        """
        Context manager for database cursors.
        Automatically handles connection creation and cleanup.
        """
        if not self._connection or self._connection.closed:
            pg_config = get_postgres_config()
            if not pg_config.is_configured:
                raise psycopg2.Error("PostgreSQL is not configured.")
            self._connection = psycopg2.connect(
                pg_config.to_dsn(),
                cursor_factory=RealDictCursor
            )
        
        cursor = self._connection.cursor()
        try:
            yield cursor
            if commit:
                self._connection.commit()
        except Exception as e:
            self._connection.rollback()
            raise e
        finally:
            cursor.close()
    
    def close(self) -> None:
        """Close database connection."""
        if self._connection and not self._connection.closed:
            self._connection.close()
    
    def fetch_pending_directives(self) -> List[Dict[str, Any]]:
        """
        Fetch pending inbound directives targeted at Claude.
        Orders by priority DESC, created_at ASC.
        """
        query = sql.SQL("""
            SELECT * FROM {}.command_queue
            WHERE direction = 'inbound'
              AND status = 'pending'
              AND target = 'claude'
            ORDER BY priority DESC, created_at ASC
            FOR UPDATE SKIP LOCKED
        """).format(sql.Identifier(self.config.schema))
        
        with self.get_cursor() as cursor:
            cursor.execute(query)
            return cursor.fetchall()
    
    def update_directive_status(
        self, 
        directive_id: int, 
        status: str, 
        processing_time_ms: Optional[float] = None
    ) -> None:
        """Update directive status and optional processing metrics."""
        fields = ["status = %s", "updated_at = NOW()"]
        values = [status]
        
        if processing_time_ms is not None:
            fields.append("processing_time_ms = %s")
            values.append(processing_time_ms)
        
        query = sql.SQL("""
            UPDATE {}.command_queue
            SET {}
            WHERE id = %s
        """).format(
            sql.Identifier(self.config.schema),
            sql.SQL(", ").join(map(sql.SQL, fields))
        )
        
        with self.get_cursor(commit=True) as cursor:
            cursor.execute(query, values + [directive_id])
    
    def insert_acknowledgment(self, original_id: int, ack_message: str) -> None:
        """Insert acknowledgment response into queue."""
        query = sql.SQL("""
            INSERT INTO {}.command_queue 
            (direction, status, target, command_type, payload, source_id, priority, created_at)
            VALUES ('outbound', 'pending', 'aiva', 'acknowledgment', %s, %s, 5, NOW())
        """).format(sql.Identifier(self.config.schema))
        
        payload = {
            "acknowledged_directive_id": original_id,
            "message": ack_message,
            "timestamp": datetime.datetime.now().isoformat()
        }
        
        with self.get_cursor(commit=True) as cursor:
            cursor.execute(query, [Json(payload), original_id])


# ============================================================================
# DIRECTIVE PROCESSOR
# ============================================================================

class DirectiveProcessor:
    """
    Processes individual directives with logging, formatting, and acknowledgment.
    """
    
    def __init__(self, db_manager: DatabaseManager, logger: logging.Logger):
        self.db = db_manager
        self.logger = logger
    
    def process(self, directive: Dict[str, Any]) -> None:
        """
        Process a single directive through the complete lifecycle.
        """
        start_time = time.time()
        directive_id = directive['id']
        priority = directive.get('priority', 0)
        
        try:
            # 1. Mark as processing
            self.db.update_directive_status(directive_id, 'processing')
            
            # 2. Format and display
            self._display_directive(directive, priority)
            
            # 3. Log to file
            self._log_directive(directive)
            
            # 4. Send acknowledgment
            self.db.insert_acknowledgment(
                directive_id, 
                f"Directive {directive_id} received and processing started"
            )
            
            # 5. Calculate processing time (initial phase)
            processing_time = (time.time() - start_time) * 1000
            
            # Note: Directive remains in 'processing' state until Claude Code 
            # explicitly marks it completed via bridge_respond.py
            
            self.logger.info(
                f"Directive {directive_id} initialized in {processing_time:.2f}ms"
            )
            
        except Exception as e:
            self.logger.error(f"Failed to process directive {directive_id}: {e}")
            self.db.update_directive_status(directive_id, 'failed')
            raise
    
    def _display_directive(self, directive: Dict[str, Any], priority: int) -> None:
        """Format and print directive to stdout with color coding."""
        # Determine color based on priority
        if priority >= 8:
            color = Colors.RED
            bell = "\a"  # Terminal bell
            urgency = "🔴 URGENT"
        elif priority >= 6:
            color = Colors.YELLOW
            bell = ""
            urgency = "🟡 HIGH"
        else:
            color = Colors.GREEN
            bell = ""
            urgency = "🟢 NORMAL"
        
        # Format output
        header = f"{urgency} DIRECTIVE #{directive['id']} | Priority: {priority}"
        separator = "=" * 80
        
        output = f"""
{Colors.BOLD}{color}{separator}{Colors.END}
{Colors.BOLD}{color}{header}{Colors.END}
{Colors.BOLD}{color}{separator}{Colors.END}

{Colors.CYAN}From:{Colors.END} {directive.get('source', 'Kinan/AIVA')}
{Colors.CYAN}Received:{Colors.END} {directive.get('created_at')}
{Colors.CYAN}Type:{Colors.END} {directive.get('command_type', 'unknown')}

{Colors.BOLD}Command:{Colors.END}
{directive.get('payload', directive.get('command', 'No content'))}

{Colors.BOLD}{color}{separator}{Colors.END}{bell}
"""
        print(output)
        sys.stdout.flush()
    
    def _log_directive(self, directive: Dict[str, Any]) -> None:
        """Log directive details to structured log."""
        self.logger.info(
            f"DIRECTIVE_RECEIVED: id={directive['id']} "
            f"priority={directive.get('priority', 0)} "
            f"type={directive.get('command_type', 'unknown')} "
            f"source={directive.get('source', 'unknown')}"
        )


# ============================================================================
# POLLER
# ============================================================================

class Poller:
    """
    Main polling loop with graceful shutdown support.
    """
    
    def __init__(self, config: Config):
        self.config = config
        self.running = False
        self.db = DatabaseManager(config)
        self.logger = self._setup_logging()
        self.processor = DirectiveProcessor(self.db, self.logger)
        self._setup_signal_handlers()
    
    def _setup_logging(self) -> logging.Logger:
        """Configure structured logging to file and stdout."""
        logger = logging.getLogger("bridge_poller")
        logger.setLevel(logging.INFO)
        
        # Ensure log directory exists
        log_dir = os.path.dirname(self.config.log_file)
        if log_dir and not os.path.exists(log_dir):
            os.makedirs(log_dir, exist_ok=True)
        
        # File handler with rotation capability
        file_handler = logging.FileHandler(self.config.log_file)
        file_handler.setLevel(logging.INFO)
        file_format = logging.Formatter(
            '%(asctime)s | %(levelname)-8s | %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S'
        )
        file_handler.setFormatter(file_format)
        
        # Console handler for errors only (directives printed separately)
        console_handler = logging.StreamHandler(sys.stderr)
        console_handler.setLevel(logging.WARNING)
        console_format = logging.Formatter('%(levelname)s: %(message)s')
        console_handler.setFormatter(console_format)
        
        logger.addHandler(file_handler)
        logger.addHandler(console_handler)
        
        return logger
    
    def _setup_signal_handlers(self) -> None:
        """Setup handlers for graceful shutdown."""
        signal.signal(signal.SIGTERM, self._handle_shutdown)
        signal.signal(signal.SIGINT, self._handle_shutdown)
    
    def _handle_shutdown(self, signum, frame) -> None:
        """Handle shutdown signals gracefully."""
        self.logger.info(f"Received signal {signum}, initiating graceful shutdown...")
        self.running = False
    
    def poll_once(self) -> int:
        """
        Execute single polling cycle.
        Returns number of directives processed.
        """
        try:
            directives = self.db.fetch_pending_directives()
            count = 0
            
            for directive in directives:
                try:
                    self.processor.process(directive)
                    count += 1
                except Exception as e:
                    self.logger.error(f"Error processing directive {directive.get('id')}: {e}")
                    # Continue processing other directives
            
            return count
            
        except psycopg2.Error as e:
            self.logger.error(f"Database error during poll: {e}")
            return 0
        except Exception as e:
            self.logger.error(f"Unexpected error during poll: {e}")
            return 0
    
    def run_daemon(self) -> None:
        """Run continuous polling loop."""
        self.running = True
        self.logger.info(f"Starting AIVA Bridge Poller (interval: {self.config.poll_interval}s)")
        
        while self.running:
            start_time = time.time()
            
            processed = self.poll_once()
            if processed > 0:
                self.logger.info(f"Processed {processed} directives in this cycle")
            
            # Calculate sleep time to maintain consistent interval
            elapsed = time.time() - start_time
            sleep_time = max(0, self.config.poll_interval - elapsed)
            
            # Sleep with interrupt handling
            while sleep_time > 0 and self.running:
                time.sleep(min(0.1, sleep_time))
                sleep_time -= 0.1
        
        self.logger.info("Poller stopped gracefully")
        self.db.close()


# ============================================================================
# MAIN ENTRY POINT
# ============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="AIVA Bridge Poller - Watches for incoming directives"
    )
    parser.add_argument(
        "--once", 
        action="store_true", 
        help="Run single poll and exit (useful for cron)"
    )
    parser.add_argument(
        "--daemon", 
        action="store_true", 
        help="Run continuous polling loop (default)"
    )
    
    args = parser.parse_args()
    
    config = Config()
    poller = Poller(config)
    
    if args.once:
        count = poller.poll_once()
        print(f"Processed {count} directives")
        sys.exit(0 if count >= 0 else 1)
    else:
        # Default to daemon mode
        poller.run_daemon()


if __name__ == "__main__":
    main()