# rag_engine_advanced.py
import logging
import os
import re
import time
from typing import List, Dict, Tuple, Callable, Optional, Union

import nltk
import numpy as np
import torch
from nltk.corpus import wordnet
from nltk.tokenize import sent_tokenize, word_tokenize
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
from sklearn.metrics.pairwise import cosine_similarity

# Ensure NLTK resources are available
try:
    nltk.data.find('corpora/wordnet')
except LookupError:
    nltk.download('wordnet')

try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class RAGQueryEngine:
    """
    A Retrieval-Augmented Generation (RAG) query engine that combines semantic search,
    BM25 retrieval, query expansion, reranking, MMR, and context packing to provide
    context-grounded and citation-supported answers.
    """

    def __init__(
            self,
            documents: List[str],
            document_ids: List[str],
            embedding_model_name: str = "all-mpnet-base-v2",
            cross_encoder_model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
            bm25_k1: float = 1.2,
            bm25_b: float = 0.75,
            query_expansion_synonyms: int = 3,
            mmr_diversity: float = 0.5,
            max_context_length: int = 3000,
            device: Optional[str] = None,
            hyde_model_name: str = "text-generation-gpt2-medium",
            cache_dir: str = "rag_cache",
            llm_callable: Callable[[str], str] = lambda x: f"Dummy LLM: {x}" # Replace with actual LLM call
    ):
        """
        Initializes the RAGQueryEngine.

        Args:
            documents (List[str]): A list of documents to index.
            document_ids (List[str]): A list of document IDs corresponding to the documents.
            embedding_model_name (str): The name of the SentenceTransformer model to use for embeddings.
            cross_encoder_model_name (str): The name of the CrossEncoder model to use for reranking.
            bm25_k1 (float): The k1 parameter for BM25.
            bm25_b (float): The b parameter for BM25.
            query_expansion_synonyms (int): The number of synonyms to use for query expansion.
            mmr_diversity (float): The diversity parameter for MMR.
            max_context_length (int): The maximum length of the context to provide to the language model.
            device (Optional[str]): The device to use for the models ("cpu" or "cuda"). If None, it will automatically detect GPU availability.
            hyde_model_name (str): Name of the model to use for HyDE.
            cache_dir (str): Directory to store cached embeddings and other data.
            llm_callable (Callable[[str], str]): A callable that takes a prompt and returns a string (the LLM's response).  Defaults to a dummy LLM.
        """

        self.documents = documents
        self.document_ids = document_ids
        self.embedding_model_name = embedding_model_name
        self.cross_encoder_model_name = cross_encoder_model_name
        self.bm25_k1 = bm25_k1
        self.bm25_b = bm25_b
        self.query_expansion_synonyms = query_expansion_synonyms
        self.mmr_diversity = mmr_diversity
        self.max_context_length = max_context_length
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.hyde_model_name = hyde_model_name
        self.cache_dir = cache_dir
        self.llm_callable = llm_callable

        if not os.path.exists(self.cache_dir):
            os.makedirs(self.cache_dir)

        self.embedding_model = self._load_embedding_model()
        self.cross_encoder = self._load_cross_encoder()
        self.bm25 = self._build_bm25_index()
        self.document_embeddings = self._compute_document_embeddings()
        self.hyde_model = self._load_hyde_model()

        self.chunk_cache: Dict[str, List[str]] = {}  # Cache for adaptive chunking results

        logging.info(f"RAGQueryEngine initialized with device: {self.device}")

    def _load_embedding_model(self) -> SentenceTransformer:
        """Loads the SentenceTransformer model."""
        model = SentenceTransformer(self.embedding_model_name, device=self.device)
        logging.info(f"Loaded embedding model: {self.embedding_model_name}")
        return model

    def _load_cross_encoder(self) -> CrossEncoder:
        """Loads the CrossEncoder model."""
        model = CrossEncoder(self.cross_encoder_model_name, device=self.device)
        logging.info(f"Loaded cross-encoder model: {self.cross_encoder_model_name}")
        return model

    def _load_hyde_model(self):
        """Loads the HyDE model (GPT-2 or similar).  This is a placeholder."""
        # In a real implementation, you would load a language model here.
        # For demonstration purposes, we return None and use a dummy implementation.
        logging.info(f"HyDE model loading skipped (placeholder).")
        return None

    def _build_bm25_index(self) -> BM25Okapi:
        """Builds the BM25 index."""
        tokenized_corpus = [word_tokenize(doc.lower()) for doc in self.documents]
        bm25 = BM25Okapi(tokenized_corpus, k1=self.bm25_k1, b=self.bm25_b)
        logging.info("Built BM25 index.")
        return bm25

    def _compute_document_embeddings(self) -> np.ndarray:
        """Computes and caches document embeddings."""
        cache_file = os.path.join(self.cache_dir, "document_embeddings.npy")
        if os.path.exists(cache_file):
            logging.info("Loading document embeddings from cache.")
            embeddings = np.load(cache_file)
        else:
            logging.info("Computing document embeddings...")
            embeddings = self.embedding_model.encode(self.documents, convert_to_numpy=True, show_progress_bar=True)
            np.save(cache_file, embeddings)
            logging.info("Document embeddings computed and cached.")
        return embeddings

    def _adaptive_chunking(self, document: str, query: str) -> List[str]:
        """
        Dynamically chunks a document based on query complexity.

        Args:
            document: The document to chunk.
            query: The query being asked.

        Returns:
            A list of text chunks.
        """
        cache_key = f"{hash(document)}_{hash(query)}"
        if cache_key in self.chunk_cache:
            return self.chunk_cache[cache_key]

        sentences = sent_tokenize(document)
        num_sentences = len(sentences)

        # Heuristic: More complex queries require smaller chunks
        query_complexity = len(word_tokenize(query))  # Simple proxy for complexity

        if query_complexity > 15:
            # Complex query: smaller chunks (e.g., single sentences)
            chunks = sentences
        elif query_complexity > 8:
            # Medium query: combine 2-3 sentences
            chunks = [" ".join(sentences[i:i + 2]) for i in range(0, num_sentences, 2)]
        else:
            # Simple query: larger chunks (e.g., paragraphs or multiple sentences)
            chunks = [" ".join(sentences[i:i + 4]) for i in range(0, num_sentences, 4)]

        # Ensure chunks are not empty
        chunks = [chunk for chunk in chunks if chunk.strip()]

        self.chunk_cache[cache_key] = chunks
        return chunks

    def _query_expansion(self, query: str, num_synonyms: int = 3) -> str:
        """
        Expands the query using WordNet synonyms.

        Args:
            query (str): The original query.
            num_synonyms (int): The number of synonyms to add for each word.

        Returns:
            str: The expanded query.
        """
        expanded_query = query
        words = word_tokenize(query)
        for word in words:
            synonyms = []
            for syn in wordnet.synsets(word):
                for lemma in syn.lemmas():
                    synonyms.append(lemma.name())
            synonyms = list(set(synonyms))[:num_synonyms]  # Get unique synonyms and limit the number
            if synonyms:
                expanded_query += " OR " + " OR ".join(synonyms)
        logging.info(f"Expanded query: {expanded_query}")
        return expanded_query

    def _hyde(self, query: str) -> str:
        """
        Generates a hypothetical document embedding for the query using HyDE.

        Args:
            query (str): The query.

        Returns:
            str: The generated hypothetical document.
        """

        if not self.hyde_model:
            logging.warning("HyDE model not loaded. Using dummy HyDE implementation.")
            # Dummy HyDE implementation
            return f"A hypothetical document about: {query}"

        # In a real implementation, you would use the HyDE model to generate a document.
        # This is a placeholder.
        # The LLM should be prompted to create a hypothetical document that answers the query.
        hyde_prompt = f"Write a detailed and comprehensive document that answers the following question: {query}"
        hypothetical_document = self.llm_callable(hyde_prompt)  # Use LLM to generate the hypo doc

        logging.info(f"Generated hypothetical document using HyDE.")
        return hypothetical_document

    def _semantic_search(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
        """
        Performs semantic search using SentenceTransformers.

        Args:
            query (str): The query.
            top_k (int): The number of documents to retrieve.

        Returns:
            List[Tuple[int, float]]: A list of tuples containing the document index and similarity score.
        """
        query_embedding = self.embedding_model.encode(query, convert_to_numpy=True)
        similarities = cosine_similarity([query_embedding], self.document_embeddings)[0]
        indices = np.argsort(similarities)[::-1][:top_k]
        return [(int(i), float(similarities[i])) for i in indices]

    def _bm25_retrieval(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
        """
        Performs BM25 retrieval.

        Args:
            query (str): The query.
            top_k (int): The number of documents to retrieve.

        Returns:
            List[Tuple[int, float]]: A list of tuples containing the document index and BM25 score.
        """
        tokenized_query = word_tokenize(query.lower())
        scores = self.bm25.get_scores(tokenized_query)
        indices = np.argsort(scores)[::-1][:top_k]
        return [(int(i), float(scores[i])) for i in indices]

    def _reciprocal_rank_fusion(self, results: List[List[Tuple[int, float]]], k: int = 60) -> List[int]:
        """
        Performs Reciprocal Rank Fusion (RRF) to combine results from multiple retrieval methods.

        Args:
            results (List[List[Tuple[int, float]]]): A list of lists, where each inner list contains tuples of (document index, score).
            k (int): The RRF parameter.

        Returns:
            List[int]: A list of document indices, ranked by RRF score.
        """
        document_scores: Dict[int, float] = {}
        for result_list in results:
            for rank, (doc_id, score) in enumerate(result_list):
                if doc_id not in document_scores:
                    document_scores[doc_id] = 0.0
                document_scores[doc_id] += 1 / (rank + k)

        ranked_documents = sorted(document_scores.items(), key=lambda item: item[1], reverse=True)
        return [doc_id for doc_id, score in ranked_documents]

    def _rerank(self, query: str, retrieved_documents: List[str]) -> List[int]:
        """
        Reranks the retrieved documents using a cross-encoder.

        Args:
            query (str): The query.
            retrieved_documents (List[str]): A list of retrieved documents.

        Returns:
            List[int]: A list of document indices, reranked by the cross-encoder.
        """
        features = [(query, doc) for doc in retrieved_documents]
        scores = self.cross_encoder.predict(features, show_progress_bar=False)
        indices = np.argsort(scores)[::-1]
        return indices

    def _mmr(self, query_embedding: np.ndarray, document_embeddings: np.ndarray, document_indices: List[int], diversity: float = 0.5, top_n: int = 10) -> List[int]:
        """
        Performs Maximum Marginal Relevance (MMR) to select a diverse set of documents.

        Args:
            query_embedding (np.ndarray): The embedding of the query.
            document_embeddings (np.ndarray): The embeddings of the retrieved documents.
            document_indices (List[int]): The indices of the retrieved documents.
            diversity (float): The diversity parameter (0.0 for maximal relevance, 1.0 for maximal diversity).
            top_n (int): Number of documents to select.

        Returns:
            List[int]: A list of document indices selected by MMR.
        """
        # Convert to numpy arrays if necessary
        if not isinstance(query_embedding, np.ndarray):
            query_embedding = np.array(query_embedding)
        if not isinstance(document_embeddings, np.ndarray):
            document_embeddings = np.array(document_embeddings)

        selected_indices = []
        remaining_indices = document_indices[:]  # Create a copy to avoid modifying the original list

        if not remaining_indices:
            return selected_indices  # Return empty list if no documents are provided

        # Calculate initial similarities to the query
        similarities = cosine_similarity([query_embedding], document_embeddings[remaining_indices])[0]

        for _ in range(min(top_n, len(document_indices))):
            # Calculate MMR scores for each document
            mmr_scores = []
            for i in range(len(remaining_indices)):
                index = remaining_indices[i]
                if selected_indices:
                    similarity_to_selected = np.max(cosine_similarity(document_embeddings[[index]], document_embeddings[selected_indices]))
                else:
                    similarity_to_selected = 0.0  # No selected documents yet

                mmr_score = diversity * similarities[i] - (1 - diversity) * similarity_to_selected
                mmr_scores.append(mmr_score)

            # Select the document with the highest MMR score
            best_index_within_remaining = np.argmax(mmr_scores)
            best_index = remaining_indices[best_index_within_remaining]
            selected_indices.append(best_index)

            # Remove the selected document from the remaining pool
            del remaining_indices[best_index_within_remaining]
            similarities = np.delete(similarities, best_index_within_remaining) # Update similarities

        return selected_indices

    def _assemble_context(self, indices: List[int]) -> Tuple[str, List[str]]:
        """
        Assembles the context from the retrieved documents, packing as much information as possible
        while staying within the maximum context length.

        Args:
            indices (List[int]): A list of document indices to include in the context.

        Returns:
            Tuple[str, List[str]]: A tuple containing the assembled context and a list of source document IDs.
        """
        context = ""
        source_ids = []
        added_documents = set()  # To deduplicate documents

        for index in indices:
            if index < 0 or index >= len(self.documents):
                logging.warning(f"Invalid document index: {index}")
                continue

            document = self.documents[index]
            doc_id = self.document_ids[index]

            if doc_id in added_documents:
                continue  # Skip duplicate documents

            if len(context) + len(document) + 100 > self.max_context_length: # Add some buffer
                logging.info("Maximum context length reached. Truncating context.")
                break

            context += document + "\n\n"
            source_ids.append(doc_id)
            added_documents.add(doc_id)

        return context.strip(), source_ids

    def _inject_citations(self, response: str, source_ids: List[str]) -> str:
        """
        Injects citations into the generated response.

        Args:
            response (str): The generated response.
            source_ids (List[str]): A list of source document IDs.

        Returns:
            str: The response with citations injected.
        """
        if not source_ids:
            return response

        citations = "[" + ", ".join(source_ids) + "]"
        return response + " (Citations: " + citations + ")"

    def _retrieve_necessity_prediction(self, query: str, context: str) -> bool:
        """
        Predicts whether retrieval is necessary for answering the query.
        This is a simplified placeholder. A real implementation would use a model
        trained for this task.

        Args:
            query (str): The query.
            context (str): The context.

        Returns:
            bool: True if retrieval is necessary, False otherwise.
        """
        # Simple heuristic: if the query contains specific keywords, retrieval is likely necessary
        keywords = ["what", "where", "when", "who", "how", "explain", "describe"]
        if any(keyword in query.lower() for keyword in keywords):
            return True
        else:
            # Otherwise, assume the LLM can answer from its own knowledge
            return False

    def _query_decomposition(self, query: str) -> List[str]:
        """
        Decomposes a complex query into simpler sub-queries.
        This is a simplified placeholder. A real implementation would use a model
        trained for query decomposition.

        Args:
            query (str): The complex query.

        Returns:
            List[str]: A list of sub-queries.
        """
        # Simple heuristic: Split the query if it contains "and" or "or"
        if " and " in query.lower():
            return query.split(" and ")
        elif " or " in query.lower():
            return query.split(" or ")
        else:
            return [query]  # No decomposition needed

    def aiva_query(self, query: str) -> str:
        """
        Answers a query using the RAG pipeline.

        Args:
            query (str): The query to answer.

        Returns:
            str: The answer to the query, with citations.
        """
        start_time = time.time()
        try:
            # 1. Query Decomposition
            sub_queries = self._query_decomposition(query)
            if len(sub_queries) > 1:
                logging.info(f"Decomposed query into sub-queries: {sub_queries}")

            # 2. Self-RAG: Retrieval Necessity Prediction (simplified)
            if self._retrieve_necessity_prediction(query, ""):  # No context yet for the first check
                logging.info("Retrieval is deemed necessary.")

                # 3. Query Expansion
                expanded_query = self._query_expansion(query, num_synonyms=self.query_expansion_synonyms)

                # 4. HyDE (Hypothetical Document Embedding)
                hypothetical_document = self._hyde(expanded_query)

                # 5. Retrieval (Hybrid: Semantic + BM25)
                semantic_results = self._semantic_search(hypothetical_document, top_k=20)  # Use HyDE query
                bm25_results = self._bm25_retrieval(expanded_query, top_k=20)  # Use expanded query

                # 6. Reciprocal Rank Fusion
                fused_results = self._reciprocal_rank_fusion([semantic_results, bm25_results])

                # 7. Reranking
                retrieved_documents = [self.documents[i] for i in fused_results[:20]]  # Limit for reranking
                reranked_indices = self._rerank(expanded_query, retrieved_documents)
                reranked_ids = [fused_results[i] for i in reranked_indices] # Get original doc IDs

                # 8. MMR (Maximum Marginal Relevance)
                query_embedding = self.embedding_model.encode(expanded_query, convert_to_numpy=True)
                mmr_indices = self._mmr(query_embedding, self.document_embeddings, reranked_ids, diversity=self.mmr_diversity, top_n=10)

                # 9. Assemble Context
                context, source_ids = self._assemble_context(mmr_indices)

                logging.info(f"Retrieved document IDs: {source_ids}")
                logging.info(f"Assembled context: {context[:200]}...")

                # 10. Self-RAG: Re-check Retrieval Necessity with Context
                if not self._retrieve_necessity_prediction(query, context):
                    logging.info("Retrieval is now deemed unnecessary based on context.")
                    final_response = self.llm_callable(query) # Answer directly, no context
                    source_ids = []  # No sources needed
                else:
                    # 11. Generate Answer with Context
                    prompt = f"Answer the following question based on the context provided:\n\nContext:\n{context}\n\nQuestion: {query}"
                    final_response = self.llm_callable(prompt)

                # 12. Citation Injection
                final_response = self._inject_citations(final_response, source_ids)

            else:  # Retrieval is not necessary
                logging.info("Retrieval is deemed unnecessary.")
                final_response = self.llm_callable(query) # Answer directly, no context
                source_ids = [] # No sources needed
                final_response = self._inject_citations(final_response, source_ids) # Inject empty citations

            # 13. Confidence Scoring (Placeholder)
            confidence_score = 0.95  # Replace with a real confidence scoring mechanism
            logging.info(f"Confidence score: {confidence_score}")

            end_time = time.time()
            logging.info(f"Query processed in {end_time - start_time:.2f} seconds.")
            return final_response

        except Exception as e:
            logging.exception(f"An error occurred during query processing: {e}")
            return f"An error occurred: {e}"


# Example Usage (Replace with your actual data and LLM)
if __name__ == '__main__':
    # Sample Documents
    documents = [
        "The capital of France is Paris.",
        "Paris is a beautiful city and the most populous city of France.",
        "France is located in Western Europe.",
        "The Eiffel Tower is a famous landmark in Paris.",
        "London is the capital of England.",
        "England is part of the United Kingdom."
    ]
    document_ids = [f"doc{i}" for i in range(len(documents))]

    def dummy_llm(prompt: str) -> str:
        """A dummy LLM for testing."""
        return f"Dummy LLM says: {prompt}"

    # Initialize the RAG engine
    rag_engine = RAGQueryEngine(
        documents=documents,
        document_ids=document_ids,
        llm_callable=dummy_llm,
        max_context_length=500
    )

    # Example Query
    query = "What is the capital of France and where is it located?"
    answer = rag_engine.aiva_query(query)
    print(f"Query: {query}")
    print(f"Answer: {answer}")

    query = "What is the capital of England?"
    answer = rag_engine.aiva_query(query)
    print(f"Query: {query}")
    print(f"Answer: {answer}")

    query = "Tell me about the Eiffel Tower"
    answer = rag_engine.aiva_query(query)
    print(f"Query: {query}")
    print(f"Answer: {answer}")
