import os
import json
import csv
import io
import logging
import functools
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from decimal import Decimal
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException, Query, Header, Request, Response
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
import psycopg2
from psycopg2 import sql
from psycopg2.extras import RealDictCursor, Json

from core.secrets_loader import get_postgres_config

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("audit_bridge")

API_KEY = "aiva-voice-bridge-secret-key-2024"
SCHEMA = "genesis_bridge"

# Event types
class AuditEventType:
    DIRECTIVE_CREATED = "directive_created"
    DIRECTIVE_UPDATED = "directive_updated"
    DIRECTIVE_COMPLETED = "directive_completed"
    DIRECTIVE_FAILED = "directive_failed"
    DIRECTIVE_CANCELLED = "directive_cancelled"
    STATUS_QUERIED = "status_queried"
    STATUS_RESPONSE = "status_response"
    AUTH_SUCCESS = "auth_success"
    AUTH_FAILED = "auth_failed"
    SYSTEM_STARTUP = "system_startup"
    SYSTEM_SHUTDOWN = "system_shutdown"
    SYSTEM_RECONNECT = "system_reconnect"
    API_REQUEST = "api_request"
    API_RESPONSE = "api_response"

# Actors
class AuditActor:
    KINAN = "kinan"
    CLAUDE = "claude"
    AIVA = "aiva"
    SYSTEM = "system"

# Database connection
def get_db_connection():
    """Create a new database connection."""
    pg_config = get_postgres_config()
    if not pg_config.is_configured:
        raise psycopg2.Error("PostgreSQL is not configured.")
    
    conn_params = {
        "host": pg_config.host,
        "port": pg_config.port,
        "user": pg_config.user,
        "password": pg_config.password,
        "dbname": pg_config.dbname,
        "sslmode": pg_config.sslmode
    }
    return psycopg2.connect(**conn_params)

def init_database():
    """Initialize database schema and tables."""
    conn = get_db_connection()
    try:
        with conn.cursor() as cur:
            # Create schema
            cur.execute(sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(
                sql.Identifier(SCHEMA)
            ))
            
            # Create audit_log table
            cur.execute(sql.SQL("""
                CREATE TABLE IF NOT EXISTS {}.audit_log (
                    id BIGSERIAL PRIMARY KEY,
                    event_type VARCHAR(50) NOT NULL,
                    actor VARCHAR(50) NOT NULL,
                    directive_id BIGINT,
                    details JSONB DEFAULT '{}'::jsonb,
                    ip_address INET,
                    user_agent TEXT,
                    created_at TIMESTAMPTZ DEFAULT NOW()
                )
            """).format(sql.Identifier(SCHEMA)))
            
            # Create audit_archive table
            cur.execute(sql.SQL("""
                CREATE TABLE IF NOT EXISTS {}.audit_archive (
                    id BIGINT NOT NULL,
                    event_type VARCHAR(50) NOT NULL,
                    actor VARCHAR(50) NOT NULL,
                    directive_id BIGINT,
                    details JSONB,
                    ip_address INET,
                    user_agent TEXT,
                    created_at TIMESTAMPTZ,
                    archived_at TIMESTAMPTZ DEFAULT NOW(),
                    PRIMARY KEY (id, archived_at)
                )
            """).format(sql.Identifier(SCHEMA)))
            
            # Create indexes
            cur.execute(sql.SQL("""
                CREATE INDEX IF NOT EXISTS idx_audit_log_created_at 
                ON {}.audit_log(created_at DESC)
            """).format(sql.Identifier(SCHEMA)))
            
            cur.execute(sql.SQL("""
                CREATE INDEX IF NOT EXISTS idx_audit_log_event_type 
                ON {}.audit_log(event_type)
            """).format(sql.Identifier(SCHEMA)))
            
            cur.execute(sql.SQL("""
                CREATE INDEX IF NOT EXISTS idx_audit_log_actor 
                ON {}.audit_log(actor)
            """).format(sql.Identifier(SCHEMA)))
            
            cur.execute(sql.SQL("""
                CREATE INDEX IF NOT EXISTS idx_audit_log_directive_id 
                ON {}.audit_log(directive_id)
            """).format(sql.Identifier(SCHEMA)))
            
            cur.execute(sql.SQL("""
                CREATE INDEX IF NOT EXISTS idx_audit_archive_created_at 
                ON {}.audit_archive(created_at DESC)
            """).format(sql.Identifier(SCHEMA)))
            
            conn.commit()
            logger.info("Database initialized successfully")
    except Exception as e:
        conn.rollback()
        logger.error(f"Database initialization failed: {e}")
        raise
    finally:
        conn.close()

