"""
AIVA Voice Command Bridge - Priority Routing System
Production-ready implementation for Genesis
"""

import os
import sys
import time
import json
import logging
import threading
import subprocess
import datetime
import asyncio
from enum import IntEnum
from typing import Optional, List, Dict, Any, Callable
from dataclasses import dataclass, field
from queue import PriorityQueue, Queue
from abc import ABC, abstractmethod

import psycopg2
from psycopg2 import pool
from psycopg2.extras import RealDictCursor
import requests

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('/var/log/aiva/priority_router.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)


# =============================================================================
# Database Connection Manager
# =============================================================================

class DatabaseConnectionPool:
    """Thread-safe PostgreSQL connection pool manager."""
    
    _instance = None
    _lock = threading.Lock()
    
    def __new__(cls):
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
        return cls._instance
    
    def __init__(self):
        if not hasattr(self, 'initialized'):
            self.host = os.getenv('POSTGRES_HOST', 'postgresql-genesis-u50607.vm.elestio.app')
            self.port = os.getenv('POSTGRES_PORT', '25432')
            self.user = os.getenv('POSTGRES_USER', 'postgres')
            self.password = os.getenv('POSTGRES_PASSWORD', 'CiBjh6LM7Yuqkq-jo2r7eQDw')
            self.database = os.getenv('POSTGRES_DB', 'postgres')
            self.pool = None
            self.initialized = True
            self._init_pool()
    
    def _init_pool(self):
        """Initialize the connection pool."""
        try:
            self.pool = psycopg2.pool.ThreadedConnectionPool(
                minconn=2,
                maxconn=20,
                host=self.host,
                port=self.port,
                user=self.user,
                password=self.password,
                database=self.database
            )
            logger.info(f"Database connection pool initialized: {self.host}:{self.port}")
        except Exception as e:
            logger.error(f"Failed to initialize database pool: {e}")
            raise
    
    def get_connection(self):
        """Get a connection from the pool."""
        try:
            conn = self.pool.getconn()
            conn.autocommit = False
            return conn
        except Exception as e:
            logger.error(f"Failed to get connection from pool: {e}")
            raise
    
    def return_connection(self, conn):
        """Return a connection to the pool."""
        try:
            self.pool.putconn(conn)
        except Exception as e:
            logger.error(f"Failed to return connection to pool: {e}")
    
    def execute_query(self, query: str, params: tuple = None, fetch: bool = True):
        """Execute a query and return results."""
        conn = None
        cursor = None
        try:
            conn = self.get_connection()
            cursor = conn.cursor(cursor_factory=RealDictCursor)
            cursor.execute(query, params)
            
            if fetch:
                results = cursor.fetchall()
                conn.commit()
                return results
            else:
                conn.commit()
                return cursor.rowcount
        except Exception as e:
            if conn:
                conn.rollback()
            logger.error(f"Query execution failed: {e}")
            raise
        finally:
            if cursor:
                cursor.close()
            if conn:
                self.return_connection(conn)
    
    def close_all(self):
        """Close all connections in the pool."""
        if self.pool:
            self.pool.closeall()
            logger.info("Database connection pool closed")


# =============================================================================
# Database Schema Setup
# =============================================================================

class DatabaseSchema:
    """Database schema management for priority routing."""
    
    @staticmethod
    def setup_schema():
        """Create necessary tables if they don't exist."""
        db = DatabaseConnectionPool()
        
        # Create schema
        db.execute_query("CREATE SCHEMA IF NOT EXISTS genesis_bridge", fetch=False)
        
        # Create directives table with priority support
        db.execute_query("""
            CREATE TABLE IF NOT EXISTS genesis_bridge.directives (
                id SERIAL PRIMARY KEY,
                directive_id VARCHAR(255) UNIQUE NOT NULL,
                voice_command TEXT NOT NULL,
                raw_transcript TEXT,
                priority INTEGER DEFAULT 5 CHECK (priority >= 1 AND priority <= 10),
                urgency_keyword BOOLEAN DEFAULT FALSE,
                status VARCHAR(50) DEFAULT 'pending' CHECK (status IN ('pending', 'processing', 'completed', 'failed', 'escalated')),
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                processed_at TIMESTAMP,
                escalated_at TIMESTAMP,
                previous_priority INTEGER,
                notification_sent BOOLEAN DEFAULT FALSE,
                retry_count INTEGER DEFAULT 0,
                metadata JSONB DEFAULT '{}'::jsonb
            )
        """, fetch=False)
        
        # Create priority queue table
        db.execute_query("""
            CREATE TABLE IF NOT EXISTS genesis_bridge.priority_queue (
                id SERIAL PRIMARY KEY,
                directive_id VARCHAR(255) REFERENCES genesis_bridge.directives(directive_id),
                priority INTEGER NOT NULL CHECK (priority >= 1 AND priority <= 10),
                enqueued_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                scheduled_processing TIMESTAMP,
                processing_deadline TIMESTAMP,
                status VARCHAR(50) DEFAULT 'queued' CHECK (status IN ('queued', 'processing', 'completed', 'failed', 'bypassed'))
            )
        """, fetch=False)
        
        # Create priority escalation log
        db.execute_query("""
            CREATE TABLE IF NOT EXISTS genesis_bridge.priority_escalation_log (
                id SERIAL PRIMARY KEY,
                directive_id VARCHAR(255) REFERENCES genesis_bridge.directives(directive_id),
                old_priority INTEGER NOT NULL,
                new_priority INTEGER NOT NULL,
                escalation_reason VARCHAR(255) NOT NULL,
                escalated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                escalation_method VARCHAR(50) DEFAULT 'automatic'
            )
        """, fetch=False)
        
        # Create indexes
        db.execute_query("""
            CREATE INDEX IF NOT EXISTS idx_directives_priority 
            ON genesis_bridge.directives(priority DESC, created_at ASC)
        """, fetch=False)
        
        db.execute_query("""
            CREATE INDEX IF NOT EXISTS idx_directives_status 
            ON genesis_bridge.directives(status, created_at)
        """, fetch=False)
        
        db.execute_query("""
            CREATE INDEX IF NOT EXISTS idx_priority_queue_priority 
            ON genesis_bridge.priority_queue(priority, enqueued_at)
        """, fetch=False)
        
        logger.info("Database schema setup completed")


# =============================================================================
# Priority Level Definitions
# =============================================================================

class Priority(IntEnum):
    """Priority levels for directive processing."""
    CRITICAL = 10
    URGENT_HIGH = 9
    URGENT = 8
    HIGH = 7
    HIGH_MEDIUM = 6
    NORMAL = 5
    NORMAL_LOW = 4
    LOW = 3
    LOW_MEDIUM = 2
    LOW_MIN = 1


class PriorityBehavior:
    """Defines behavior for each priority level."""
    
    BEHAVIOR_MAP = {
        Priority.CRITICAL: {
            "execution": "immediate",
            "notification": ["websocket", "terminal_bell", "log", "desktop_notification"],
            "bypass_queue": True,
            "max_wait_time": 0,
            "description": "Immediate execution, terminal bell + desktop notification, bypass queue"
        },
        Priority.URGENT_HIGH: {
            "execution": "immediate",
            "notification": ["websocket", "terminal_bell", "log"],
            "bypass_queue": False,
            "max_wait_time": 60,
            "description": "Move to front of queue, process within 60 seconds"
        },
        Priority.URGENT: {
            "execution": "fast_track",
            "notification": ["websocket", "terminal_bell", "log"],
            "bypass_queue": False,
            "max_wait_time": 60,
            "description": "Move to front of queue, process within 60 seconds"
        },
        Priority.HIGH: {
            "execution": "standard",
            "notification": ["websocket", "log"],
            "bypass_queue": False,
            "max_wait_time": 300,
            "description": "Standard priority, process within 5 minutes"
        },
        Priority.HIGH_MEDIUM: {
            "execution": "standard",
            "notification": ["websocket", "log"],
            "bypass_queue": False,
            "max_wait_time": 300,
            "description": "Standard priority, process within 5 minutes"
        },
        Priority.NORMAL: {
            "execution": "standard",
            "notification": ["websocket", "log"],
            "bypass_queue": False,
            "max_wait_time": None,
            "description": "Default, process in order"
        },
        Priority.NORMAL_LOW: {
            "execution": "standard",
            "notification": ["websocket", "log"],
            "bypass_queue": False,
            "max_wait_time": None,
            "description": "Default, process in order"
        },
        Priority.LOW: {
            "execution": "background",
            "notification": ["log"],
            "bypass_queue": False,
            "max_wait_time": None,
            "description": "Background processing, batch if possible"
        },
        Priority.LOW_MEDIUM: {
            "execution": "background",
            "notification": ["log"],
            "bypass_queue": False,
            "max_wait_time": None,
            "description": "Background processing, batch if possible"
        },
        Priority.LOW_MIN: {
            "execution": "background",
            "notification": ["log"],
            "bypass_queue": False,
            "max_wait_time": None,
            "description": "Background processing, batch if possible"
        }
    }
    
    @classmethod
    def get_behavior(cls, priority: int) -> Dict[str, Any]:
        """Get behavior configuration for a priority level."""
        priority = max(1, min(10, priority))
        return cls.BEHAVIOR_MAP.get(priority, cls.BEHAVIOR_MAP[Priority.NORMAL])
    
    @classmethod
    def get_notification_channels(cls, priority: int) -> List[str]:
        """Get notification channels for a priority level."""
        behavior = cls.get_behavior(priority)
        return behavior.get("notification", ["log"])
    
    @classmethod
    def should_bypass_queue(cls, priority: int) -> bool:
        """Check if priority level should bypass the queue."""
        behavior = cls.get_behavior(priority)
        return behavior.get("bypass_queue", False)


# =============================================================================
# Notification Dispatcher
# =============================================================================

class NotificationDispatcher:
    """Dispatches notifications based on priority level."""
    
    def __init__(self, websocket_handler: Optional[Callable] = None):
        self.websocket_handler = websocket_handler
        self.db = DatabaseConnectionPool()
        self.logger = logging.getLogger(f"{__name__}.NotificationDispatcher")
    
    def dispatch(self, directive_id: str, priority: int, message: str, 
                 command_text: str = "") -> bool:
        """Dispatch notifications based on priority level."""
        channels = PriorityBehavior.get_notification_channels(priority)
        
        self.logger.info(f"Dispatching notifications for directive {directive_id} "
                        f"with priority {priority} via channels: {channels}")
        
        success = True
        
        for channel in channels:
            try:
                if channel == "log":
                    self._send_log_notification(directive_id, priority, message)
                elif channel == "terminal_bell":
                    self._send_terminal_bell(directive_id, priority, message)
                elif channel == "desktop_notification":
                    self._send_desktop_notification(directive_id, priority, message, command_text)
                elif channel == "websocket":
                    self._send_websocket_notification(directive_id, priority, message)
            except Exception as e:
                self.logger.error(f"Failed to send {channel} notification: {e}")
                success = False
        
        # Mark notification as sent in database
        if success:
            self._mark_notification_sent(directive_id)
        
        return success
    
    def _send_log_notification(self, directive_id: str, priority: int, message: str):
        """Send log notification."""
        priority_label = self._get_priority_label(priority)
        log_message = f"[{priority_label}] Directive {directive_id}: {message}"
        
        if priority >= 8:
            self.logger.warning(log_message)
        else:
            self.logger.info(log_message)
    
    def _send_terminal_bell(self, directive_id: str, priority: int, message: str):
        """Send terminal bell notification."""
        # Print bell character to stdout for terminal notification
        priority_label = self._get_priority_label(priority)
        print(f"\a\a\a[{priority_label}] {directive_id}: {message}", flush=True)
        
        # Also try using system beep
        try:
            if sys.platform == "win32":
                import winsound
                winsound.Beep(1000, 500)
            else:
                # Try using terminal escape sequences
                print('\033[10;100]', flush=True)
                subprocess.run(['echo', '-e', '\a'], check=False, capture_output=True)
        except Exception as e:
            self.logger.debug(f"Could not send system beep: {e}")
    
    def _send_desktop_notification(self, directive_id: str, priority: int, 
                                    message: str, command_text: str = ""):
        """Send desktop notification."""
        priority_label = self._get_priority_label(priority)
        title = f"AIVA: {priority_label} Directive"
        body = f"{message}\n\nCommand: {command_text}" if command_text else message
        
        # Try using different notification tools
        notification_sent = False
        
        # Try notify-send (Linux)
        try:
            subprocess.run(
                ['notify-send', '-u', 'critical' if priority == 10 else 'normal', 
                 '-t', '10000', title, body],
                check=False,
                capture_output=True
            )
            notification_sent = True
        except Exception:
            pass
        
        # Try using Python libraries
        if not notification_sent:
            try:
                from plyer import notification
                notification.notify(
                    title=title,
                    message=body,
                    app_name='AIVA Voice Command Bridge',
                    timeout=10
                )
                notification_sent = True
            except ImportError:
                pass
        
        # Fallback: Use curl to send notification via webhook (if configured)
        if not notification_sent:
            webhook_url = os.getenv('DESKTOP_NOTIFICATION_WEBHOOK')
            if webhook_url:
                try:
                    requests.post(webhook_url, json={
                        "title": title,
                        "body": body,
                        "directive_id": directive_id,
                        "priority": priority
                    }, timeout=5)
                except Exception as e:
                    self.logger.warning(f"Failed to send webhook notification: {e}")
    
    def _send_websocket_notification(self, directive_id: str, priority: int, message: str):
        """Send WebSocket notification."""
        if self.websocket_handler:
            try:
                payload = {
                    "type": "directive_notification",
                    "directive_id": directive_id,
                    "priority": priority,
                    "priority_label": self._get_priority_label(priority),
                    "message": message,
                    "timestamp": datetime.datetime.utcnow().isoformat()
                }
                self.websocket_handler(json.dumps(payload))
            except Exception as e:
                self.logger.error(f"Failed to send WebSocket notification: {e}")
        else:
            self.logger.debug(f"WebSocket handler not configured, skipping notification for {directive_id}")
    
    def _mark_notification_sent(self, directive_id: str):
        """Mark notification as sent in database."""
        try:
            self.db.execute_query(
                """UPDATE genesis_bridge.directives 
                   SET notification_sent = TRUE, updated_at = CURRENT_TIMESTAMP 
                   WHERE directive_id = %s""",
                (directive_id,)
            )
        except Exception as e:
            self.logger.error(f"Failed to mark notification sent: {e}")
    
    def _get_priority_label(self, priority: int) -> str:
        """Get human-readable priority label."""
        if priority == 10:
            return "CRITICAL"
        elif priority >= 8:
            return "URGENT"
        elif priority >= 6:
            return "HIGH"
        elif priority >= 4:
            return "NORMAL"
        else:
            return "LOW"


# =============================================================================
# Priority Router
# =============================================================================

class PriorityRouter:
    """Main priority routing engine for directives."""
    
    URGENT_KEYWORDS = ['urgent', 'now', 'immediately', 'asap', 'emergency', 'critical', 'right now']
    
    def __init__(self):
        self.db = DatabaseConnectionPool()
        self.notification_dispatcher = NotificationDispatcher()
        self.logger = logging.getLogger(f"{__name__}.PriorityRouter")
    
    def route_directive(self, directive_id: str, voice_command: str, 
                        raw_transcript: str = "") -> Dict[str, Any]:
        """
        Route a directive based on priority determination.
        
        Returns dict with priority, behavior, and routing information.
        """
        self.logger.info(f"Routing directive: {directive_id}")
        
        # Determine priority
        priority = self._determine_priority(voice_command, raw_transcript)
        
        # Get behavior for priority
        behavior = PriorityBehavior.get_behavior(priority)
        
        # Store directive in database
        self._store_directive(directive_id, voice_command, raw_transcript, priority)
        
        # Enqueue directive based on behavior
        queue_result = self._enqueue_directive(directive_id, priority, behavior)
        
        # Send notifications for high-priority directives
        if priority >= 8:
            self.notification_dispatcher.dispatch(
                directive_id=directive_id,
                priority=priority,
                message=f"New {self._get_priority_label(priority)} directive received",
                command_text=voice_command
            )
        
        return {
            "directive_id": directive_id,
            "priority": priority,
            "priority_label": self._get_priority_label(priority),
            "behavior": behavior,
            "queue_status": queue_result,
            "routed_at": datetime.datetime.utcnow().isoformat()
        }
    
    def _determine_priority(self, voice_command: str, raw_transcript: str = "") -> int:
        """
        Determine priority based on voice command content and patterns.
        
        Priority rules:
        - 10: Contains urgent keywords ('urgent', 'now', 'immediately', etc.)
        - 8-9: High priority based on command type
        - 6-7: Standard high priority
        - 4-5: Normal priority (default)
        - 1-3: Low priority (background processing)
        """
        command_lower = (voice_command + " " + raw_transcript).lower()
        
        # Check for critical/urgent keywords
        for keyword in self.URGENT_KEYWORDS:
            if keyword in command_lower:
                self.logger.info(f"Detected urgent keyword: {keyword}")
                return Priority.CRITICAL
        
        # Check for explicit priority indicators
        if any(word in command_lower for word in ['priority 10', 'priority ten', 'critical']):
            return Priority.CRITICAL
        if any(word in command_lower for word in ['priority 9', 'priority nine', 'priority 8', 'priority eight']):
            return Priority.URGENT
        if any(word in command_lower for word in ['priority 7', 'priority seven', 'priority 6', 'priority six']):
            return Priority.HIGH
        if any(word in command_lower for word in ['priority 5', 'priority five', 'priority 4', 'priority four']):
            return Priority.NORMAL
        if any(word in command_lower for word in ['priority 3', 'priority three', 'priority 2', 'priority two', 
                                                     'priority 1', 'priority one', 'low priority', 'when possible']):
            return Priority.LOW
        
        # Default priority based on command type analysis
        high_urgency_commands = ['lock', 'unlock', 'start', 'stop', 'activate', 'deactivate', 
                                 'open', 'close', 'enable', 'disable']
        low_urgency_commands = ['schedule', 'remind', 'maybe', 'sometime', 'later', 'when you can']
        
        if any(cmd in command_lower for cmd in high_urgency_commands):
            return Priority.HIGH
        elif any(cmd in command_lower for cmd in low_urgency_commands):
            return Priority.NORMAL_LOW
        
        # Default priority
        return Priority.NORMAL
    
    def _store_directive(self, directive_id: str, voice_command: str, 
                         raw_transcript: str, priority: int):
        """Store directive in database."""
        urgency_keyword = any(kw in voice_command.lower() for kw in self.URGENT_KEYWORDS)
        
        try:
            self.db.execute_query(
                """INSERT INTO genesis_bridge.directives 
                   (directive_id, voice_command, raw_transcript, priority, urgency_keyword, status)
                   VALUES (%s, %s, %s, %s, %s, 'pending')
                   ON CONFLICT (directive_id) DO UPDATE SET
                   voice_command = EXCLUDED.voice_command,
                   raw_transcript = EXCLUDED.raw_transcript,
                   priority = EXCLUDED.priority,
                   urgency_keyword = EXCLUDED.urgency_keyword,
                   updated_at = CURRENT_TIMESTAMP""",
                (directive_id, voice_command, raw_transcript, priority, urgency_keyword)
            )
            self.logger.debug(f"Stored directive {directive_id} with priority {priority}")
        except Exception as e:
            self.logger.error(f"Failed to store directive: {e}")
            raise
    
    def _enqueue_directive(self, directive_id: str, priority: int, 
                           behavior: Dict[str, Any]) -> Dict[str, Any]:
        """Enqueue directive based on priority behavior."""
        should_bypass = behavior.get("bypass_queue", False)
        
        if should_bypass:
            status = "bypassed"
            scheduled = datetime.datetime.utcnow()
            deadline = datetime.datetime.utcnow()
        else:
            status = "queued"
            scheduled = datetime.datetime.utcnow()
            max_wait = behavior.get("max_wait_time")
            if max_wait:
                deadline = datetime.datetime.utcnow() + datetime.timedelta(seconds=max_wait)
            else:
                deadline = None
        
        try:
            self.db.execute_query(
                """INSERT INTO genesis_bridge.priority_queue 
                   (directive_id, priority, enqueued_at, scheduled_processing, processing_deadline, status)
                   VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s)""",
                (directive_id, priority, scheduled, deadline, status),
                fetch=False
            )
            
            return {
                "status": status,
                "scheduled_processing": scheduled.isoformat(),
                "deadline": deadline.isoformat() if deadline else None
            }
        except Exception as e:
            self.logger.error(f"Failed to enqueue directive: {e}")
            raise
    
    def escalate_priority(self, directive_id: str, new_priority: int, 
                         reason: str, method: str = "automatic") -> Dict[str, Any]:
        """
        Escalate priority of a directive.
        
        Returns the escalation result.
        """
        # Get current directive info
        current = self.db.execute_query(
            """SELECT priority, status FROM genesis_bridge.directives 
               WHERE directive_id = %s""",
            (directive_id,)
        )
        
        if not current:
            raise ValueError(f"Directive {directive_id} not found")
        
        old_priority = current[0]['priority']
        current_status = current[0]['status']
        
        # Update directive
        self.db.execute_query(
            """UPDATE genesis_bridge.directives 
               SET priority = %s, 
                   previous_priority = %s,
                   escalated_at = CURRENT_TIMESTAMP,
                   status = 'escalated',
                   updated_at = CURRENT_TIMESTAMP
               WHERE directive_id = %s""",
            (new_priority, old_priority, directive_id)
        )
        
        # Log escalation
        self.db.execute_query(
            """INSERT INTO genesis_bridge.priority_escalation_log 
               (directive_id, old_priority, new_priority, escalation_reason, escalation_method)
               VALUES (%s, %s, %s, %s, %s)""",
            (directive_id, old_priority, new_priority, reason, method)
        )
        
        # Update queue
        self.db.execute_query(
            """UPDATE genesis_bridge.priority_queue 
               SET priority = %s, status = 'queued'
               WHERE directive_id = %s""",
            (new_priority, directive_id)
        )
        
        # Send notification for escalation
        self.notification_dispatcher.dispatch(
            directive_id=directive_id,
            priority=new_priority,
            message=f"Priority escalated from {old_priority} to {new_priority}: {reason}",
            command_text=""
        )
        
        self.logger.warning(f"Directive {directive_id} escalated from {old_priority} to {new_priority}: {reason}")
        
        return {
            "directive_id": directive_id,
            "old_priority": old_priority,
            "new_priority": new_priority,
            "reason": reason,
            "method": method,
            "escalated_at": datetime.datetime.utcnow().isoformat()
        }
    
    def _get_priority_label(self, priority: int) -> str:
        """Get human-readable priority label."""
        if priority == 10:
            return "CRITICAL"
        elif priority >= 8:
            return "URGENT"
        elif priority >= 6:
            return "HIGH"
        elif priority >= 4:
            return "NORMAL"
        else:
            return "LOW"


# =============================================================================
# Priority Escalation Worker
# =============================================================================

class PriorityEscalationWorker:
    """Background worker that checks for and performs priority escalations."""
    
    def __init__(self, check_interval: int = 60):
        self.db = DatabaseConnectionPool()
        self.router = PriorityRouter()
        self.check_interval = check_interval
        self.running = False
        self.thread = None
        self.logger = logging.getLogger(f"{__name__}.PriorityEscalationWorker")
    
    def start(self):
        """Start the escalation worker."""
        if self.running:
            self.logger.warning("Escalation worker already running")
            return
        
        self.running = True
        self.thread = threading.Thread(target=self._run, daemon=True)
        self.thread.start()
        self.logger.info("Priority escalation worker started")
    
    def stop(self):
        """Stop the escalation worker."""
        self.running = False
        if self.thread:
            self.thread.join(timeout=5)
        self.logger.info("Priority escalation worker stopped")
    
    def _run(self):
        """Main worker loop."""
        while self.running:
            try:
                self._check_and_escalate()
            except Exception as e:
                self.logger.error(f"Error in escalation check: {e}")
            
            time.sleep(self.check_interval)
    
    def _check_and_escalate(self):
        """Check pending directives and escalate if needed."""
        now = datetime.datetime.utcnow()
        
        # Get pending directives
        pending = self.db.execute_query(
            """SELECT d.directive_id, d.priority, d.created_at, d.status,
                      EXTRACT(EPOCH FROM (NOW() - d.created_at)) as pending_seconds
               FROM genesis_bridge.directives d
               WHERE d.status IN ('pending', 'escalated')
               AND d.notification_sent = FALSE"""
        )
        
        for directive in pending:
            directive_id = directive['directive_id']
            current_priority = directive['priority']
            pending_seconds = directive['pending_seconds']
            
            # Rule 1: If pending > 15 minutes and priority < 8, escalate to 8
            if pending_seconds > 900 and current_priority < 8:
                self.logger.info(f"Escalating {directive_id} to URGENT (pending > 15 min)")
                self.router.escalate_priority(
                    directive_id=directive_id,
                    new_priority=Priority.URGENT,
                    reason="Pending for more than 15 minutes",
                    method="automatic_time_based"
                )
            
            # Rule 2: If pending > 5 minutes and priority < 8, escalate by 1
            elif pending_seconds > 300 and current_priority < 8:
                new_priority = min(current_priority + 1, 7)
                if new_priority > current_priority:
                    self.logger.info(f"Escalating {directive_id} from {current_priority} to {new_priority}")
                    self.router.escalate_priority(
                        directive_id=directive_id,
                        new_priority=new_priority,
                        reason="Pending for more than 5 minutes",
                        method="automatic_time_based"
                    )
    
    def handle_urgency_keyword(self, directive_id: str, command_text: str) -> Dict[str, Any]:
        """
        Handle urgent keyword detection in follow-up.
        
        Called when Kinan says 'urgent' or 'now' in follow-up.
        """
        # Check for urgent keywords
        command_lower = command_text.lower()
        urgent_keywords = ['urgent', 'now', 'immediately', 'asap', 'emergency', 'critical', 'right now']
        
        if any(kw in command_lower for kw in urgent_keywords):
            self.logger.info(f"Urgency keyword detected for {directive_id}, escalating to CRITICAL")
            return self.router.escalate_priority(
                directive_id=directive_id,
                new_priority=Priority.CRITICAL,
                reason=f"Urgency keyword detected: {command_text}",
                method="keyword_detection"
            )
        
        return None


# =============================================================================
# FastAPI Application
# =============================================================================

from fastapi import FastAPI, HTTPException, Header, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import Optional, List
import uvicorn

app = FastAPI(title="AIVA Priority Routing API", version="1.0.0")

# API Key authentication
API_KEY = os.getenv('API_KEY', 'aiva-priority-router-secret-key')

def verify_api_key(x_api_key: str = Header(...)):
    """Verify API key for authentication."""
    if x_api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API key")


# Pydantic models
class DirectiveRequest(BaseModel):
    directive_id: str = Field(..., description="Unique directive identifier")
    voice_command: str = Field(..., description="The voice command text")
    raw_transcript: Optional[str] = Field("", description="Raw transcript from voice recognition")


class EscalationRequest(BaseModel):
    directive_id: str = Field(..., description="Directive to escalate")
    new_priority: int = Field(..., ge=1, le=10, description="New priority level (1-10)")
    reason: str = Field(..., description="Reason for escalation")


class UrgencyKeywordRequest(BaseModel):
    directive_id: str = Field(..., description="Directive identifier")
    command_text: str = Field(..., description="Follow-up command text containing urgency keywords")


class DirectiveResponse(BaseModel):
    directive_id: str
    priority: int
    priority_label: str
    behavior: dict
    queue_status: dict
    routed_at: str


class EscalationResponse(BaseModel):
    directive_id: str
    old_priority: int
    new_priority: int
    reason: str
    method: str
    escalated_at: str


# WebSocket connection manager
class ConnectionManager:
    def __init__(self):
        self.active_connections: List[WebSocket] = []
    
    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)
    
    def disconnect(self, websocket: WebSocket):
        self.active_connections.remove(websocket)
    
    async def broadcast(self, message: dict):
        disconnected = []
        for connection in self.active_connections:
            try:
                await connection.send_json(message)
            except Exception:
                disconnected.append(connection)
        
        for conn in disconnected:
            self.disconnect(conn)


