
"""
RAG Benchmark Script
Compares the performance of Naive, HyDE, Fusion, Agentic, and Corrective RAG.
Metrics: Latency, Retrieval Count, Relevance (Simulated).
"""

import time
import json
from typing import List, Dict
from naive_rag import NaiveRAG
from hyde_rag import HydeRAG
from fusion_rag import FusionRAG
from agentic_rag import AgenticRAG
from corrective_rag import CorrectiveRAG

QUESTIONS = [
    "What is the Genesis System?",
    "How do I add a new MCP server?",
    "Who is Kinan?",
    "Explain the memory tiers.",
    "What is the plan for GHL integration?"
]

def benchmark_pattern(rag_class, name: str) -> List[Dict]:
    print(f"\n--- Benchmarking {name} ---")
    
    try:
        rag = rag_class()
    except Exception as e:
        print(f"Failed to initialize {name}: {e}")
        return []

    results = []
    
    for q in QUESTIONS:
        start_time = time.time()
        try:
            res = rag.query(q)
            duration = time.time() - start_time
            
            # Simple metric extraction
            context_count = len(res.get("context", []))
            answer_len = len(res.get("answer", ""))
            
            print(f"  Q: {q[:30]}... | T: {duration:.2f}s | Ctx: {context_count}")
            
            results.append({
                "pattern": name,
                "question": q,
                "latency": duration,
                "context_count": context_count,
                "answer_length": answer_len,
                "status": "success"
            })
            
        except Exception as e:
            print(f"  Failed Q: {q[:30]}... | Error: {e}")
            results.append({
                "pattern": name,
                "question": q,
                "status": "failed",
                "error": str(e)
            })
            
    return results

def run_benchmarks():
    all_results = []
    
    patterns = [
        (NaiveRAG, "Naive"),
        (HydeRAG, "HyDE"),
        (FusionRAG, "Fusion"),
        (AgenticRAG, "Agentic"),
        (CorrectiveRAG, "Corrective")
    ]
    
    for cls, name in patterns:
        all_results.extend(benchmark_pattern(cls, name))
        
    print("\n\n=== BENCHMARK SUMMARY ===")
    print(f"{'Pattern':<12} | {'Avg Latency':<12} | {'Success Rate':<12}")
    print("-" * 40)
    
    for cls, name in patterns:
        pattern_res = [r for r in all_results if r['pattern'] == name]
        if not pattern_res:
            continue
            
        avg_lat = sum(r['latency'] for r in pattern_res if r['status'] == 'success') / len(pattern_res) if pattern_res else 0
        success = len([r for r in pattern_res if r['status'] == 'success'])
        
        print(f"{name:<12} | {avg_lat:.2f}s        | {success}/{len(QUESTIONS)}")

    # Save detailed results
    with open("research/rag_patterns/benchmark_results.json", "w") as f:
        json.dump(all_results, f, indent=2)

if __name__ == "__main__":
    run_benchmarks()
