
"""
Fusion RAG Implementation
Uses Reciprocal Rank Fusion (RRF) to combine results from multiple query variations.
ADAPTED: Uses Anthropic.
"""

from typing import List, Dict
from collections import defaultdict
from config import config
from naive_rag import NaiveRAG

class FusionRAG(NaiveRAG):
    def generate_queries(self, query: str, num_queries: int = 4) -> List[str]:
        """Generate multiple search queries from a single user query."""
        prompt = f"""
        You are a helpful assistant that generates multiple search queries based on a single input query.
        Generate {num_queries} search queries related to: "{query}"
        OUTPUT ONLY THE QUERIES, one per line. Do not number them.
        """
        
        response = self.anthropic.messages.create(
            model=config.llm.model,
            max_tokens=200,
            messages=[
                {"role": "user", "content": prompt}
            ],
            system="You are a helpful assistant."
        )
        return [q.strip() for q in response.content[0].text.split('\n') if q.strip()]

    def reciprocal_rank_fusion(self, results_dict: Dict[str, List[Dict]], k: int = 60) -> List[Dict]:
        """Fuse results using RRF algorithm."""
        fused_scores = defaultdict(float)
        doc_map = {}
        
        for query, results in results_dict.items():
            for rank, doc in enumerate(results):
                doc_content = doc['content']
                doc_map[doc_content] = doc # keep reference to full doc
                fused_scores[doc_content] += 1 / (rank + k)
        
        # Sort by fused score
        reranked_results = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
        
        final_results = []
        for content, score in reranked_results:
            doc = doc_map[content]
            doc['score'] = score # update score to RRF score
            doc['original_score'] = doc.get('score')
            final_results.append(doc)
            
        return final_results

    def retrieve(self, query: str, limit: int = 5) -> List[Dict]:
        """Retrieve using Fusion strategy."""
        queries = self.generate_queries(query)
        queries.append(query) 
        
        all_results = {}
        for q in queries:
            all_results[q] = super().retrieve(q, limit=limit)
            
        fused_results = self.reciprocal_rank_fusion(all_results)
        return fused_results[:limit]

    def query(self, query: str) -> Dict:
        """End-to-end Fusion pipeline."""
        context = self.retrieve(query)
        answer = self.generate(query, context)
        return {
            "query": query,
            "answer": answer,
            "context": context,
            "method": "Fusion"
        }

if __name__ == "__main__":
    rag = FusionRAG()