# Models
class AuditLogCreate(BaseModel):
    event_type: str
    actor: str
    directive_id: Optional[int] = None
    details: Optional[Dict[str, Any]] = {}
    ip_address: Optional[str] = None
    user_agent: Optional[str] = None

class AuditLogResponse(BaseModel):
    id: int
    event_type: str
    actor: str
    directive_id: Optional[int]
    details: Dict[str, Any]
    ip_address: Optional[str]
    user_agent: Optional[str]
    created_at: datetime

class AuditQueryParams(BaseModel):
    event_type: Optional[str] = None
    actor: Optional[str] = None
    directive_id: Optional[int] = None
    start_date: Optional[datetime] = None
    end_date: Optional[datetime] = None
    page: int = Field(default=1, ge=1)
    page_size: int = Field(default=50, ge=1, le=1000)

class AuditSummaryResponse(BaseModel):
    date: str
    total_events: int
    directives_created: int
    directives_completed: int
    directives_failed: int
    unique_directives: int
    avg_response_time_seconds: Optional[float]
    error_rate: float

# Audit Logger Service
class AuditLogger:
    """Service for logging audit events to the database."""
    
    def __init__(self):
        self._conn = None
    
    def _get_connection(self):
        """Get or create database connection."""
        if self._conn is None or self._conn.closed:
            self._conn = get_db_connection()
        return self._conn
    
    def log_event(
        self,
        event_type: str,
        actor: str,
        directive_id: Optional[int] = None,
        details: Optional[Dict[str, Any]] = None,
        ip_address: Optional[str] = None,
        user_agent: Optional[str] = None
    ) -> int:
        """Log an audit event to the database."""
        conn = self._get_connection()
        try:
            with conn.cursor() as cur:
                cur.execute(
                    sql.SQL("""
                        INSERT INTO {}.audit_log 
                        (event_type, actor, directive_id, details, ip_address, user_agent)
                        VALUES (%s, %s, %s, %s, %s, %s)
                        RETURNING id
                    """).format(sql.Identifier(SCHEMA)),
                    (
                        event_type,
                        actor,
                        directive_id,
                        Json(details) if details else Json({}),
                        ip_address,
                        user_agent
                    )
                )
                result = cur.fetchone()
                conn.commit()
                event_id = result[0] if result else None
                logger.debug(f"Audit event logged: {event_type} by {actor}, id={event_id}")
                return event_id
        except Exception as e:
            conn.rollback()
            logger.error(f"Failed to log audit event: {e}")
            raise
    
    def log_directive_created(
        self,
        directive_id: int,
        actor: str,
        details: Dict[str, Any],
        ip_address: Optional[str] = None,
        user_agent: Optional[str] = None
    ) -> int:
        """Log directive creation."""
        return self.log_event(
            event_type=AuditEventType.DIRECTIVE_CREATED,
            actor=actor,
            directive_id=directive_id,
            details=details,
            ip_address=ip_address,
            user_agent=user_agent
        )
    
    def log_directive_completed(
        self,
        directive_id: int,
        actor: str,
        details: Optional[Dict[str, Any]] = None,
        ip_address: Optional[str] = None,
        user_agent: Optional[str] = None
    ) -> int:
        """Log directive completion."""
        return self.log_event(
            event_type=AuditEventType.DIRECTIVE_COMPLETED,
            actor=actor,
            directive_id=directive_id,
            details=details or {},
            ip_address=ip_address,
            user_agent=user_agent
        )
    
    def log_directive_failed(
        self,
        directive_id: int,
        actor: str,
        error_message: str,
        ip_address: Optional[str] = None,
        user_agent: Optional[str] = None
    ) -> int:
        """Log directive failure."""
        return self.log_event(
            event_type=AuditEventType.DIRECTIVE_FAILED,
            actor=actor,
            directive_id=directive_id,
            details={"error": error_message},
            ip_address=ip_address,
            user_agent=user_agent
        )
    
    def log_status_query(
        self,
        directive_id: Optional[int],
        actor: str,
        ip_address: Optional[str] = None,
        user_agent: Optional[str] = None
    ) -> int:
        """Log status query."""
        return self.log_event(
            event_type=AuditEventType.STATUS_QUERIED,
            actor=actor,
            directive_id=directive_id,
            details={},
            ip_address=ip_address,
            user_agent=user_agent
        )
    
    def log_auth_attempt(
        self,
        success: bool,
        actor: str,
        ip_address: Optional[str] = None,
        user_agent: Optional[str] = None,
        reason: Optional[str] = None
    ) -> int:
        """Log authentication attempt."""
        return self.log_event(
            event_type=AuditEventType.AUTH_SUCCESS if success else AuditEventType.AUTH_FAILED,
            actor=actor,
            details={"reason": reason} if reason else {},
            ip_address=ip_address,
            user_agent=user_agent
        )
    
    def log_system_event(
        self,
        event_type: str,
        details: Optional[Dict[str, Any]] = None
    ) -> int:
        """Log system event."""
        return self.log_event(
            event_type=event_type,
            actor=AuditActor.SYSTEM,
            details=details or {}
        )
    
    def close(self):
        """Close database connection."""
        if self._conn and not self._conn.closed:
            self._conn.close()

