# patent_api_server.py
import asyncio
import logging
import os
import time
from typing import List, Dict, Any

import prometheus_client
from fastapi import FastAPI, HTTPException, Depends, status, WebSocket, WebSocketDisconnect, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.staticfiles import StaticFiles
from prometheus_client import CollectorRegistry, Gauge
from pydantic import BaseModel, Field
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.status import HTTP_401_UNAUTHORIZED
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded

# --- Logging Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Environment Variables ---
JWT_SECRET = os.environ.get("JWT_SECRET", "super-secret-jwt-key")  # Replace with a strong secret in production
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

# --- Prometheus Metrics ---
registry = CollectorRegistry()
REQUESTS_TOTAL = Gauge('patent_api_requests_total', 'Total number of requests', registry=registry)
REQUEST_LATENCY = Gauge('patent_api_request_latency_seconds', 'Request latency in seconds', registry=registry)

# --- FastAPI App Setup ---
app = FastAPI(
    title="Patent Validation System API",
    description="API for validating AI outputs and managing patent-specific validations.",
    version="0.1.0",
    openapi_url="/api/v1/openapi.json",
    docs_url="/api/v1/docs",
    redoc_url="/api/v1/redoc",
)

# --- Rate Limiting ---
limiter = Limiter(key_func=get_remote_address, default_limits=["200/minute"])
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# --- CORS Configuration ---
origins = [
    "http://localhost",
    "http://localhost:8080",
    "*",  # WARNING: Only for development, restrict in production!
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- Security ---
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

# --- Data Models ---
class Token(BaseModel):
    access_token: str
    token_type: str

class User(BaseModel):
    username: str = "admin" #Hardcoded for simplicity

class ValidationRequest(BaseModel):
    input_text: str = Field(..., description="The AI generated text to validate.")

class BatchValidationRequest(BaseModel):
    inputs: List[str] = Field(..., description="List of AI generated texts to validate.")

class ValidationResult(BaseModel):
    validation_id: str
    status: str
    result: Dict[str, Any] = {}

class PatentStatus(BaseModel):
    patent_id: str
    status: str
    details: Dict[str, Any] = {}

class AuditEvent(BaseModel):
    event_id: str
    timestamp: float
    event_type: str
    user: str
    details: Dict[str, Any] = {}

class ThresholdUpdateRequest(BaseModel):
    threshold_name: str
    new_value: float

class ConfigResponse(BaseModel):
    thresholds: Dict[str, float]

# --- Dummy Data (Replace with Database) ---
validation_results: Dict[str, ValidationResult] = {}
patent_statuses: List[PatentStatus] = []
audit_trail: List[AuditEvent] = []
thresholds: Dict[str, float] = {"hallucination_threshold": 0.8, "consensus_agreement": 0.7}

# --- Security Functions ---
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from datetime import datetime, timedelta

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

def get_password_hash(password):
    return pwd_context.hash(password)

def create_access_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=ALGORITHM)
    return encoded_jwt

async def get_current_user(token: str = Depends(oauth2_scheme)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, JWT_SECRET, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
        return User(username=username)
    except JWTError:
        raise credentials_exception

# --- Authentication Endpoint ---
@app.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
    user = User(username="admin") #Hardcoded user
    if not verify_password(form_data.password, pwd_context.hash("password")):  #Hardcoded password
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}


# --- Background Task ---
async def validate_ai_output_background(input_text: str):
    await asyncio.sleep(2)  # Simulate validation processing
    validation_id = str(time.time())
    validation_results[validation_id] = ValidationResult(
        validation_id=validation_id,
        status="completed",
        result={"score": 0.95, "details": "Passed all validation checks."},
    )
    logger.info(f"Validation completed for ID: {validation_id}")

# --- API Endpoints ---
@app.get("/health")
@limiter.limit("5/second")
async def health_check(request: Request):
    """Health check endpoint."""
    return {"status": "ok"}

@app.get("/metrics")
async def metrics():
    """Prometheus metrics endpoint."""
    return StreamingResponse(prometheus_client.generate_latest(registry), media_type="text/plain")

@app.post("/validate", response_model=ValidationResult)
@limiter.limit("10/minute")
async def validate_ai(request: Request, validation_request: ValidationRequest, background_tasks: BackgroundTasks, current_user: User = Depends(get_current_user)):
    """Validate a single AI output."""
    REQUESTS_TOTAL.inc()
    start_time = time.time()
    validation_id = str(time.time())
    validation_results[validation_id] = ValidationResult(validation_id=validation_id, status="pending")
    background_tasks.add_task(validate_ai_output_background, validation_request.input_text)
    REQUEST_LATENCY.set(time.time() - start_time)
    return ValidationResult(validation_id=validation_id, status="pending")