manager = ConnectionManager()

# Global instances
router = PriorityRouter()
escalation_worker = PriorityEscalationWorker()


@app.on_event("startup")
async def startup_event():
    """Initialize database and start workers on startup."""
    logger.info("Setting up database schema...")
    DatabaseSchema.setup_schema()
    
    logger.info("Starting escalation worker...")
    escalation_worker.start()


@app.on_event("shutdown")
async def shutdown_event():
    """Clean up on shutdown."""
    logger.info("Stopping escalation worker...")
    escalation_worker.stop()
    
    logger.info("Closing database connections...")
    DatabaseConnectionPool().close_all()


# Routes
@app.post("/route", response_model=DirectiveResponse, dependencies=[Depends(verify_api_key)])
async def route_directive(request: DirectiveRequest, x_api_key: str = Header(...)):
    """
    Route a directive and determine its priority.
    
    Determines priority based on voice command analysis and stores
    the directive in the priority queue.
    """
    try:
        result = router.route_directive(
            directive_id=request.directive_id,
            voice_command=request.voice_command,
            raw_transcript=request.raw_transcript
        )
        return result
    except Exception as e:
        logger.error(f"Error routing directive: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/escalate", response_model=EscalationResponse, dependencies=[Depends(verify_api_key)])
async def escalate_directive(request: EscalationRequest, x_api_key: str = Header(...)):
    """
    Manually escalate a directive's priority.
    
    Allows manual priority escalation with a reason.
    """
    try:
        result = router.escalate_priority(
            directive_id=request.directive_id,
            new_priority=request.new_priority,
            reason=request.reason,
            method="manual"
        )
        return result
    except ValueError as e:
        raise HTTPException(status_code=404, detail=str(e))
    except Exception as e:
        logger.error(f"Error escalating directive: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/urgency-keyword", response_model=EscalationResponse, dependencies=[Depends(verify_api_key)])