# Global audit logger instance
audit_logger = AuditLogger()

# Audit Queries Service
class AuditQueries:
    """Predefined queries for common audit reports."""
    
    @staticmethod
    def get_paginated_audit_logs(
        event_type: Optional[str] = None,
        actor: Optional[str] = None,
        directive_id: Optional[int] = None,
        start_date: Optional[datetime] = None,
        end_date: Optional[datetime] = None,
        page: int = 1,
        page_size: int = 50
    ) -> tuple[List[Dict[str, Any]], int]:
        """Get paginated audit logs with optional filters."""
        conn = get_db_connection()
        try:
            conditions = []
            params = []
            param_idx = 1
            
            if event_type:
                conditions.append(f"event_type = ${param_idx}")
                params.append(event_type)
                param_idx += 1
            
            if actor:
                conditions.append(f"actor = ${param_idx}")
                params.append(actor)
                param_idx += 1
            
            if directive_id is not None:
                conditions.append(f"directive_id = ${param_idx}")
                params.append(directive_id)
                param_idx += 1
            
            if start_date:
                conditions.append(f"created_at >= ${param_idx}")
                params.append(start_date)
                param_idx += 1
            
            if end_date:
                conditions.append(f"created_at <= ${param_idx}")
                params.append(end_date)
                param_idx += 1
            
            where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
            
            # Get total count
            count_sql = sql.SQL("""
                SELECT COUNT(*) as total 
                FROM {}.audit_log
                {}
            """).format(sql.Identifier(SCHEMA), sql.SQL(where_clause))
            
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(count_sql, params)
                total = cur.fetchone()["total"]
                
                # Get paginated results
                offset = (page - 1) * page_size
                query_sql = sql.SQL("""
                    SELECT id, event_type, actor, directive_id, details, 
                           ip_address, user_agent, created_at
                    FROM {}.audit_log
                    {}
                    ORDER BY created_at DESC
                    LIMIT ${} OFFSET ${}
                """).format(
                    sql.Identifier(SCHEMA),
                    sql.SQL(where_clause),
                    sql.Identifier(f"param_{param_idx}"),
                    sql.Identifier(f"param_{param_idx + 1}")
                )
                
                params.extend([page_size, offset])
                cur.execute(query_sql, params)
                results = [dict(row) for row in cur.fetchall()]
                
                return results, total
        finally:
            conn.close()
    
    @staticmethod
    def get_daily_summary(date: Optional[datetime] = None) -> AuditSummaryResponse:
        """Get daily summary statistics."""
        if date is None:
            date = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
        else:
            date = date.replace(hour=0, minute=0, second=0, microsecond=0)
        
        next_date = date + timedelta(days=1)
        
        conn = get_db_connection()
        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                # Total events
                cur.execute(
                    sql.SQL("""
                        SELECT COUNT(*) as total
                        FROM {}.audit_log
                        WHERE created_at >= %s AND created_at < %s
                    """).format(sql.Identifier(SCHEMA)),
                    (date, next_date)
                )
                total_events = cur.fetchone()["total"]
                
                # Directives created
                cur.execute(
                    sql.SQL("""
                        SELECT COUNT(*) as count
                        FROM {}.audit_log
                        WHERE event_type = %s AND created_at >= %s AND created_at < %s
                    """).format(sql.Identifier(SCHEMA)),
                    (AuditEventType.DIRECTIVE_CREATED, date, next_date)
                )
                directives_created = cur.fetchone()["count"]
                
                # Directives completed
                cur.execute(
                    sql.SQL("""
                        SELECT COUNT(*) as count
                        FROM {}.audit_log
                        WHERE event_type = %s AND created_at >= %s AND created_at < %s
                    """).format(sql.Identifier(SCHEMA)),
                    (AuditEventType.DIRECTIVE_COMPLETED, date, next_date)
                )
                directives_completed = cur.fetchone()["count"]
                
                # Directives failed
                cur.execute(
                    sql.SQL("""
                        SELECT COUNT(*) as count
                        FROM {}.audit_log
                        WHERE event_type = %s AND created_at >= %s AND created_at < %s
                    """).format(sql.Identifier(SCHEMA)),
                    (AuditEventType.DIRECTIVE_FAILED, date, next_date)
                )
                directives_failed = cur.fetchone()["count"]
                
                # Unique directives
                cur.execute(
                    sql.SQL("""
                        SELECT COUNT(DISTINCT directive_id) as count
                        FROM {}.audit_log
                        WHERE directive_id IS NOT NULL 
                        AND created_at >= %s AND created_at < %s
                    """).format(sql.Identifier(SCHEMA)),
                    (date, next_date)
                )
                unique_directives = cur.fetchone()["count"]
                
                # Average response time (time from created to completed)
                cur.execute(
                    sql.SQL("""
                        SELECT AVG(completed_at - created_at) as avg_time
                        FROM (
                            SELECT 
                                l1.created_at as created_at,
                                l2.created_at as completed_at
                            FROM {}.audit_log l1
                            JOIN {}.audit_log l2 ON l1.directive_id = l2.directive_id
                            WHERE l1.event_type = %s
                            AND l2.event_type = %s
                            AND l1.created_at >= %s AND l1.created_at < %s
                        ) sub
                    """).format(sql.Identifier(SCHEMA), sql.Identifier(SCHEMA)),
                    (AuditEventType.DIRECTIVE_CREATED, AuditEventType.DIRECTIVE_COMPLETED, date, next_date)
                )
                avg_time_result = cur.fetchone()["avg_time"]
                avg_response_time = None
                if avg_time_result:
                    avg_response_time = avg_time_result.total_seconds()
                
                # Error rate
                error_rate = 0.0
                if directives_created > 0:
                    error_rate = (directives_failed / directives_created) * 100
                
                return AuditSummaryResponse(
                    date=date.strftime("%Y-%m-%d"),
                    total_events=total_events,
                    directives_created=directives_created,
                    directives_completed=directives_completed,
                    directives_failed=directives_failed,
                    unique_directives=unique_directives,
                    avg_response_time_seconds=avg_response_time,
                    error_rate=round(error_rate, 2)
                )
        finally:
            conn.close()
    
    @staticmethod
    def get_range_summary(start_date: datetime, end_date: datetime) -> List[AuditSummaryResponse]:
        """Get summary for a date range."""
        summaries = []
        current_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
        end_date = end_date.replace(hour=0, minute=0, second=0, microsecond=0)
        
        while current_date <= end_date:
            try:
                summary = AuditQueries.get_daily_summary(current_date)
                summaries.append(summary)
            except Exception as e:
                logger.warning(f"Failed to get summary for {current_date}: {e}")
            current_date += timedelta(days=1)
        
        return summaries
    
    @staticmethod
    def get_directive_timeline(directive_id: int) -> List[Dict[str, Any]]:
        """Get full timeline for a directive."""
        conn = get_db_connection()
        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(
                    sql.SQL("""
                        SELECT id, event_type, actor, details, created_at
                        FROM {}.audit_log
                        WHERE directive_id = %s
                        ORDER BY created_at ASC
                    """).format(sql.Identifier(SCHEMA)),
                    (directive_id,)
                )
                return [dict(row) for row in cur.fetchall()]
        finally:
            conn.close()
    
    @staticmethod
    def export_to_csv(
        event_type: Optional[str] = None,
        actor: Optional[str] = None,
        start_date: Optional[datetime] = None,
        end_date: Optional[datetime] = None
    ) -> io.StringIO:
        """Export audit logs to CSV."""
        logs, _ = AuditQueries.get_paginated_audit_logs(
            event_type=event_type,
            actor=actor,
            start_date=start_date,
            end_date=end_date,
            page=1,
            page_size=100000  # Large number to get all
        )
        
        output = io.StringIO()
        if logs:
            fieldnames = ["id", "event_type", "actor", "directive_id", "details", 
                         "ip_address", "user_agent", "created_at"]
            writer = csv.DictWriter(output, fieldnames=fieldnames, extrasaction='ignore')
            writer.writeheader()
            
            for log in logs:
                log["created_at"] = log["created_at"].isoformat()
                log["details"] = json.dumps(log.get("details", {}))
                writer.writerow(log)
        
        output.seek(0)
        return output


