from fastapi import FastAPI, HTTPException, Depends
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy.orm import Session
from pydantic import BaseModel
import uvicorn
import sys
import os
import hashlib
import datetime
import json

# Add current directory to path to import server
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from server import validate_hallucination, assess_risk, audit_log, privacy_scan
from database import engine, get_db, Base
from models import AuditLog, RiskAssessment, HallucinationCheck

# Create Tables
Base.metadata.create_all(bind=engine)

app = FastAPI(title="Genesis Patent OS API", version="1.0.0")

# Serve static files (widget)
os.makedirs("widget", exist_ok=True)
app.mount("/widget", StaticFiles(directory="widget"), name="widget")

# --- Schemas ---
class ValidationRequest(BaseModel):
    text: str
    context: str = ""

class RiskRequest(BaseModel):
    proposal: str
    industry: str = "general"

class AuditRequest(BaseModel):
    event_type: str
    details: dict
    user_id: str = "system"

class PrivacyRequest(BaseModel):
    text: str

# --- Endpoints ---

@app.post("/validate/hallucination")
async def api_validate_hallucination(req: ValidationRequest, db: Session = Depends(get_db)):
    """Patent 7 Endpoint (with DB Persistence)"""
    result = validate_hallucination(req.text, req.context)
    
    # Store in HallucinationCheck
    db_check = HallucinationCheck(
        text_snippet=req.text[:100], # Store snippet only
        confidence_score=result['confidence_score'],
        is_valid=1 if result['is_valid'] else 0,
        issues_found=result['issues']
    )
    db.add(db_check)
    db.commit()
    
    # Log to Audit
    _log_to_db(db, "HALLUCINATION_CHECK", result, "api_user")
    
    return result

@app.post("/assess/risk")
async def api_assess_risk(req: RiskRequest, db: Session = Depends(get_db)):
    """Patent 3 Endpoint (with DB Persistence)"""
    result = assess_risk(req.proposal, req.industry)
    
    # Store Risk
    proposal_hash = hashlib.sha256(req.proposal.encode()).hexdigest()
    db_risk = RiskAssessment(
        proposal_hash=proposal_hash,
        risk_score=result['overall_risk_score'],
        risk_level=result['risk_level'],
        dimension_scores=result['dimension_scores'],
        industry=req.industry
    )
    db.add(db_risk)
    db.commit()
    
    _log_to_db(db, "RISK_ASSESSMENT", result, "api_user")
    return result

@app.post("/audit/log")
async def api_audit_log(req: AuditRequest, db: Session = Depends(get_db)):
    """Patent 4 Endpoint (Direct Log)"""
    return _log_to_db(db, req.event_type, req.details, req.user_id)

@app.post("/privacy/scan")
async def api_privacy_scan(req: PrivacyRequest, db: Session = Depends(get_db)):
    """Patent 8 Endpoint"""
    result = privacy_scan(req.text)
    if result['pii_detected']:
         _log_to_db(db, "PRIVACY_ALERT", {"types": result['detected_types']}, "api_user")
    return result

# --- Internal Helper ---
def _log_to_db(db: Session, event_type: str, details: dict, user_id: str):
    # Fetch last hash
    last_entry = db.query(AuditLog).order_by(AuditLog.timestamp.desc()).first()
    prev_hash = last_entry.current_hash if last_entry else "0"*64
    
    # Create new entry
    new_entry = AuditLog(
        event_type=event_type,
        details=details,
        user_id=user_id,
        prev_hash=prev_hash,
        timestamp=datetime.datetime.utcnow()
    )
    
    # Calculate hash (Patent 4 requirement)
    data_str = f"{new_entry.timestamp}{new_entry.event_type}{json.dumps(details, sort_keys=True)}{new_entry.user_id}{new_entry.prev_hash}"
    new_entry.current_hash = hashlib.sha256(data_str.encode()).hexdigest()
    
    db.add(new_entry)
    db.commit()
    db.refresh(new_entry)
    return {"status": "logged", "tx_hash": new_entry.current_hash}

# --- Widget Endpoint ---
@app.get("/badge/status")
async def get_badge_status(db: Session = Depends(get_db)):
    """Returns the live status for the 'Powered by Sunaiva' badge."""
    # Logic: If recent high-risk events exist, show warning. Else, show secure.
    one_hour_ago = datetime.datetime.utcnow() - datetime.timedelta(hours=1)
    
    recent_risks = db.query(RiskAssessment).filter(
        RiskAssessment.timestamp > one_hour_ago,
        RiskAssessment.risk_score > 0.7
    ).count()
    
    checks_count = db.query(HallucinationCheck).filter(
        HallucinationCheck.timestamp > one_hour_ago
    ).count()
    
    status = "SECURE"
    color = "green"
    if recent_risks > 0:
        status = "WARNING"
        color = "orange"
        
    return {
        "status": status,
        "color": color,
        "checks_last_hour": checks_count,
        "message": "9-Layer Shield Active",
        "powered_by": "Sunaiva Digital"
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
