
"""
Corrective RAG (CRAG) Implementation
Evaluates retrieval quality.
ADAPTED: Uses Anthropic.
"""

from typing import List, Dict
from config import config
from naive_rag import NaiveRAG

class CorrectiveRAG(NaiveRAG):
    def evaluate_relevance(self, query: str, context: List[Dict]) -> List[str]:
        """Evaluate if retrieved documents are relevant to the query."""
        relevance_judgments = []
        for doc in context:
            prompt = f"""
            You are a grader assessing relevance of a retrieved document to a user query.
            Query: {query}
            Document: {doc['content']}
            
            Is this document relevant? Output 'yes' or 'no' only.
            """
            response = self.anthropic.messages.create(
                model=config.llm.model,
                max_tokens=10,
                messages=[{"role": "user", "content": prompt}],
                temperature=0
            )
            relevance_judgments.append(response.content[0].text.strip().lower())
        return relevance_judgments

    def web_search_fallback(self, query: str) -> List[Dict]:
        """Simulate web search for missing info."""
        return [{
            "content": f"[WEB SEARCH RESULT] External information about {query}. (Simulated fallback)",
            "score": 1.0,
            "source": "web"
        }]

    def retrieve(self, query: str, limit: int = 5) -> List[Dict]:
        """CRAG retrieval logic."""
        context = super().retrieve(query, limit)
        judgments = self.evaluate_relevance(query, context)
        
        refined_context = []
        ambiguous = False
        
        for doc, judgment in zip(context, judgments):
            if "yes" in judgment:
                refined_context.append(doc)
            else:
                ambiguous = True 
        
        if not refined_context or (len(refined_context) < limit // 2):
            web_results = self.web_search_fallback(query)
            refined_context.extend(web_results)
            
        return refined_context

    def query(self, query: str) -> Dict:
        """End-to-end CRAG pipeline."""
        context = self.retrieve(query)
        answer = self.generate(query, context)
        return {
            "query": query,
            "answer": answer,
            "context": context,
            "method": "Corrective"
        }

if __name__ == "__main__":
    rag = CorrectiveRAG()