# Cleanup Job
class AuditCleanup:
    """Cleanup job for archiving old audit logs."""
    
    ARCHIVE_DAYS = 90
    
    @classmethod
    def archive_old_logs(cls) -> int:
        """Archive logs older than ARCHIVE_DAYS to audit_archive table."""
        cutoff_date = datetime.utcnow() - timedelta(days=cls.ARCHIVE_DAYS)
        
        conn = get_db_connection()
        archived_count = 0
        try:
            with conn.cursor() as cur:
                # Insert into archive
                cur.execute(
                    sql.SQL("""
                        INSERT INTO {}.audit_archive 
                        (id, event_type, actor, directive_id, details, ip_address, user_agent, created_at)
                        SELECT id, event_type, actor, directive_id, details, ip_address, user_agent, created_at
                        FROM {}.audit_log
                        WHERE created_at < %s
                    """).format(sql.Identifier(SCHEMA), sql.Identifier(SCHEMA)),
                    (cutoff_date,)
                )
                archived_count = cur.rowcount
                
                # Delete from main table
                cur.execute(
                    sql.SQL("""
                        DELETE FROM {}.audit_log
                        WHERE created_at < %s
                    """).format(sql.Identifier(SCHEMA)),
                    (cutoff_date,)
                )
                
                conn.commit()
                logger.info(f"Archived {archived_count} audit logs older than {cutoff_date}")
                return archived_count
        except Exception as e:
            conn.rollback()
            logger.error(f"Archive cleanup failed: {e}")
            raise
        finally:
            conn.close()
    
    @classmethod
    def get_log_counts(cls) -> Dict[str, int]:
        """Get current log counts."""
        conn = get_db_connection()
        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute(
                    sql.SQL("""
                        SELECT 
                            (SELECT COUNT(*) FROM {}.audit_log) as active_logs,
                            (SELECT COUNT(*) FROM {}.audit_archive) as archived_logs
                    """).format(sql.Identifier(SCHEMA), sql.Identifier(SCHEMA))
                )
                result = cur.fetchone()
                return {
                    "active_logs": result["active_logs"],
                    "archived_logs": result["archived_logs"]
                }
        finally:
            conn.close()


