#!/usr/bin/env python3
"""
AIVA Voice Command Bridge - Telnyx Webhook Handler
Handles call events and extracts directives from voice transcripts.
"""

import os
import json
import logging
import hashlib
import hmac
import re
from datetime import datetime
from typing import Optional, Dict, Any, List
from enum import Enum
from dataclasses import dataclass
from contextlib import contextmanager

from fastapi import FastAPI, Request, HTTPException, Header, Depends
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
import psycopg2
from psycopg2 import sql
from psycopg2.extras import RealDictCursor

from core.secrets_loader import get_postgres_config, get_telnyx_api_key

# RLM Gateway — shadow mode integration (import with fallback guard)
_rlm_gateway = None
try:
    import sys as _sys
    _sys.path.append('/mnt/e/genesis-system')
    from AIVA.rlm_gateway import get_gateway as _get_rlm_gateway
    _rlm_gateway = _get_rlm_gateway()
    logger_rlm = logging.getLogger("RLM.webhook_bridge")
    logger_rlm.info("RLM Gateway loaded in shadow mode")
except Exception as _rlm_import_err:
    import logging as _log_fallback
    _log_fallback.getLogger("RLM.webhook_bridge").warning(
        f"RLM Gateway not available (shadow mode disabled): {_rlm_import_err}"
    )


# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Configuration from environment
TELNYX_API_KEY = get_telnyx_api_key()
WEBHOOK_SECRET = os.getenv("TELNYX_WEBHOOK_SECRET", "")
API_KEY = os.getenv("API_KEY", "genesis-bridge-api-key")

# Directive keyword patterns
DIRECTIVE_PATTERNS = [
    r'\b(claude|claude code|ai|assistant)\b',
    r'\b(directive|command|task|instruction)\b',
    r'\b(do something|perform|execute|run)\b',
    r'\b(please help|can you|could you)\b',
    r'\b(urgent|important|asap|immediately)\b',
]


class EventType(str, Enum):
    CALL_INITIATED = "call.initiated"
    CALL_ANSWERED = "call.answered"
    CALL_TRANSCRIPTION = "call.transcription"
    CALL_HANGUP = "call.hangup"
    CALL_MACHINE_DETECTION = "call.machine.detection"
    CALL_RECORDING = "call.recording.complete"


@dataclass
class CallTranscript:
    """Data class for call transcript records."""
    id: Optional[int] = None
    call_id: str = ""
    caller_number: str = ""
    transcript_text: str = ""
    is_directive: bool = False
    directive_id: Optional[int] = None
    confidence_score: float = 0.0
    is_partial: bool = False
    created_at: Optional[datetime] = None


@contextmanager
def get_db_connection():
    """Context manager for database connections."""
    conn = None
    try:
        pg_config = get_postgres_config()
        if not pg_config.is_configured:
            raise psycopg2.Error("PostgreSQL is not configured via environment variables.")
        
        conn_params = {
            "host": pg_config.host,
            "port": pg_config.port,
            "user": pg_config.user,
            "password": pg_config.password,
            "dbname": pg_config.dbname,
            "sslmode": pg_config.sslmode
        }
        conn = psycopg2.connect(**conn_params)
        yield conn
    except psycopg2.Error as e:
        logger.error(f"Database connection error: {e}")
        raise
    finally:
        if conn:
            conn.close()


@contextmanager
def get_db_cursor(dict_cursor=True):
    """Context manager for database cursors."""
    with get_db_connection() as conn:
        cursor = conn.cursor(cursor_factory=RealDictCursor if dict_cursor else None)
        try:
            yield cursor
            conn.commit()
        except Exception as e:
            conn.rollback()
            logger.error(f"Database operation error: {e}")
            raise
        finally:
            cursor.close()


