from fastmcp import FastMCP
import datetime
import hashlib
import json
import re
import random
import time
from typing import List, Dict, Any, Optional

# Initialize the server
mcp = FastMCP("Genesis Patent OS")

# --- Patent 4: Immutable Audit Trail Logic ---
def log_to_audit_trail(event_type: str, details: Any, user_id: str = "system"):
    timestamp = datetime.datetime.now().isoformat()
    # In a real system, we'd fetch the last hash from the file
    prev_hash = "0000000000000000000000000000000000000000000000000000000000000000"
    
    entry = {
        "timestamp": timestamp,
        "event_type": event_type,
        "details": details,
        "user_id": user_id,
        "prev_hash": prev_hash
    }
    
    entry_str = json.dumps(entry, sort_keys=True)
    current_hash = hashlib.sha256(entry_str.encode()).hexdigest()
    
    try:
        with open("patent_audit_trail.jsonl", "a") as f:
            f.write(json.dumps({**entry, "hash": current_hash}) + "\n")
    except Exception:
        pass
    return current_hash

# --- Patent 7: Hallucination Detection ---
@mcp.tool()
def validate_hallucination(text: str, context: str = "") -> dict:
    """
    Patent 7: Real-Time Hallucination Detection.
    Checks for uncertainty markers and cross-references with context.
    """
    hallucination_indicators = ["i think", "maybe", "possibly", "it is likely", "to the best of my knowledge"]
    score = 1.0
    issues = []
    
    for indicator in hallucination_indicators:
        if indicator in text.lower():
            score -= 0.15
            issues.append(f"Uncertainty marker found: '{indicator}'")
            
    if context:
        # Simple Jaccard-like overlap check
        context_words = set(re.findall(r'\w+', context.lower()))
        text_words = set(re.findall(r'\w+', text.lower()))
        if text_words:
            overlap = len(context_words.intersection(text_words)) / len(text_words)
            if overlap < 0.1:
                score -= 0.3
                issues.append(f"Low contextual relevance (Overlap: {overlap:.2f})")
    
    result = {
        "is_valid": score > 0.6,
        "confidence_score": round(score, 2),
        "issues": issues,
        "patent_ref": "US Patent 7 - Real-Time Hallucination Detection"
    }
    
    log_to_audit_trail("HALLUCINATION_CHECK", result)
    return result

# --- Patent 3: Multi-Dimensional Risk Assessment ---
@mcp.tool()
def assess_risk(proposal: str, industry: str = "general") -> dict:
    """
    Patent 3: Multi-Dimensional Risk Assessment.
    Scores proposal across Financial, Legal, Operational, and Reputational dimensions.
    """
    dimensions = {
        "financial": ["cost", "investment", "price", "profit", "budget"],
        "legal": ["contract", "liability", "law", "regulation", "compliance"],
        "operational": ["process", "efficiency", "workflow", "implementation"],
        "reputational": ["ethics", "brand", "public", "trust", "scandal"]
    }
    
    scores = {}
    for dim, keywords in dimensions.items():
        dim_score = 0
        for kw in keywords:
            if kw in proposal.lower():
                dim_score += 0.2
        scores[dim] = min(1.0, dim_score)
        
    # Apply weights
    weights = {"financial": 0.4, "legal": 0.3, "operational": 0.2, "reputational": 0.1}
    overall_risk = sum(scores[d] * weights[d] for d in scores)
    
    result = {
        "overall_risk_score": round(overall_risk, 2),
        "dimension_scores": scores,
        "risk_level": "High" if overall_risk > 0.6 else "Medium" if overall_risk > 0.3 else "Low",
        "patent_ref": "US Patent 3 - Multi-Dimensional Risk Assessment"
    }
    
    log_to_audit_trail("RISK_ASSESSMENT", result)
    return result

# --- Patent 5: Multi-Model Consensus ---
@mcp.tool()
def model_consensus(predictions: List[Dict[str, Any]]) -> dict:
    """
    Patent 5: Multi-Model Consensus Validation.
    Takes outputs from multiple models and finds agreement.
    Example input: [{"model": "gemini", "output": "yes", "conf": 0.9}, {"model": "claude", "output": "yes", "conf": 0.8}]
    """
    if not predictions:
        return {"error": "No predictions provided"}
        
    votes = {}
    for p in predictions:
        val = p.get("output")
        conf = p.get("conf", 0.5)
        votes[val] = votes.get(val, 0) + conf
        
    winner = max(votes, key=votes.get)
    agreement_ratio = votes[winner] / sum(votes.values())
    
    result = {
        "consensus_output": winner,
        "agreement_ratio": round(agreement_ratio, 2),
        "is_reliable": agreement_ratio > 0.6,
        "patent_ref": "US Patent 5 - Multi-Model Consensus Validation"
    }
    
    log_to_audit_trail("CONSENSUS_VALIDATION", result)
    return result

# --- Patent 8: Privacy-Preserving Validation ---
@mcp.tool()
def privacy_scrub(text: str) -> dict:
    """
    Patent 8: Privacy-Preserving AI Validation Protocol.
    Detects and redacts PII using regex and tokenization principles.
    """
    patterns = {
        "EMAIL": r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}',
        "PHONE": r'\b(?:\+?61|0)[2-478](?:[ -]?[0-9]){8}\b', # AU Phone
        "CREDIT_CARD": r'\b(?:\d[ -]*?){13,16}\b'
    }
    
    redacted_text = text
    found_types = []
    
    for label, pattern in patterns.items():
        if re.search(pattern, text):
            found_types.append(label)
            redacted_text = re.sub(pattern, f"[{label}_REDACTED]", redacted_text)
            
    result = {
        "pii_detected": len(found_types) > 0,
        "detected_types": found_types,
        "scrubbed_text": redacted_text,
        "patent_ref": "US Patent 8 - Privacy-Preserving AI Validation"
    }
    
    log_to_audit_trail("PRIVACY_SCRUB", {"pii_types": found_types})
    return result

# --- Patent 6: Dynamic Confidence Scoring ---
@mcp.tool()
def calculate_confidence(base_score: float, time_since_last_check: float = 0) -> dict:
    """
    Patent 6: Dynamic Confidence Scoring.
    Implements confidence decay over time.
    """
    decay_rate = 0.01 # 1% per minute
    decayed_score = base_score * (1 - (decay_rate * (time_since_last_check / 60)))
    
    result = {
        "original_score": base_score,
        "decayed_score": round(max(0, decayed_score), 2),
        "decay_applied": round(base_score - decayed_score, 2),
        "patent_ref": "US Patent 6 - Dynamic Confidence Scoring"
    }
    
    return result

if __name__ == "__main__":
    mcp.run()