@app.post("/validate/batch", response_model=List[ValidationResult])
@limiter.limit("5/minute")
async def validate_batch(request: Request, batch_validation_request: BatchValidationRequest, current_user: User = Depends(get_current_user)):
    """Validate a batch of AI outputs."""
    results = []
    for input_text in batch_validation_request.inputs:
        validation_id = str(time.time())
        validation_results[validation_id] = ValidationResult(
            validation_id=validation_id, status="completed", result={"score": 0.8}
        )
        results.append(validation_results[validation_id])
    return results

@app.get("/validate/{validation_id}", response_model=ValidationResult)
async def get_validation_result(validation_id: str, current_user: User = Depends(get_current_user)):
    """Get the result of a specific validation."""
    if validation_id not in validation_results:
        raise HTTPException(status_code=404, detail="Validation not found")
    return validation_results[validation_id]

@app.websocket("/validate/stream")
async def websocket_endpoint(websocket: WebSocket, current_user: User = Depends(get_current_user)):
    """Real-time validation via WebSocket."""
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_text()
            # Simulate real-time validation
            await websocket.send_text(f"Validating: {data}")
            await asyncio.sleep(1)
            await websocket.send_text(f"Validation completed for: {data}")
    except WebSocketDisconnect:
        logger.info("Client disconnected from WebSocket")

@app.post("/patents/p1/crypto")
async def patent_p1_crypto(validation_request: ValidationRequest, current_user: User = Depends(get_current_user)):
    """Cryptographic validation for patent P1."""
    return {"status": "success", "result": "Cryptographic validation passed."}

@app.post("/patents/p5/consensus")
async def patent_p5_consensus(validation_request: ValidationRequest, current_user: User = Depends(get_current_user)):
    """Multi-model consensus validation for patent P5."""
    if float(thresholds["consensus_agreement"]) > 0.7:
        return {"status": "success", "result": "Multi-model consensus validation passed."}
    else:
         raise HTTPException(status_code=500, detail=f"Threshold agreement is below {thresholds['consensus_agreement']}")


@app.post("/patents/p7/hallucination")
async def patent_p7_hallucination(validation_request: ValidationRequest, current_user: User = Depends(get_current_user)):
    """Hallucination detection for patent P7."""
    if float(thresholds["hallucination_threshold"]) < 0.8:
        return {"status": "success", "result": "Hallucination detection passed."}
    else:
        raise HTTPException(status_code=500, detail=f"Threshold hallucination is above {thresholds['hallucination_threshold']}")

@app.get("/patents/status", response_model=List[PatentStatus])
async def get_patent_status(current_user: User = Depends(get_current_user)):
    """Get the status of all patent systems."""
    return patent_statuses

@app.get("/audit/trail", response_model=List[AuditEvent])
async def get_audit_trail(current_user: User = Depends(get_current_user)):
    """Get the audit trail."""
    return audit_trail

@app.get("/audit/{event_id}", response_model=AuditEvent)
async def get_audit_event(event_id: str, current_user: User = Depends(get_current_user)):
    """Get a specific audit event."""
    for event in audit_trail:
        if event.event_id == event_id:
            return event
    raise HTTPException(status_code=404, detail="Event not found")

@app.post("/audit/export")
async def export_audit_data(current_user: User = Depends(get_current_user)):
    """Export audit data."""
    # Simulate exporting audit data
    return {"status": "Export initiated."}

@app.post("/admin/thresholds")
async def update_thresholds(threshold_update: ThresholdUpdateRequest, current_user: User = Depends(get_current_user)):
    """Update validation thresholds."""
    if threshold_update.threshold_name in thresholds:
        thresholds[threshold_update.threshold_name] = threshold_update.new_value
        return {"message": f"Threshold {threshold_update.threshold_name} updated to {threshold_update.new_value}"}
    else:
        raise HTTPException(status_code=400, detail="Invalid threshold name")

@app.get("/admin/config", response_model=ConfigResponse)
async def get_config(current_user: User = Depends(get_current_user)):
    """Get the current configuration."""
    return ConfigResponse(thresholds=thresholds)

from fastapi import BackgroundTasks

# --- Example usage of BackgroundTasks ---
@app.post("/long_task")
async def long_task(background_tasks: BackgroundTasks):
    background_tasks.add_task(some_long_task, "some_data")
    return {"message": "Task started in the background"}

async def some_long_task(data: str):
    await asyncio.sleep(5)  # Simulate a long-running task
    logger.info(f"Long task completed with data: {data}")