def init_database():
    """Initialize database schema and tables."""
    logger.info("Initializing database schema...")
    
    with get_db_cursor(dict_cursor=False) as cursor:
        # Create schema
        cursor.execute("CREATE SCHEMA IF NOT EXISTS genesis_bridge")
        
        # Create call_transcripts table
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS genesis_bridge.call_transcripts (
                id SERIAL PRIMARY KEY,
                call_id VARCHAR(255) NOT NULL,
                caller_number VARCHAR(50),
                transcript_text TEXT,
                is_directive BOOLEAN DEFAULT FALSE,
                directive_id INTEGER,
                confidence_score DECIMAL(5,4),
                is_partial BOOLEAN DEFAULT FALSE,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                CONSTRAINT fk_directive FOREIGN KEY (directive_id) 
                    REFERENCES genesis_bridge.command_queue(id) ON DELETE SET NULL
            )
        """)
        
        # Create index for faster lookups
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_call_transcripts_call_id 
            ON genesis_bridge.call_transcripts(call_id)
        """)
        
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_call_transcripts_directive 
            ON genesis_bridge.call_transcripts(is_directive, directive_id)
        """)
        
        # Create call_events table for tracking call lifecycle
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS genesis_bridge.call_events (
                id SERIAL PRIMARY KEY,
                call_id VARCHAR(255) NOT NULL,
                event_type VARCHAR(100) NOT NULL,
                caller_number VARCHAR(50),
                callee_number VARCHAR(50),
                payload JSONB,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_call_events_call_id 
            ON genesis_bridge.call_events(call_id)
        """)
        
    logger.info("Database schema initialized successfully")


def verify_telnyx_signature(payload: bytes, signature: str, secret: str) -> bool:
    """Verify Telnyx webhook signature."""
    if not secret:
        logger.warning("No webhook secret configured, skipping signature verification")
        return True
    
    try:
        expected_signature = hmac.new(
            secret.encode('utf-8'),
            payload,
            hashlib.sha256
        ).hexdigest()
        
        return hmac.compare_digest(expected_signature, signature)
    except Exception as e:
        logger.error(f"Signature verification error: {e}")
        return False


def detect_directive(text: str) -> Dict[str, Any]:
    """
    Detect if text contains a directive pattern.
    Returns dict with is_directive, confidence, and matched_patterns.
    """
    text_lower = text.lower()
    matched_patterns = []
    
    for i, pattern in enumerate(DIRECTIVE_PATTERNS):
        if re.search(pattern, text_lower, re.IGNORECASE):
            matched_patterns.append(pattern)
    
    is_directive = len(matched_patterns) > 0
    confidence = min(1.0, len(matched_patterns) * 0.3 + 0.5) if is_directive else 0.0
    
    return {
        "is_directive": is_directive,
        "confidence": confidence,
        "matched_patterns": matched_patterns,
        "text": text
    }


def analyze_transcript_context(call_id: str) -> Dict[str, Any]:
    """Analyze previous transcripts in the call to detect multi-turn directives."""
    with get_db_cursor() as cursor:
        cursor.execute("""
            SELECT transcript_text, is_directive, confidence_score, is_partial
            FROM genesis_bridge.call_transcripts
            WHERE call_id = %s
            ORDER BY created_at ASC
        """, (call_id,))
        
        transcripts = cursor.fetchall()
    
    if not transcripts:
        return {"has_context": False, "directive_detected": False}
    
    # Check if any previous transcript was marked as directive
    previous_directives = [t for t in transcripts if t.get('is_directive')]
    if previous_directives:
        return {
            "has_context": True,
            "directive_detected": True,
            "previous_count": len(transcripts),
            "directive_count": len(previous_directives)
        }
    
    # Check for multi-turn conversation patterns
    recent_transcripts = transcripts[-3:] if len(transcripts) >= 3 else transcripts
    combined_text = " ".join([t.get('transcript_text', '') for t in recent_transcripts])
    
    directive_result = detect_directive(combined_text)
    
    return {
        "has_context": True,
        "directive_detected": directive_result["is_directive"],
        "combined_text": combined_text[:500],
        "transcript_count": len(transcripts),
        "recent_count": len(recent_transcripts)
    }


