"""
Genesis Talking Website Widget - FastAPI Server
Production-ready API for voice-enabled website widgets
"""

from fastapi import FastAPI, HTTPException, Depends, Header, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from contextlib import asynccontextmanager
import os
import secrets
from datetime import datetime, timedelta
from typing import Optional, List
import redis

# Import local modules
from models import (
    ConversationRequest, ConversationResponse,
    TextConversationRequest, TextConversationResponse,
    WidgetConfig, LeadCaptureRequest, Lead, LeadListResponse,
    AnalyticsResponse, BusinessCreateRequest, Business,
    HealthResponse, ErrorResponse, LeadStatus
)
from database import db
from memory_handler import memory
from voice_handler import voice_handler
from tenant_manager import tenant_manager

# Import Elestio Redis config
import sys
sys.path.append('/mnt/e/genesis-system/data/genesis-memory')
from elestio_config import RedisConfig


# === STARTUP/SHUTDOWN ===

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Startup and shutdown events."""
    # Startup
    print("🚀 Genesis Widget API starting...")

    # Initialize database schema
    print("📊 Initializing database schema...")
    db.init_schema()

    # Test connections
    print("🔌 Testing service connections...")
    services_ok = True

    if not db.test_connection():
        print("❌ PostgreSQL connection failed")
        services_ok = False
    else:
        print("✅ PostgreSQL connected")

    if not memory.test_connection():
        print("⚠️  Qdrant connection failed (non-critical)")
    else:
        print("✅ Qdrant connected")

    if not voice_handler.test_connection():
        print("⚠️  Telnyx connection failed (non-critical)")
    else:
        print("✅ Telnyx connected")

    if not services_ok:
        raise RuntimeError("Critical services unavailable")

    print("✅ Genesis Widget API ready")

    yield

    # Shutdown
    print("🛑 Genesis Widget API shutting down...")


# === APP INITIALIZATION ===

app = FastAPI(
    title="Genesis Talking Website Widget API",
    description="Voice-enabled AI chat widget for business websites",
    version="1.0.0",
    lifespan=lifespan
)

# Rate limiting
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# CORS
allowed_origins_env = os.getenv("ALLOWED_ORIGINS", "http://localhost:8888,http://127.0.0.1:8888")
allowed_origins = [origin.strip() for origin in allowed_origins_env.split(",")]

# For local development, allow all origins if "*" is in the list
if "*" in allowed_origins:
    allowed_origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=allowed_origins if allowed_origins != ["*"] else ["*"],
    allow_credentials=True if allowed_origins != ["*"] else False,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Redis for session caching
redis_config = RedisConfig.get_connection_params()
redis_client = redis.Redis(**redis_config)


# === DEPENDENCIES ===

async def verify_api_key(x_api_key: str = Header(...)) -> dict:
    """Verify widget API key and return business context."""
    business = tenant_manager.get_business_by_api_key(x_api_key)
    if not business:
        raise HTTPException(status_code=401, detail="Invalid API key")
    return business


def get_business_from_id_or_key(
    business_id: str,
    x_api_key: Optional[str] = Header(None)
) -> dict:
    """
    Get business context from API key (header) or business_id (from request).
    For local testing, allows business_id without API key.
    """
    # Try API key first (production)
    if x_api_key:
        business = tenant_manager.get_business_by_api_key(x_api_key)
        if business:
            return business

    # Fallback: Use business_id (local testing)
    if business_id:
        business = tenant_manager.get_business_by_id(business_id)
        if business:
            return business

    raise HTTPException(status_code=401, detail="Invalid API key or business ID")


async def verify_admin_token(authorization: str = Header(...)) -> bool:
    """Verify admin JWT token (placeholder for production)."""
    # TODO: Implement proper JWT verification
    # For MVP, use simple bearer token
    expected_token = os.getenv("ADMIN_TOKEN", "genesis_admin_secret_change_me")
    if authorization != f"Bearer {expected_token}":
        raise HTTPException(status_code=403, detail="Unauthorized")
    return True


# === ENDPOINTS ===

@app.get("/", response_model=dict)
async def root():
    """API root endpoint."""
    return {
        "service": "Genesis Talking Website Widget API",
        "version": "1.0.0",
        "status": "operational",
        "docs": "/docs"
    }


@app.get("/v1/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint."""
    services = {}

    # Test PostgreSQL
    services["postgresql"] = "healthy" if db.test_connection() else "unhealthy"

    # Test Redis
    try:
        redis_client.ping()
        services["redis"] = "healthy"
    except:
        services["redis"] = "unhealthy"

    # Test Qdrant
    services["qdrant"] = "healthy" if memory.test_connection() else "degraded"

    # Test Telnyx
    services["telnyx"] = "healthy" if voice_handler.test_connection() else "degraded"

    overall_status = "healthy" if all(
        s in ["healthy", "degraded"] for s in services.values()
    ) else "unhealthy"

    return HealthResponse(
        status=overall_status,
        version="1.0.0",
        services=services,
        timestamp=datetime.now()
    )