async def handle_urgency_keyword(request: UrgencyKeywordRequest, x_api_key: str = Header(...)):
    """
    Handle urgency keyword detection in follow-up commands.
    
    Called when Kinan says 'urgent' or 'now' in follow-up,
    escalates the directive to CRITICAL (priority 10).
    """
    try:
        result = escalation_worker.handle_urgency_keyword(
            directive_id=request.directive_id,
            command_text=request.command_text
        )
        
        if result is None:
            return {
                "directive_id": request.directive_id,
                "old_priority": 0,
                "new_priority": 0,
                "reason": "No urgency keywords detected",
                "method": "none",
                "escalated_at": datetime.datetime.utcnow().isoformat()
            }
        
        return result
    except Exception as e:
        logger.error(f"Error handling urgency keyword: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/directive/{directive_id}", dependencies=[Depends(verify_api_key)])
async def get_directive(directive_id: str, x_api_key: str = Header(...)):
    """Get directive details including current priority."""
    db = DatabaseConnectionPool()
    try:
        result = db.execute_query(
            """SELECT d.*, pq.status as queue_status, pq.enqueued_at, pq.processing_deadline
               FROM genesis_bridge.directives d
               LEFT JOIN genesis_bridge.priority_queue pq ON d.directive_id = pq.directive_id
               WHERE d.directive_id = %s""",
            (directive_id,)
        )
        
        if not result:
            raise HTTPException(status_code=404, detail="Directive not found")
        
        return result[0]
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error getting directive: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/queue", dependencies=[Depends(verify_api_key)])
async def get_queue_status(
    min_priority: int = Query(1, ge=1, le=10),
    x_api_key: str = Header(...)
):
    """Get current priority queue status."""
    db = DatabaseConnectionPool()
    try:
        result = db.execute_query(
            """SELECT d.directive_id, d.priority, d.voice_command, d.status,
                      pq.enqueued_at, pq.processing_deadline, pq.status as queue_status,
                      EXTRACT(EPOCH FROM (NOW() - pq.enqueued_at)) as wait_seconds
               FROM genesis_bridge.directives d
               JOIN genesis_bridge.priority_queue pq ON d.directive_id = pq.directive_id
               WHERE d.priority >= %s AND pq.status IN ('queued', 'processing')
               ORDER BY d.priority DESC, pq.enqueued_at ASC""",
            (min_priority,)
        )
        
        return {
            "queue": result,
            "total": len(result),
            "timestamp": datetime.datetime.utcnow().isoformat()
        }
    except Exception as e:
        logger.error(f"Error getting queue status: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/escalation-log", dependencies=[Depends(verify_api_key)])