# FastAPI Application
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Application lifespan handler."""
    # Startup
    logger.info("Starting AIVA Voice Command Bridge - Audit System")
    try:
        init_database()
        audit_logger.log_system_event(AuditEventType.SYSTEM_STARTUP, {
            "version": "1.0.0",
            "service": "audit_bridge"
        })
    except Exception as e:
        logger.error(f"Startup error: {e}")
    
    yield
    
    # Shutdown
    logger.info("Shutting down AIVA Voice Command Bridge - Audit System")
    try:
        audit_logger.log_system_event(AuditEventType.SYSTEM_SHUTDOWN, {})
        audit_logger.close()
    except Exception as e:
        logger.error(f"Shutdown error: {e}")

app = FastAPI(
    title="AIVA Voice Command Bridge - Audit System",
    version="1.0.0",
    lifespan=lifespan
)

def verify_api_key(x_api_key: str = Header(...)) -> str:
    """Verify API key."""
    if x_api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API key")
    return x_api_key

# API Endpoints
@app.get("/bridge/audit", response_model=Dict[str, Any])
async def get_audit_logs(
    request: Request,
    x_api_key: str = Header(...),
    event_type: Optional[str] = Query(None),
    actor: Optional[str] = Query(None),
    directive_id: Optional[int] = Query(None),
    start_date: Optional[str] = Query(None),
    end_date: Optional[str] = Query(None),
    page: int = Query(1, ge=1),
    page_size: int = Query(50, ge=1, le=1000)
):
    """Get paginated audit log with filters."""
    verify_api_key(x_api_key)
    
    # Log this API request
    client_ip = request.client.host if request.client else None
    user_agent = request.headers.get("user_agent")
    
    # Parse dates
    start_dt = None
    end_dt = None
    try:
        if start_date:
            start_dt = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
        if end_date:
            end_dt = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
    except ValueError as e:
        raise HTTPException(status_code=400, detail=f"Invalid date format: {e}")
    
    logs, total = AuditQueries.get_paginated_audit_logs(
        event_type=event_type,
        actor=actor,
        directive_id=directive_id,
        start_date=start_dt,
        end_date=end_dt,
        page=page,
        page_size=page_size
    )
    
    # Log the query
    audit_logger.log_event(
        event_type=AuditEventType.API_REQUEST,
        actor=AuditActor.CLAUDE,
        details={
            "endpoint": "/bridge/audit",
            "filters": {
                "event_type": event_type,
                "actor": actor,
                "directive_id": directive_id,
                "start_date": start_date,
                "end_date": end_date
            }
        },
        ip_address=client_ip,
        user_agent=user_agent
    )
    
    total_pages = (total + page_size - 1) // page_size
    
    return {
        "data": logs,
        "pagination": {
            "page": page,
            "page_size": page_size,
            "total": total,
            "total_pages": total_pages
        }
    }

@app.get("/bridge/audit/summary", response_model=List[AuditSummaryResponse])
async def get_audit_summary(
    request: Request,
    x_api_key: str = Header(...),
    start_date: Optional[str] = Query(None),
    end_date: Optional[str] = Query(None)
):
    """Get daily summary statistics."""
    verify_api_key(x_api_key)
    
    client_ip = request.client.host if request.client else None
    user_agent = request.headers.get("user_agent")
    
    # Parse dates
    start_dt = None
    end_dt = None
    try:
        if start_date:
            start_dt = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
        else:
            start_dt = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
        
        if end_date:
            end_dt = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
        else:
            end_dt = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=f"Invalid date format: {e}")
    
    # Log the query
    audit_logger.log_event(
        event_type=AuditEventType.API_REQUEST,
        actor=AuditActor.CLAUDE,
        details={
            "endpoint": "/bridge/audit/summary",
            "start_date": start_date,
            "end_date": end_date
        },
        ip_address=client_ip,
        user_agent=user_agent
    )
    
    summaries = AuditQueries.get_range_summary(start_dt, end_dt)
    return summaries

@app.get("/bridge/audit/export")
async def export_audit_logs(
    request: Request,
    x_api_key: str = Header(...),
    event_type: Optional[str] = Query(None),
    actor: Optional[str] = Query(None),
    start_date: Optional[str] = Query(None),
    end_date: Optional[str] = Query(None)
):
    """Export audit logs as CSV."""
    verify_api_key(x_api_key)
    
    client_ip = request.client.host if request.client else None
    user_agent = request.headers.get("user_agent")
    
    # Parse dates
    start_dt = None
    end_dt = None
    try:
        if start_date:
            start_dt = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
        if end_date:
            end_dt = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
    except ValueError as e:
        raise HTTPException(status_code=400, detail=f"Invalid date format: {e}")
    
    # Log the export request
    audit_logger.log_event(
        event_type=AuditEventType.API_REQUEST,
        actor=AuditActor.CLAUDE,
        details={
            "endpoint": "/bridge/audit/export",
            "filters": {
                "event_type": event_type,
                "actor": actor,
                "start_date": start_date,
                "end_date": end_date
            }
        },
        ip_address=client_ip,
        user_agent=user_agent
    )
    
    csv_output = AuditQueries.export_to_csv(
        event_type=event_type,
        actor=actor,
        start_date=start_dt,
        end_date=end_dt
    )
    
    filename = f"audit_export_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv"
    
    return StreamingResponse(
        iter([csv_output.getvalue()]),
        media_type="text/csv",
        headers={"Content-Disposition": f"attachment; filename={filename}"}
    )

@app.get("/bridge/audit/stats")
async def get_audit_stats(
    request: Request,
    x_api_key: str = Header(...)
):
    """Get audit system statistics."""
    verify_api_key(x_api_key)
    
    counts = AuditCleanup.get_log_counts()
    
    conn = get_db_connection()
    try:
        with conn.cursor(cursor_factory=RealDictCursor) as cur:
            # Get event type distribution
            cur.execute(
                sql.SQL("""
                    SELECT event_type, COUNT(*) as count
                    FROM {}.audit_log
                    WHERE created_at >= NOW() - INTERVAL '7 days'
                    GROUP BY event_type
                    ORDER BY count DESC
                """).format(sql.Identifier(SCHEMA))
            )
            event_distribution = [dict(row) for row in cur.fetchall()]
            
            # Get actor distribution
            cur.execute(
                sql.SQL("""
                    SELECT actor, COUNT(*) as count
                    FROM {}.audit_log
                    WHERE created_at >= NOW() - INTERVAL '7 days'
                    GROUP BY actor
                    ORDER BY count DESC
                """).format(sql.Identifier(SCHEMA))
            )
            actor_distribution = [dict(row) for row in cur.fetchall()]
            
            # Get recent activity
            cur.execute(
                sql.SQL("""
                    SELECT COUNT(*) as count
                    FROM {}.audit_log
                    WHERE created_at >= NOW() - INTERVAL '24 hours'
                """).format(sql.Identifier(SCHEMA))
            )
            last_24h = cur.fetchone()["count"]
            
    finally:
        conn.close()
    
    return {
        "total_logs": counts["active_logs"],
        "archived_logs": counts["archived_logs"],
        "last_24h_events": last_24h,
        "event_type_distribution": event_distribution,
        "actor_distribution": actor_distribution
    }

@app.get("/bridge/audit/directive/{directive_id}")
async def get_directive_timeline(
    directive_id: int,
    request: Request,
    x_api_key: str = Header(...)
):
    """Get full timeline for a specific directive."""
    verify_api_key(x_api_key)
    
    client_ip = request.client.host if request.client else None
    user_agent = request.headers.get("user_agent")
    
    timeline = AuditQueries.get_directive_timeline(directive_id)
    
    # Log the query
    audit_logger.log_event(
        event_type=AuditEventType.STATUS_QUERIED,
        actor=AuditActor.CLAUDE,
        directive_id=directive_id,
        details={"endpoint": f"/bridge/audit/directive/{directive_id}"},
        ip_address=client_ip,
        user_agent=user_agent
    )
    
    return {
        "directive_id": directive_id,
        "timeline": timeline,
        "total_events": len(timeline)
    }

@app.post("/bridge/audit/cleanup")
async def trigger_cleanup(
    request: Request,
    x_api_key: str = Header(...)
):
    """Manually trigger archive cleanup."""
    verify_api_key(x_api_key)
    
    client_ip = request.client.host if request.client else None
    user_agent = request.headers.get("user_agent")
    
    try:
        archived_count = AuditCleanup.archive_old_logs()
        
        audit_logger.log_system_event(
            "manual_cleanup",
            {
                "archived_count": archived_count,
                "triggered_by": "manual_api_call"
            }
        )
        
        return {
            "success": True,
            "archived_count": archived_count,
            "message": f"Archived {archived_count} old audit logs"
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")

# Decorator for auto-logging
def audit_log(event_type: str, actor: str, include_request: bool = True):
    """Decorator to auto-log function calls as audit events."""
    def decorator(func):
        @functools.wraps(func)
        async def async_wrapper(*args, **kwargs):
            request = kwargs.get('request')
            directive_id = kwargs.get('directive_id')
            
            ip_address = None
            user_agent = None
            if request and include_request:
                ip_address = request.client.host if request.client else None
                user_agent = request.headers.get("user_agent")
            
            try:
                result = await func(*args, **kwargs)
                
                audit_logger.log_event(
                    event_type=event_type,
                    actor=actor,
                    directive_id=directive_id,
                    details={
                        "function": func.__name__,
                        "status": "success"
                    },
                    ip_address=ip_address,
                    user_agent=user_agent
                )
                
                return result
            except Exception as e:
                audit_logger.log_event(
                    event_type=event_type.replace('_success', '_failed') if '_success' in event_type else AuditEventType.DIRECTIVE_FAILED,
                    actor=actor,
                    directive_id=directive_id,
                    details={
                        "function": func.__name__,
                        "error": str(e)
                    },
                    ip_address=ip_address,
                    user_agent=user_agent
                )
                raise
        
        @functools.wraps(func)
        def sync_wrapper(*args, **kwargs):
            request = kwargs.get('request')
            directive_id = kwargs.get('directive_id')
            
            ip_address = None
            user_agent = None
            if request and include_request:
                ip_address = request.client.host if request.client else None
                user_agent = request.headers.get("user_agent")
            
            try:
                result = func(*args, **kwargs)
                
                audit_logger.log_event(
                    event_type=event_type,
                    actor=actor,
                    directive_id=directive_id,
                    details={
                        "function": func.__name__,
                        "status": "success"
                    },
                    ip_address=ip_address,
                    user_agent=user_agent
                )
                
                return result
            except Exception as e:
                audit_logger.log_event(
                    event_type=event_type.replace('_success', '_failed') if '_success' in event_type else AuditEventType.DIRECTIVE_FAILED,
                    actor=actor,
                    directive_id=directive_id,
                    details={
                        "function": func.__name__,
                        "error": str(e)
                    },
                    ip_address=ip_address,
                    user_agent=user_agent
                )
                raise
        
        import asyncio
        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        return sync_wrapper
    return decorator


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8080)