@app.post("/v1/conversation", response_model=ConversationResponse)
@limiter.limit("60/minute")
async def process_voice_conversation(
    request: Request,
    conversation_req: ConversationRequest,
    business: dict = Depends(verify_api_key)
):
    """
    Process voice conversation from widget.

    Flow:
    1. Receive audio from visitor
    2. Transcribe to text (STT)
    3. Generate AI response based on business context
    4. Convert response to speech (TTS)
    5. Return audio + transcript
    """
    try:
        # Get or create session context from Redis
        session_key = f"session:{conversation_req.visitor_id}"
        session_data = redis_client.get(session_key)

        if session_data:
            import json
            session_context = json.loads(session_data)
        else:
            session_context = conversation_req.session_context or {}

        # Transcribe audio
        visitor_message = await voice_handler.process_voice_input(
            conversation_req.audio_data
        )

        # Get business context from memory
        business_context = {
            "name": business["name"],
            "agent_name": business.get("agent_name", "Sarah"),
            "website_url": business.get("website_url"),
            "knowledge_base": business.get("config", {}).get("knowledge_base", {})
        }

        # Generate AI response
        ai_response_text, updated_session = await voice_handler.generate_ai_response(
            visitor_message,
            business_context,
            session_context
        )

        # Convert to speech
        audio_response = await voice_handler.generate_voice_response(
            ai_response_text,
            business["id"]
        )

        # Generate conversation ID
        conversation_id = f"conv_{secrets.token_hex(8)}"

        # Store conversation in database
        with db.get_cursor() as cursor:
            cursor.execute("""
                INSERT INTO widget_conversations
                (id, business_id, visitor_id, mode, transcript, session_context)
                VALUES (%s, %s, %s, 'voice', %s, %s)
            """, (
                conversation_id,
                business["id"],
                conversation_req.visitor_id,
                f"Visitor: {visitor_message}\nAI: {ai_response_text}",
                updated_session
            ))

        # Check for lead capture
        lead_info = await voice_handler.detect_lead_info(visitor_message, updated_session)
        lead_captured = False

        if lead_info:
            lead_id = f"lead_{secrets.token_hex(8)}"
            with db.get_cursor() as cursor:
                cursor.execute("""
                    INSERT INTO widget_leads
                    (id, business_id, visitor_id, conversation_id, name, phone, email, status)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, 'new')
                """, (
                    lead_id,
                    business["id"],
                    conversation_req.visitor_id,
                    conversation_id,
                    lead_info.get("name"),
                    lead_info.get("phone"),
                    lead_info.get("email")
                ))
            lead_captured = True
            updated_session["lead_captured"] = True

        # Save updated session to Redis (30 min TTL)
        import json
        redis_client.setex(session_key, 1800, json.dumps(updated_session))

        return ConversationResponse(
            conversation_id=conversation_id,
            audio_response=audio_response,
            transcript=ai_response_text,
            session_context=updated_session,
            lead_captured=lead_captured
        )

    except Exception as e:
        print(f"Conversation error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/v1/conversation/text", response_model=TextConversationResponse)