async def get_escalation_log(
    directive_id: Optional[str] = None,
    limit: int = Query(50, ge=1, le=500),
    x_api_key: str = Header(...)
):
    """Get priority escalation history."""
    db = DatabaseConnectionPool()
    try:
        if directive_id:
            result = db.execute_query(
                """SELECT * FROM genesis_bridge.priority_escalation_log
                   WHERE directive_id = %s
                   ORDER BY escalated_at DESC
                   LIMIT %s""",
                (directive_id, limit)
            )
        else:
            result = db.execute_query(
                """SELECT * FROM genesis_bridge.priority_escalation_log
                   ORDER BY escalated_at DESC
                   LIMIT %s""",
                (limit,)
            )
        
        return {
            "escalations": result,
            "total": len(result)
        }
    except Exception as e:
        logger.error(f"Error getting escalation log: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """WebSocket endpoint for real-time priority notifications."""
    await manager.connect(websocket)
    try:
        while True:
            data = await websocket.receive_text()
            # Echo back for ping/pong
            await websocket.send_text(data)
    except WebSocketDisconnect:
        manager.disconnect(websocket)


@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "timestamp": datetime.datetime.utcnow().isoformat()
    }


# Add missing import
from fastapi import Depends, Query


# =============================================================================
# Test Suite
# =============================================================================

import unittest
from unittest.mock import Mock, patch, MagicMock
import tempfile
import os


class TestPriorityBehavior(unittest.TestCase):
    """Tests for PriorityBehavior class."""
    
    def test_critical_priority_behavior(self):
        """Test CRITICAL (10) priority behavior."""
        behavior = PriorityBehavior.get_behavior(10)
        self.assertEqual(behavior["execution"], "immediate")
        self.assertTrue(behavior["bypass_queue"])
        self.assertEqual(behavior["max_wait_time"], 0)
        self.assertIn("terminal_bell", behavior["notification"])
        self.assertIn("desktop_notification", behavior["notification"])
    
    def test_urgent_priority_behavior(self):
        """Test URGENT (8-9) priority behavior."""
        for priority in [8, 9]:
            behavior = PriorityBehavior.get_behavior(priority)
            self.assertEqual(behavior["max_wait_time"], 60)
            self.assertIn("terminal_bell", behavior["notification"])
            self.assertIn("websocket", behavior["notification"])
    
    def test_high_priority_behavior(self):
        """Test HIGH (6-7) priority behavior."""
        for priority in [6, 7]:
            behavior = PriorityBehavior.get_behavior(priority)
            self.assertEqual(behavior["max_wait_time"], 300)
            self.assertIn("websocket", behavior["notification"])
            self.assertIn("log", behavior["notification"])
    
    def test_normal_priority_behavior(self):
        """Test NORMAL (4-5) priority behavior."""
        for priority in [4, 5]:
            behavior = PriorityBehavior.get_behavior(priority)
            self.assertIn("websocket", behavior["notification"])
            self.assertIn("log", behavior["notification"])
    
    def test_low_priority_behavior(self):
        """Test LOW (1-3) priority behavior."""
        for priority in [1, 2, 3]:
            behavior = PriorityBehavior.get_behavior(priority)
            self.assertEqual(behavior["execution"], "background")
            self.assertEqual(behavior["notification"], ["log"])
    
    def test_notification_channels(self):
        """Test notification channel selection."""
        self.assertIn("desktop_notification", PriorityBehavior.get_notification_channels(10))
        self.assertIn("terminal_bell", PriorityBehavior.get_notification_channels(9))
        self.assertIn("websocket", PriorityBehavior.get_notification_channels(7))
        self.assertNotIn("desktop_notification", PriorityBehavior.get_notification_channels(8))
        self.assertEqual(PriorityBehavior.get_notification_channels(1), ["log"])
    
    def test_bypass_queue(self):
        """Test queue bypass functionality."""
        self.assertTrue(PriorityBehavior.should_bypass_queue(10))
        self.assertFalse(PriorityBehavior.should_bypass_queue(9))
        self.assertFalse(PriorityBehavior.should_bypass_queue(5))
        self.assertFalse(PriorityBehavior.should_bypass_queue(1))


class TestPriorityRouter(unittest.TestCase):
    """Tests for PriorityRouter class."""
    
    def setUp(self):
        """Set up test fixtures."""
        self.router = PriorityRouter()
    
    def test_determine_priority_urgent_keywords(self):
        """Test priority determination with urgent keywords."""
        test_cases = [
            ("This is urgent", Priority.CRITICAL),
            ("Do it now", Priority.CRITICAL),
            ("I need this immediately", Priority.CRITICAL),
            ("Emergency shutdown", Priority.CRITICAL),
            ("ASAP completion", Priority.CRITICAL),
            ("Critical system check", Priority.CRITICAL),
            ("Right now open the door", Priority.CRITICAL),
        ]
        
        for command, expected in test_cases:
            priority = self.router._determine_priority(command)
            self.assertEqual(priority, expected, f"Failed for command: {command}")
    
    def test_determine_priority_explicit_priority(self):
        """Test priority determination with explicit priority numbers."""
        test_cases = [
            ("Set priority 10", Priority.CRITICAL),
            ("Priority nine task", Priority.URGENT),
            ("Priority 8 request", Priority.URGENT),
            ("Priority seven action", Priority.HIGH),
            ("Task priority six", Priority.HIGH_MEDIUM),
            ("Priority five item", Priority.NORMAL),
            ("Set priority 4", Priority.NORMAL_LOW),
            ("Priority three", Priority.LOW),
            ("Priority two", Priority.LOW_MEDIUM),
            ("Priority one task", Priority.LOW_MIN),
            ("Low priority item", Priority.LOW),
        ]
        
        for command, expected in test_cases:
            priority = self.router._determine_priority(command)
            self.assertEqual(priority, expected, f"Failed for command: {command}")
    
    def test_determine_priority_command_type(self):
        """Test priority determination based on command type."""
        # High urgency commands
        high_urgency = ["Lock the door", "Stop the system", "Activate security"]
        for cmd in high_urgency:
            priority = self.router._determine_priority(cmd)
            self.assertIn(priority, [Priority.HIGH, Priority.CRITICAL], f"Failed for: {cmd}")
        
        # Low urgency commands
        low_urgency = ["Schedule a meeting", "Remind me later", "Do this sometime"]
        for cmd in low_urgency:
            priority = self.router._determine_priority(cmd)
            self.assertIn(priority, [Priority.NORMAL, Priority.NORMAL_LOW, Priority.LOW], f"Failed for: {cmd}")
    
    def test_determine_priority_default(self):
        """Test default priority determination."""
        # Regular commands should get NORMAL priority
        regular_commands = ["Run a report", "Check status", "Show data"]
        for cmd in regular_commands:
            priority = self.router._determine_priority(cmd)
            self.assertIn(priority, [Priority.NORMAL, Priority.NORMAL_LOW], f"Failed for: {cmd}")
    
    def test_get_priority_label(self):
        """Test priority label generation."""
        self.assertEqual(self.router._get_priority_label(10), "CRITICAL")
        self.assertEqual(self.router._get_priority_label(9), "URGENT")
        self.assertEqual(self.router._get_priority_label(8), "URGENT")
        self.assertEqual(self.router._get_priority_label(7), "HIGH")
        self.assertEqual(self.router._get_priority_label(6), "HIGH")
        self.assertEqual(self.router._get_priority_label(5), "NORMAL")
        self.assertEqual(self.router._get_priority_label(4), "NORMAL")
        self.assertEqual(self.router._get_priority_label(3), "LOW")
        self.assertEqual(self.router._get_priority_label(2), "LOW")
        self.assertEqual(self.router._get_priority_label(1), "LOW")


class TestPriorityEscalation(unittest.TestCase):
    """Tests for priority escalation scenarios."""
    
    def setUp(self):
        """Set up test fixtures."""
        self.worker = PriorityEscalationWorker(check_interval=1)
    
    def test_urgency_keyword_detection(self):
        """Test detection of urgency keywords in follow-up commands."""
        test_cases = [
            ("This is urgent now", True),
            ("Make it urgent", True),
            ("Do this immediately", True),
            ("Emergency situation", True),
            ("ASAP please", True),
            ("Just do it normally", False),
            ("No rush on this", False),
        ]
        
        for command, should_escalate in test_cases:
            # Create a mock router
            mock_router = Mock()
            mock_router.escalate_priority = Mock(return_value={
                "directive_id": "test-123",
                "new_priority": Priority.CRITICAL
            })
            self.worker.router = mock_router
            
            # Test keyword detection logic
            command_lower = command.lower()
            urgent_keywords = ['urgent', 'now', 'immediately', 'asap', 'emergency', 'critical', 'right now']
            has_urgency = any(kw in command_lower for kw in urgent_keywords)
            
            self.assertEqual(has_urgency, should_escalate, f"Failed for: {command}")
    
    def test_escalation_worker_initialization(self):
        """Test escalation worker initialization."""
        self.assertEqual(self.worker.check_interval, 1)
        self.assertFalse(self.worker.running)
        self.assertIsNone(self.worker.thread)


class TestNotificationDispatcher(unittest.TestCase):
    """Tests for NotificationDispatcher class."""
    
    def setUp(self):
        """Set up test fixtures."""
        self.dispatcher = NotificationDispatcher()
    
    def test_get_notification_channels_critical(self):
        """Test notification channels for CRITICAL priority."""
        channels = PriorityBehavior.get_notification_channels(10)
        self.assertIn("websocket", channels)
        self.assertIn("terminal_bell", channels)
        self.assertIn("log", channels)
        self.assertIn("desktop_notification", channels)
    
    def test_get_notification_channels_urgent(self):
        """Test notification channels for URGENT priority."""
        for priority in [8, 9]:
            channels = PriorityBehavior.get_notification_channels(priority)
            self.assertIn("websocket", channels)
            self.assertIn("terminal_bell", channels)
            self.assertIn("log", channels)
            self.assertNotIn("desktop_notification", channels)
    
    def test_get_notification_channels_high(self):
        """Test notification channels for HIGH priority."""
        for priority in [6, 7]:
            channels = PriorityBehavior.get_notification_channels(priority)
            self.assertIn("websocket", channels)
            self.assertIn("log", channels)
            self.assertNotIn("terminal_bell", channels)
    
    def test_get_notification_channels_normal(self):
        """Test notification channels for NORMAL priority."""
        for priority in [4, 5]:
            channels = PriorityBehavior.get_notification_channels(priority)
            self.assertIn("websocket", channels)
            self.assertIn("log", channels)
            self.assertNotIn("terminal_bell", channels)
    
    def test_get_notification_channels_low(self):
        """Test notification channels for LOW priority."""
        for priority in [1, 2, 3]:
            channels = PriorityBehavior.get_notification_channels(priority)
            self.assertEqual(channels, ["log"])
    
    def test_dispatcher_with_websocket(self):
        """Test dispatcher with WebSocket handler."""
        mock_ws = Mock()
        dispatcher = NotificationDispatcher(websocket_handler=mock_ws)
        
        # The dispatcher should use the websocket handler
        self.assertIsNotNone(dispatcher.websocket_handler)


class TestDatabaseSchema(unittest.TestCase):
    """Tests for database schema setup."""
    
    def test_priority_behavior_map_completeness(self):
        """Test that all priority levels have behavior defined."""
        for priority in range(1, 11):
            behavior = PriorityBehavior.get_behavior(priority)
            self.assertIsNotNone(behavior)
            self.assertIn("execution", behavior)
            self.assertIn("notification", behavior)
            self.assertIn("bypass_queue", behavior)


class TestIntegration(unittest.TestCase):
    """Integration tests for the priority routing system."""
    
    def test_priority_escalation_rules(self):
        """Test priority escalation rules logic."""
        # Rule 1: pending > 15 min and priority < 8 -> escalate to 8
        test_cases_rule1 = [
            {"pending_seconds": 900, "current_priority": 7, "should_escalate": True, "expected_priority": Priority.URGENT},
            {"pending_seconds": 901, "current_priority": 5, "should_escalate": True, "expected_priority": Priority.URGENT},
            {"pending_seconds": 899, "current_priority": 7, "should_escalate": False, "expected_priority": None},
            {"pending_seconds": 900, "current_priority": 8, "should_escalate": False, "expected_priority": None},
            {"pending_seconds": 900, "current_priority": 9, "should_escalate": False, "expected_priority": None},
        ]
        
        for case in test_cases_rule1:
            pending_seconds = case["pending_seconds"]
            current_priority = case["current_priority"]
            should_escalate = case["should_escalate"]
            expected_priority = case["expected_priority"]
            
            # Rule 1 logic
            if pending_seconds > 900 and current_priority < 8:
                self.assertTrue(should_escalate)
                self.assertEqual(expected_priority, Priority.URGENT)
            else:
                self.assertFalse(should_escalate)
        
        # Rule 2: pending > 5 min and priority < 8 -> escalate by 1
        test_cases_rule2 = [
            {"pending_seconds": 300, "current_priority": 7, "should_escalate": False, "reason": "already at threshold"},
            {"pending_seconds": 301, "current_priority": 6, "should_escalate": True, "expected_new": 7},
            {"pending_seconds": 301, "current_priority": 5, "should_escalate": True, "expected_new": 6},
            {"pending_seconds": 301, "current_priority": 4, "should_escalate": True, "expected_new": 5},
            {"pending_seconds": 301, "current_priority": 3, "should_escalate": True, "expected_new": 4},
            {"pending_seconds": 301, "current_priority": 2, "should_escalate": True, "expected_new": 3},
            {"pending_seconds": 301, "current_priority": 1, "should_escalate": True, "expected_new": 2},
            {"pending_seconds": 299, "current_priority": 5, "should_escalate": False, "reason": "under threshold"},
            {"pending_seconds": 301, "current_priority": 8, "should_escalate": False, "reason": "already high"},
        ]
        
        for case in test_cases_rule2:
            pending_seconds = case["pending_seconds"]
            current_priority = case["current_priority"]
            should_escalate = case["should_escalate"]
            
            # Rule 2 logic
            if pending_seconds > 300 and current_priority < 8:
                new_priority = min(current_priority + 1, 7)
                self.assertTrue(should_escalate)
                if "expected_new" in case:
                    self.assertEqual(new_priority, case["expected_new"])
            else:
                self.assertFalse(should_escalate)


def run_tests():
    """Run all tests and return results."""
    # Create test suite
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    
    # Add all test classes
    suite.addTests(loader.loadTestsFromTestCase(TestPriorityBehavior))
    suite.addTests(loader.loadTestsFromTestCase(TestPriorityRouter))
    suite.addTests(loader.loadTestsFromTestCase(TestPriorityEscalation))
    suite.addTests(loader.loadTestsFromTestCase(TestNotificationDispatcher))
    suite.addTests(loader.loadTestsFromTestCase(TestDatabaseSchema))
    suite.addTests(loader.loadTestsFromTestCase(TestIntegration))
    
    # Run tests
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)
    
    return result


# =============================================================================
# Main Entry Point
# =============================================================================

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="AIVA Priority Routing System")
    parser.add_argument("--run-server", action="store_true", help="Run the FastAPI server")
    parser.add_argument("--run-tests", action="store_true", help="Run the test suite")
    parser.add_argument("--setup-schema", action="store_true", help="Setup database schema")
    parser.add_argument("--host", default="0.0.0.0", help="Server host")
    parser.add_argument("--port", type=int, default=8000, help="Server port")
    
    args = parser.parse_args()
    
    if args.run_tests:
        print("Running test suite...")
        result = run_tests()
        sys.exit(0 if result.wasSuccessful() else 1)
    
    elif args.setup_schema:
        print("Setting up database schema...")
        DatabaseSchema.setup_schema()
        print("Schema setup complete.")
    
    elif args.run_server:
        print(f"Starting server on {args.host}:{args.port}...")
        uvicorn.run(app, host=args.host, port=args.port)
    
    else:
        parser.print_help()