def create_directive_from_transcript(
    call_id: str,
    transcript_text: str,
    caller_number: str,
    confidence: float
) -> Optional[int]:
    """Create a directive entry in command_queue from transcript."""
    with get_db_cursor() as cursor:
        # Check if command_queue table exists, create if not
        cursor.execute("""
            SELECT EXISTS (
                SELECT FROM information_schema.tables 
                WHERE table_schema = 'genesis_bridge' 
                AND table_name = 'command_queue'
            )
        """)
        
        if not cursor.fetchone()['exists']:
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS genesis_bridge.command_queue (
                    id SERIAL PRIMARY KEY,
                    call_id VARCHAR(255),
                    caller_number VARCHAR(50),
                    command_text TEXT NOT NULL,
                    command_type VARCHAR(50) DEFAULT 'voice_transcript',
                    status VARCHAR(50) DEFAULT 'pending',
                    priority INTEGER DEFAULT 5,
                    metadata JSONB,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    processed_at TIMESTAMP
                )
            """)
        
        # Insert directive
        cursor.execute("""
            INSERT INTO genesis_bridge.command_queue 
            (call_id, caller_number, command_text, command_type, status, priority, metadata)
            VALUES (%s, %s, %s, %s, %s, %s, %s)
            RETURNING id
        """, (
            call_id,
            caller_number,
            transcript_text,
            'voice_transcript',
            'pending',
            7 if confidence > 0.7 else 5,
            json.dumps({
                "source": "telnyx_transcription",
                "confidence": confidence,
                "auto_created": True
            })
        ))
        
        result = cursor.fetchone()
        directive_id = result['id'] if result else None
        
        logger.info(f"Created directive {directive_id} from call {call_id}")
        return directive_id


def store_transcript(transcript: CallTranscript) -> int:
    """Store a call transcript in the database."""
    with get_db_cursor() as cursor:
        cursor.execute("""
            INSERT INTO genesis_bridge.call_transcripts 
            (call_id, caller_number, transcript_text, is_directive, directive_id, 
             confidence_score, is_partial, created_at)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
            RETURNING id
        """, (
            transcript.call_id,
            transcript.caller_number,
            transcript.transcript_text,
            transcript.is_directive,
            transcript.directive_id,
            transcript.confidence_score,
            transcript.is_partial,
            transcript.created_at or datetime.utcnow()
        ))
        
        result = cursor.fetchone()
        return result['id'] if result else -1


def store_call_event(call_id: str, event_type: str, caller_number: str, 
                     callee_number: str, payload: Dict[str, Any]) -> None:
    """Store a call event in the database."""
    with get_db_cursor() as cursor:
        cursor.execute("""
            INSERT INTO genesis_bridge.call_events 
            (call_id, event_type, caller_number, callee_number, payload)
            VALUES (%s, %s, %s, %s, %s)
        """, (
            call_id,
            event_type,
            caller_number,
            callee_number,
            json.dumps(payload)
        ))


def process_transcription_event(event_data: Dict[str, Any]) -> Dict[str, Any]:
    """Process a transcription event and extract directives."""
    result = {
        "transcript_stored": False,
        "directive_created": False,
        "directive_id": None,
        "is_partial": event_data.get("transcription", {}).get("is_final", True) == False,
        "confidence": event_data.get("transcription", {}).get("confidence", 0.0)
    }
    
    call_id = event_data.get("call_session_id") or event_data.get("call", {}).get("id", "")
    caller_number = event_data.get("call", {}).get("from", {}).get("phone_number", "")
    callee_number = event_data.get("call", {}).get("to", {}).get("phone_number", "")
    
    transcript_text = event_data.get("transcription", {}).get("text", "")
    confidence = event_data.get("transcription", {}).get("confidence", 0.0)
    is_final = event_data.get("transcription", {}).get("is_final", True)
    
    if not transcript_text:
        logger.warning(f"No transcript text in event for call {call_id}")
        return result
    
    # Determine if this is a directive
    # First check the current transcript
    directive_result = detect_directive(transcript_text)
    
    # Also check context from previous transcripts
    context_analysis = analyze_transcript_context(call_id)
    
    is_directive = directive_result["is_directive"] or context_analysis.get("directive_detected", False)
    
    # Override text with combined context if multi-turn directive detected
    final_text = transcript_text
    if context_analysis.get("directive_detected") and context_analysis.get("has_context"):
        final_text = context_analysis.get("combined_text", transcript_text)
    
    # Handle low confidence speech
    low_confidence = confidence < 0.6
    
    # Create transcript record
    transcript = CallTranscript(
        call_id=call_id,
        caller_number=caller_number,
        transcript_text=final_text,
        is_directive=is_directive,
        confidence_score=confidence,
        is_partial=not is_final
    )
    
    transcript_id = store_transcript(transcript)
    result["transcript_stored"] = True
    result["transcript_id"] = transcript_id
    
    # If directive detected and this is a final transcription, create command
    if is_directive and is_final and not low_confidence:
        directive_id = create_directive_from_transcript(
            call_id=call_id,
            transcript_text=final_text,
            caller_number=caller_number,
            confidence=max(directive_result["confidence"], confidence)
        )
        
        if directive_id:
            # Update transcript with directive link
            with get_db_cursor() as cursor:
                cursor.execute("""
                    UPDATE genesis_bridge.call_transcripts
                    SET directive_id = %s, is_directive = TRUE
                    WHERE id = %s
                """, (directive_id, transcript_id))
            
            result["directive_created"] = True
            result["directive_id"] = directive_id
            logger.info(f"Directive {directive_id} created from transcription in call {call_id}")
    elif low_confidence:
        logger.info(f"Low confidence transcription ({confidence}) flagged for review in call {call_id}")
        result["flagged_for_review"] = True
    
    return result


# FastAPI application
app = FastAPI(
    title="AIVA Voice Command Bridge - Telnyx Webhook",
    description="Webhook handler for Telnyx call events and transcription",
    version="1.0.0"
)


class TelnyxWebhookEvent(BaseModel):
    """Model for Telnyx webhook events."""
    data: Dict[str, Any]
    event_type: str = Field(alias="event_type")
    
    class Config:
        populate_by_name = True


@app.on_event("startup")
async def startup_event():
    """Initialize database on startup."""
    logger.info("Starting AIVA Voice Command Bridge - Telnyx Webhook Handler")
    try:
        init_database()
    except Exception as e:
        logger.error(f"Failed to initialize database: {e}")
        raise


def verify_api_key(x_api_key: str = Header(...)) -> bool:
    """Verify API key for authentication."""
    if x_api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API key")
    return True


@app.post("/bridge/webhook/telnyx", status_code=200)
async def handle_telnyx_webhook(
    request: Request,
    x_telnyx_signature: Optional[str] = Header(None, alias="X-Telnyx-Signature"),
    x_api_key: str = Header(...)
) -> JSONResponse:
    """
    Handle incoming Telnyx webhook events.
    
    Processes:
    - call.initiated: Call started
    - call.answered: Call was answered
    - call.transcription: Voice transcription available
    - call.hangup: Call ended
    - call.machine.detection: Voicemail detection result
    """
    # Verify API key
    if x_api_key != API_KEY:
        logger.warning("Unauthorized webhook request received")
        raise HTTPException(status_code=401, detail="Invalid API key")
    
    # Get raw body for signature verification
    raw_body = await request.body()
    
    # Verify Telnyx signature if secret is configured
    if WEBHOOK_SECRET and x_telnyx_signature:
        if not verify_telnyx_signature(raw_body, x_telnyx_signature, WEBHOOK_SECRET):
            logger.warning("Invalid Telnyx webhook signature")
            raise HTTPException(status_code=401, detail="Invalid signature")
    
    try:
        event_data = json.loads(raw_body)
    except json.JSONDecodeError as e:
        logger.error(f"Failed to parse webhook payload: {e}")
        raise HTTPException(status_code=400, detail="Invalid JSON payload")
    
    event_type = event_data.get("event_type", "")
    call_id = event_data.get("call_session_id") or event_data.get("call", {}).get("id", "unknown")
    
    logger.info(f"Received webhook event: {event_type} for call {call_id}")
    
    # Extract caller and callee info
    caller_number = event_data.get("call", {}).get("from", {}).get("phone_number", "")
    callee_number = event_data.get("call", {}).get("to", {}).get("phone_number", "")
    
    # Store the event
    try:
        store_call_event(call_id, event_type, caller_number, callee_number, event_data)
    except Exception as e:
        logger.error(f"Failed to store call event: {e}")
    
    # Process based on event type
    response_data = {
        "event_type": event_type,
        "call_id": call_id,
        "status": "processed"
    }
    
    try:
        if event_type == EventType.CALL_TRANSCRIPTION:
            # Process transcription - extract directives
            transcription_result = process_transcription_event(event_data)
            response_data.update(transcription_result)
            logger.info(f"Transcription processed for call {call_id}: {transcription_result}")
            
        elif event_type == EventType.CALL_INITIATED:
            logger.info(f"Call initiated: {caller_number} -> {callee_number}")
            response_data["message"] = "Call initiated"
            
        elif event_type == EventType.CALL_ANSWERED:
            logger.info(f"Call answered: {call_id}")
            response_data["message"] = "Call answered"
            
        elif event_type == EventType.CALL_HANGUP:
            logger.info(f"Call ended: {call_id}")
            response_data["message"] = "Call hangup processed"

            # Extract call duration from hangup payload (if available)
            hangup_payload = event_data.get("call", {})
            call_duration_seconds = int(hangup_payload.get("duration_ms", 0) / 1000)

            # Collect full transcript for this call
            full_transcript = ""
            try:
                with get_db_cursor() as cursor:
                    cursor.execute("""
                        SELECT transcript_text
                        FROM genesis_bridge.call_transcripts
                        WHERE call_id = %s AND is_partial = FALSE
                        ORDER BY created_at ASC
                    """, (call_id,))
                    rows = cursor.fetchall()
                    full_transcript = " ".join(
                        r.get("transcript_text", "") for r in rows if r.get("transcript_text")
                    )
            except Exception as e:
                logger.warning(f"Could not fetch transcripts for RLM processing on call {call_id}: {e}")

            # === RLM SHADOW MODE INTEGRATION ===
            # Runs AFTER call ends. Logs only — does NOT alter call behaviour.
            if _rlm_gateway is not None and full_transcript:
                try:
                    import asyncio as _asyncio
                    rlm_result = _asyncio.create_task(
                        _rlm_gateway.process_interaction(
                            call_id=call_id,
                            transcript=full_transcript,
                            caller_number=caller_number,
                            call_duration_seconds=call_duration_seconds,
                            outcome="completed",
                        )
                    )
                    response_data["rlm_shadow"] = "queued"
                    logger.info(f"RLM shadow processing queued for call {call_id}")
                except Exception as e:
                    logger.warning(f"RLM shadow mode error (non-fatal) for call {call_id}: {e}")
                    response_data["rlm_shadow"] = "error"
            # === END RLM SHADOW MODE ===

            # Analyze complete call for any missed directives
            try:
                context = analyze_transcript_context(call_id)
                if context.get("directive_detected") and context.get("has_context"):
                    logger.info(f"Multi-turn directive detected in completed call {call_id}")
                    response_data["post_call_analysis"] = context
            except Exception as e:
                logger.error(f"Failed to analyze call context: {e}")
            
        elif event_type == EventType.CALL_MACHINE_DETECTION:
            detection_result = event_data.get("machine_detection", {})
            is_machine = detection_result.get("result") == "machine"
            logger.info(f"Machine detection result for {call_id}: {'machine' if is_machine else 'human'}")
            response_data["detection_result"] = detection_result.get("result")
            response_data["is_voicemail"] = is_machine
            
        else:
            logger.info(f"Unhandled event type: {event_type}")
            response_data["status"] = "unhandled_event_type"
    
    except Exception as e:
        logger.error(f"Error processing webhook event: {e}", exc_info=True)
        response_data["error"] = str(e)
        response_data["status"] = "error"
    
    return JSONResponse(content=response_data, status_code=200)


@app.get("/bridge/webhook/telnyx/health")
async def health_check() -> JSONResponse:
    """Health check endpoint."""
    try:
        with get_db_cursor() as cursor:
            cursor.execute("SELECT 1")
        return JSONResponse(content={"status": "healthy", "service": "telnyx-webhook"})
    except Exception as e:
        logger.error(f"Health check failed: {e}")
        raise HTTPException(status_code=503, detail="Service unhealthy")


@app.get("/bridge/webhook/telnyx/calls/{call_id}/transcripts")
async def get_call_transcripts(
    call_id: str,
    x_api_key: str = Header(...)
) -> JSONResponse:
    """Get all transcripts for a specific call."""
    if x_api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API key")
    
    with get_db_cursor() as cursor:
        cursor.execute("""
            SELECT id, call_id, caller_number, transcript_text, is_directive, 
                   directive_id, confidence_score, is_partial, created_at
            FROM genesis_bridge.call_transcripts
            WHERE call_id = %s
            ORDER BY created_at ASC
        """, (call_id,))
        
        transcripts = cursor.fetchall()
    
    return JSONResponse(content={"call_id": call_id, "transcripts": transcripts})


@app.get("/bridge/webhook/telnyx/directives")
async def list_directives(
    status: Optional[str] = None,
    limit: int = 50,
    x_api_key: str = Header(...)
) -> JSONResponse:
    """List directives created from voice transcriptions."""
    if x_api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API key")
    
    with get_db_cursor() as cursor:
        query = """
            SELECT cq.id, cq.call_id, cq.caller_number, cq.command_text, 
                   cq.status, cq.priority, cq.created_at, ct.confidence_score
            FROM genesis_bridge.command_queue cq
            LEFT JOIN genesis_bridge.call_transcripts ct 
                ON ct.call_id = cq.call_id AND ct.is_directive = TRUE
            WHERE cq.command_type = 'voice_transcript'
        """
        params = []
        
        if status:
            query += " AND cq.status = %s"
            params.append(status)
        
        query += " ORDER BY cq.created_at DESC LIMIT %s"
        params.append(limit)
        
        cursor.execute(query, params)
        directives = cursor.fetchall()
    
    return JSONResponse(content={"directives": directives, "count": len(directives)})


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        "telnyx_webhook:app",
        host="0.0.0.0",
        port=8000,
        reload=False,
        log_level="info"
    )