@limiter.limit("60/minute")
async def process_text_conversation(
    text_req: TextConversationRequest,
    request: Request,
    x_api_key: Optional[str] = Header(None)
):
    """
    Process text-only conversation (fallback for no-mic scenarios).
    """
    try:
        # Get business context
        business = get_business_from_id_or_key(text_req.business_id, x_api_key)

        # Get session context
        session_key = f"session:{text_req.visitor_id}"
        session_data = redis_client.get(session_key)

        if session_data:
            import json
            session_context = json.loads(session_data)
        else:
            session_context = text_req.session_context or {}

        # Business context
        business_context = {
            "name": business["name"],
            "agent_name": business.get("agent_name", "Sarah"),
            "knowledge_base": business.get("config", {}).get("knowledge_base", {})
        }

        # Generate AI response
        ai_response, updated_session = await voice_handler.generate_ai_response(
            text_req.message,
            business_context,
            session_context
        )

        # Generate conversation ID
        conversation_id = f"conv_{secrets.token_hex(8)}"

        # Store conversation
        import json as json_module
        with db.get_cursor() as cursor:
            cursor.execute("""
                INSERT INTO widget_conversations
                (id, business_id, visitor_id, mode, transcript, session_context)
                VALUES (%s, %s, %s, 'text', %s, %s)
            """, (
                conversation_id,
                business["id"],
                text_req.visitor_id,
                f"Visitor: {text_req.message}\nAI: {ai_response}",
                json_module.dumps(updated_session)
            ))

        # Check for lead capture
        lead_info = await voice_handler.detect_lead_info(text_req.message, updated_session)
        lead_captured = False

        if lead_info:
            lead_id = f"lead_{secrets.token_hex(8)}"
            with db.get_cursor() as cursor:
                cursor.execute("""
                    INSERT INTO widget_leads
                    (id, business_id, visitor_id, conversation_id, name, phone, email, status)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, 'new')
                """, (
                    lead_id,
                    business["id"],
                    text_req.visitor_id,
                    conversation_id,
                    lead_info.get("name"),
                    lead_info.get("phone"),
                    lead_info.get("email")
                ))
            lead_captured = True

        # Save session
        import json
        redis_client.setex(session_key, 1800, json.dumps(updated_session))

        return TextConversationResponse(
            conversation_id=conversation_id,
            message=ai_response,
            session_context=updated_session,
            lead_captured=lead_captured
        )

    except Exception as e:
        print(f"Text conversation error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/v1/widget/config/{business_id}", response_model=WidgetConfig)
async def get_widget_config(
    business_id: str,
    business: dict = Depends(verify_api_key)
):
    """Get widget configuration for embedding."""
    if business["id"] != business_id:
        raise HTTPException(status_code=403, detail="Business ID mismatch")

    return WidgetConfig(
        business_id=business["id"],
        business_name=business["name"],
        agent_name=business.get("agent_name", "Sarah"),
        greeting_message=business.get("greeting_message", "Hi! How can I help?"),
        primary_color=business.get("primary_color", "#6366F1"),
        position=business.get("position", "bottom-right"),
        avatar_url=business.get("avatar_url")
    )


@app.post("/v1/leads")
@limiter.limit("30/minute")
async def capture_lead(
    request: Request,
    lead_req: LeadCaptureRequest,
    business: dict = Depends(verify_api_key)
):
    """Manually capture lead information."""
    try:
        lead_id = f"lead_{secrets.token_hex(8)}"

        with db.get_cursor() as cursor:
            cursor.execute("""
                INSERT INTO widget_leads
                (id, business_id, visitor_id, conversation_id, name, phone, email, notes, status)
                VALUES (%s, %s, %s, %s, %s, %s, %s, %s, 'new')
                RETURNING *
            """, (
                lead_id,
                business["id"],
                lead_req.visitor_id,
                lead_req.conversation_id,
                lead_req.name,
                lead_req.phone,
                lead_req.email,
                lead_req.notes
            ))

            lead = dict(cursor.fetchone())

        return {"lead_id": lead["id"], "status": "captured"}

    except Exception as e:
        print(f"Lead capture error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/v1/leads/{business_id}", response_model=LeadListResponse)
async def list_leads(
    business_id: str,
    page: int = 1,
    page_size: int = 50,
    status: Optional[str] = None,
    business: dict = Depends(verify_api_key),
    _admin: bool = Depends(verify_admin_token)
):
    """List leads for a business (admin only)."""
    if business["id"] != business_id:
        raise HTTPException(status_code=403, detail="Business ID mismatch")

    offset = (page - 1) * page_size

    with db.get_cursor() as cursor:
        # Count total
        if status:
            cursor.execute("""
                SELECT COUNT(*) as count FROM widget_leads
                WHERE business_id = %s AND status = %s
            """, (business_id, status))
        else:
            cursor.execute("""
                SELECT COUNT(*) as count FROM widget_leads
                WHERE business_id = %s
            """, (business_id,))

        total = cursor.fetchone()["count"]

        # Fetch leads
        if status:
            cursor.execute("""
                SELECT * FROM widget_leads
                WHERE business_id = %s AND status = %s
                ORDER BY created_at DESC
                LIMIT %s OFFSET %s
            """, (business_id, status, page_size, offset))
        else:
            cursor.execute("""
                SELECT * FROM widget_leads
                WHERE business_id = %s
                ORDER BY created_at DESC
                LIMIT %s OFFSET %s
            """, (business_id, page_size, offset))

        leads = [Lead(**dict(row)) for row in cursor.fetchall()]

    return LeadListResponse(
        leads=leads,
        total=total,
        page=page,
        page_size=page_size
    )


@app.get("/v1/analytics/{business_id}", response_model=AnalyticsResponse)
async def get_analytics(
    business_id: str,
    period: str = "last_30_days",
    business: dict = Depends(verify_api_key),
    _admin: bool = Depends(verify_admin_token)
):
    """Get conversation analytics for a business."""
    if business["id"] != business_id:
        raise HTTPException(status_code=403, detail="Business ID mismatch")

    # Calculate date range
    if period == "last_7_days":
        start_date = datetime.now() - timedelta(days=7)
    elif period == "last_30_days":
        start_date = datetime.now() - timedelta(days=30)
    else:
        start_date = datetime.now() - timedelta(days=30)

    with db.get_cursor() as cursor:
        # Total conversations
        cursor.execute("""
            SELECT
                COUNT(*) as total,
                COUNT(CASE WHEN mode = 'voice' THEN 1 END) as voice,
                COUNT(CASE WHEN mode = 'text' THEN 1 END) as text,
                AVG(duration_seconds) as avg_duration
            FROM widget_conversations
            WHERE business_id = %s AND created_at >= %s
        """, (business_id, start_date))

        stats = dict(cursor.fetchone())

        # Total leads
        cursor.execute("""
            SELECT COUNT(*) as count FROM widget_leads
            WHERE business_id = %s AND created_at >= %s
        """, (business_id, start_date))

        leads_count = cursor.fetchone()["count"]

        # Top pages (from session context)
        # Placeholder - would need to parse session_context JSONB
        top_pages = []

    conversion_rate = (leads_count / stats["total"] * 100) if stats["total"] > 0 else 0

    return AnalyticsResponse(
        business_id=business_id,
        period=period,
        total_conversations=stats["total"] or 0,
        total_leads=leads_count,
        conversion_rate=round(conversion_rate, 2),
        avg_conversation_duration=stats["avg_duration"] or 0,
        voice_conversations=stats["voice"] or 0,
        text_conversations=stats["text"] or 0,
        top_pages=top_pages
    )


@app.post("/v1/businesses", response_model=dict)
async def create_business(
    business_req: BusinessCreateRequest,
    _admin: bool = Depends(verify_admin_token)
):
    """Create a new business account (admin only)."""
    try:
        business = tenant_manager.create_business(
            name=business_req.name,
            website_url=business_req.website_url,
            greeting_message=business_req.greeting_message,
            agent_name=business_req.agent_name,
            primary_color=business_req.primary_color,
            knowledge_base=business_req.knowledge_base
        )

        return {
            "business_id": business["id"],
            "api_key": business["api_key"],
            "message": "Business created successfully"
        }

    except Exception as e:
        print(f"Business creation error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


# === ERROR HANDLERS ===

@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
    """Custom HTTP exception handler."""
    return JSONResponse(
        status_code=exc.status_code,
        content=ErrorResponse(
            error=exc.detail,
            request_id=request.headers.get("X-Request-ID")
        ).dict()
    )


@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
    """Catch-all exception handler."""
    return JSONResponse(
        status_code=500,
        content=ErrorResponse(
            error="Internal server error",
            detail=str(exc) if os.getenv("DEBUG") else None,
            request_id=request.headers.get("X-Request-ID")
        ).dict()
    )


# === STARTUP ===

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        "main:app",
        host="0.0.0.0",
        port=int(os.getenv("PORT", 8000)),
        reload=os.getenv("DEBUG", "false").lower() == "true",
        log_level="info"
    )
