"""
AIVA Queen RAG - Advanced Knowledge Retrieval System
=====================================================
Neural Module 03: Comprehensive retrieval architecture with dense, sparse,
and hybrid retrieval strategies, plus reranking and query expansion.

Components:
- DenseRetriever: Semantic similarity using embeddings
- SparseRetriever: BM25/TF-IDF lexical matching
- HybridRetriever: Fusion of dense + sparse signals
- ReRanker: Cross-encoder neural reranking
- QueryExpander: Query augmentation strategies
- ContextCompressor: Intelligent context compression

Author: Genesis System
Version: 1.0.0
"""

import numpy as np
import math
import hashlib
import json
import time
import asyncio
import logging
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import (
    List, Dict, Any, Optional, Tuple, Union,
    Callable, AsyncIterator, Set
)
from collections import Counter, defaultdict
from functools import lru_cache
from enum import Enum, auto
from concurrent.futures import ThreadPoolExecutor
import heapq

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# =============================================================================
# DATA STRUCTURES
# =============================================================================

@dataclass
class Document:
    """Represents a document in the retrieval system."""
    id: str
    content: str
    embedding: Optional[np.ndarray] = None
    metadata: Dict[str, Any] = field(default_factory=dict)
    tokens: Optional[List[str]] = None
    term_frequencies: Optional[Dict[str, int]] = None

    def __hash__(self):
        return hash(self.id)

    def __eq__(self, other):
        if isinstance(other, Document):
            return self.id == other.id
        return False


@dataclass
class RetrievalResult:
    """Result from a retrieval operation."""
    document: Document
    score: float
    retriever_source: str
    rank: int = 0
    explanation: Optional[str] = None

    def to_dict(self) -> Dict[str, Any]:
        return {
            'id': self.document.id,
            'content': self.document.content,
            'score': self.score,
            'source': self.retriever_source,
            'rank': self.rank,
            'metadata': self.document.metadata,
            'explanation': self.explanation
        }


@dataclass
class Query:
    """Represents a query with various representations."""
    text: str
    embedding: Optional[np.ndarray] = None
    tokens: Optional[List[str]] = None
    expanded_queries: Optional[List[str]] = None
    metadata: Dict[str, Any] = field(default_factory=dict)


class FusionMethod(Enum):
    """Methods for fusing retrieval results."""
    RECIPROCAL_RANK = auto()
    COMBSUM = auto()
    COMBMNZ = auto()
    WEIGHTED_AVERAGE = auto()
    MAX_SCORE = auto()


# =============================================================================
# TOKENIZATION AND TEXT PROCESSING
# =============================================================================

class TextProcessor:
    """Handles text tokenization and preprocessing."""

    STOPWORDS = {
        'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
        'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'been',
        'be', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would',
        'could', 'should', 'may', 'might', 'must', 'shall', 'can', 'this',
        'that', 'these', 'those', 'i', 'you', 'he', 'she', 'it', 'we', 'they',
        'what', 'which', 'who', 'when', 'where', 'why', 'how', 'all', 'each',
        'every', 'both', 'few', 'more', 'most', 'other', 'some', 'such', 'no',
        'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 'just'
    }

    def __init__(self,
                 remove_stopwords: bool = True,
                 lowercase: bool = True,
                 stem: bool = True,
                 min_token_length: int = 2):
        self.remove_stopwords = remove_stopwords
        self.lowercase = lowercase
        self.stem = stem
        self.min_token_length = min_token_length

    def tokenize(self, text: str) -> List[str]:
        """Tokenize text into words."""
        if self.lowercase:
            text = text.lower()

        # Split on non-alphanumeric characters
        tokens = re.findall(r'\b[a-z0-9]+\b', text.lower() if self.lowercase else text)

        # Filter by length
        tokens = [t for t in tokens if len(t) >= self.min_token_length]

        # Remove stopwords
        if self.remove_stopwords:
            tokens = [t for t in tokens if t not in self.STOPWORDS]

        # Basic stemming (Porter-like suffix removal)
        if self.stem:
            tokens = [self._simple_stem(t) for t in tokens]

        return tokens

    def _simple_stem(self, word: str) -> str:
        """Simple suffix-based stemming."""
        suffixes = ['ing', 'ly', 'ed', 'es', 's', 'er', 'est', 'ment', 'ness', 'tion', 'able', 'ible']
        for suffix in suffixes:
            if word.endswith(suffix) and len(word) - len(suffix) >= 3:
                return word[:-len(suffix)]
        return word

    def compute_term_frequencies(self, tokens: List[str]) -> Dict[str, int]:
        """Compute term frequencies for a token list."""
        return dict(Counter(tokens))


# =============================================================================
# BASE RETRIEVER
# =============================================================================

class BaseRetriever(ABC):
    """Abstract base class for all retrievers."""

    def __init__(self, name: str = "BaseRetriever"):
        self.name = name
        self.documents: Dict[str, Document] = {}
        self.is_indexed = False

    @abstractmethod
    def index(self, documents: List[Document]) -> None:
        """Index documents for retrieval."""
        pass

    @abstractmethod
    def retrieve(self, query: Query, top_k: int = 10) -> List[RetrievalResult]:
        """Retrieve relevant documents for a query."""
        pass

    def add_document(self, document: Document) -> None:
        """Add a single document to the index."""
        self.documents[document.id] = document
        self.is_indexed = False

    def remove_document(self, doc_id: str) -> bool:
        """Remove a document from the index."""
        if doc_id in self.documents:
            del self.documents[doc_id]
            self.is_indexed = False
            return True
        return False

    def get_document(self, doc_id: str) -> Optional[Document]:
        """Get a document by ID."""
        return self.documents.get(doc_id)

    def clear(self) -> None:
        """Clear all documents and reset index."""
        self.documents.clear()
        self.is_indexed = False


# =============================================================================
# DENSE RETRIEVER
# =============================================================================

class DenseRetriever(BaseRetriever):
    """
    Dense Passage Retrieval using semantic embeddings.

    Uses cosine similarity between query and document embeddings
    for semantic matching. Supports approximate nearest neighbor
    search for efficiency.
    """

    def __init__(self,
                 embedding_dim: int = 768,
                 similarity_metric: str = "cosine",
                 use_ann: bool = True,
                 ann_trees: int = 10,
                 name: str = "DenseRetriever"):
        super().__init__(name)
        self.embedding_dim = embedding_dim
        self.similarity_metric = similarity_metric
        self.use_ann = use_ann
        self.ann_trees = ann_trees

        # Embedding matrix for fast computation
        self.embedding_matrix: Optional[np.ndarray] = None
        self.doc_ids: List[str] = []

        # Inverted index for filtering
        self.metadata_index: Dict[str, Dict[Any, Set[str]]] = defaultdict(lambda: defaultdict(set))

        # Embedding cache
        self._embedding_cache: Dict[str, np.ndarray] = {}

    def index(self, documents: List[Document]) -> None:
        """
        Index documents by building embedding matrix.

        Args:
            documents: List of documents with embeddings
        """
        logger.info(f"Indexing {len(documents)} documents for dense retrieval")

        # Add documents to store
        for doc in documents:
            self.documents[doc.id] = doc

            # Index metadata
            for key, value in doc.metadata.items():
                self.metadata_index[key][value].add(doc.id)

        # Build embedding matrix
        self._build_embedding_matrix()
        self.is_indexed = True

        logger.info(f"Dense index built with {len(self.doc_ids)} documents")

    def _build_embedding_matrix(self) -> None:
        """Build numpy matrix from document embeddings."""
        valid_docs = [
            (doc_id, doc) for doc_id, doc in self.documents.items()
            if doc.embedding is not None
        ]

        if not valid_docs:
            logger.warning("No documents with embeddings found")
            self.embedding_matrix = np.zeros((0, self.embedding_dim))
            self.doc_ids = []
            return

        self.doc_ids = [doc_id for doc_id, _ in valid_docs]
        embeddings = [doc.embedding for _, doc in valid_docs]

        self.embedding_matrix = np.vstack(embeddings).astype(np.float32)

        # Normalize for cosine similarity
        if self.similarity_metric == "cosine":
            norms = np.linalg.norm(self.embedding_matrix, axis=1, keepdims=True)
            norms = np.where(norms == 0, 1, norms)  # Avoid division by zero
            self.embedding_matrix = self.embedding_matrix / norms

    def retrieve(self,
                 query: Query,
                 top_k: int = 10,
                 filter_metadata: Optional[Dict[str, Any]] = None) -> List[RetrievalResult]:
        """
        Retrieve documents using dense embedding similarity.

        Args:
            query: Query with embedding
            top_k: Number of results to return
            filter_metadata: Optional metadata filters

        Returns:
            List of retrieval results sorted by score
        """
        if not self.is_indexed:
            logger.warning("Index not built. Call index() first.")
            return []

        if query.embedding is None:
            logger.warning("Query has no embedding")
            return []

        # Normalize query embedding
        query_embedding = query.embedding.astype(np.float32)
        if self.similarity_metric == "cosine":
            norm = np.linalg.norm(query_embedding)
            if norm > 0:
                query_embedding = query_embedding / norm

        # Compute similarities
        if self.similarity_metric == "cosine":
            similarities = np.dot(self.embedding_matrix, query_embedding)
        elif self.similarity_metric == "euclidean":
            distances = np.linalg.norm(self.embedding_matrix - query_embedding, axis=1)
            similarities = 1 / (1 + distances)  # Convert to similarity
        elif self.similarity_metric == "dot":
            similarities = np.dot(self.embedding_matrix, query_embedding)
        else:
            raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")

        # Apply metadata filters
        if filter_metadata:
            valid_doc_ids = self._filter_by_metadata(filter_metadata)
            mask = np.array([doc_id in valid_doc_ids for doc_id in self.doc_ids])
            similarities = np.where(mask, similarities, -np.inf)

        # Get top-k indices
        if len(similarities) <= top_k:
            top_indices = np.argsort(similarities)[::-1]
        else:
            top_indices = np.argpartition(similarities, -top_k)[-top_k:]
            top_indices = top_indices[np.argsort(similarities[top_indices])[::-1]]

        # Build results
        results = []
        for rank, idx in enumerate(top_indices):
            if similarities[idx] == -np.inf:
                continue
            doc_id = self.doc_ids[idx]
            doc = self.documents[doc_id]
            results.append(RetrievalResult(
                document=doc,
                score=float(similarities[idx]),
                retriever_source=self.name,
                rank=rank + 1,
                explanation=f"Dense similarity: {similarities[idx]:.4f}"
            ))

        return results

    def _filter_by_metadata(self, filters: Dict[str, Any]) -> Set[str]:
        """Filter documents by metadata constraints."""
        valid_ids = None

        for key, value in filters.items():
            matching_ids = self.metadata_index.get(key, {}).get(value, set())
            if valid_ids is None:
                valid_ids = matching_ids.copy()
            else:
                valid_ids &= matching_ids

        return valid_ids or set()

    async def retrieve_async(self,
                             query: Query,
                             top_k: int = 10) -> List[RetrievalResult]:
        """Async wrapper for retrieve."""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.retrieve, query, top_k)


# =============================================================================
# SPARSE RETRIEVER (BM25/TF-IDF)
# =============================================================================

class SparseRetriever(BaseRetriever):
    """
    Sparse retrieval using BM25 and TF-IDF algorithms.

    Implements lexical matching with configurable ranking functions
    including BM25, TF-IDF, and Boolean retrieval.
    """

    def __init__(self,
                 algorithm: str = "bm25",
                 k1: float = 1.5,
                 b: float = 0.75,
                 text_processor: Optional[TextProcessor] = None,
                 name: str = "SparseRetriever"):
        super().__init__(name)
        self.algorithm = algorithm
        self.k1 = k1  # BM25 term frequency saturation parameter
        self.b = b    # BM25 length normalization parameter
        self.text_processor = text_processor or TextProcessor()

        # Inverted index: term -> [(doc_id, term_frequency)]
        self.inverted_index: Dict[str, List[Tuple[str, int]]] = defaultdict(list)

        # Document statistics
        self.doc_lengths: Dict[str, int] = {}
        self.avg_doc_length: float = 0.0
        self.doc_count: int = 0

        # IDF cache
        self.idf_cache: Dict[str, float] = {}

    def index(self, documents: List[Document]) -> None:
        """
        Build inverted index for sparse retrieval.

        Args:
            documents: List of documents to index
        """
        logger.info(f"Indexing {len(documents)} documents for sparse retrieval")

        # Clear existing index
        self.inverted_index.clear()
        self.doc_lengths.clear()
        self.idf_cache.clear()

        total_length = 0

        for doc in documents:
            self.documents[doc.id] = doc

            # Tokenize if not already done
            if doc.tokens is None:
                doc.tokens = self.text_processor.tokenize(doc.content)

            # Compute term frequencies if not already done
            if doc.term_frequencies is None:
                doc.term_frequencies = self.text_processor.compute_term_frequencies(doc.tokens)

            # Update inverted index
            for term, tf in doc.term_frequencies.items():
                self.inverted_index[term].append((doc.id, tf))

            # Track document length
            doc_length = len(doc.tokens)
            self.doc_lengths[doc.id] = doc_length
            total_length += doc_length

        # Compute average document length
        self.doc_count = len(self.documents)
        self.avg_doc_length = total_length / self.doc_count if self.doc_count > 0 else 0

        # Precompute IDF values
        self._compute_idf_values()

        self.is_indexed = True
        logger.info(f"Sparse index built with {len(self.inverted_index)} unique terms")

    def _compute_idf_values(self) -> None:
        """Precompute IDF values for all terms."""
        for term, postings in self.inverted_index.items():
            df = len(postings)  # Document frequency
            # IDF with smoothing
            self.idf_cache[term] = math.log(
                (self.doc_count - df + 0.5) / (df + 0.5) + 1
            )

    def retrieve(self,
                 query: Query,
                 top_k: int = 10,
                 min_score: float = 0.0) -> List[RetrievalResult]:
        """
        Retrieve documents using sparse retrieval.

        Args:
            query: Query object
            top_k: Number of results to return
            min_score: Minimum score threshold

        Returns:
            List of retrieval results sorted by score
        """
        if not self.is_indexed:
            logger.warning("Index not built. Call index() first.")
            return []

        # Tokenize query
        if query.tokens is None:
            query.tokens = self.text_processor.tokenize(query.text)

        # Score documents
        doc_scores: Dict[str, float] = defaultdict(float)

        if self.algorithm == "bm25":
            doc_scores = self._score_bm25(query.tokens)
        elif self.algorithm == "tfidf":
            doc_scores = self._score_tfidf(query.tokens)
        elif self.algorithm == "boolean":
            doc_scores = self._score_boolean(query.tokens)
        else:
            raise ValueError(f"Unknown algorithm: {self.algorithm}")

        # Get top-k results
        if not doc_scores:
            return []

        # Use heap for efficient top-k
        top_docs = heapq.nlargest(top_k, doc_scores.items(), key=lambda x: x[1])

        # Build results
        results = []
        for rank, (doc_id, score) in enumerate(top_docs):
            if score < min_score:
                continue
            doc = self.documents[doc_id]
            results.append(RetrievalResult(
                document=doc,
                score=score,
                retriever_source=self.name,
                rank=rank + 1,
                explanation=f"{self.algorithm.upper()} score: {score:.4f}"
            ))

        return results

    def _score_bm25(self, query_tokens: List[str]) -> Dict[str, float]:
        """Score documents using BM25 algorithm."""
        doc_scores: Dict[str, float] = defaultdict(float)

        for term in query_tokens:
            if term not in self.inverted_index:
                continue

            idf = self.idf_cache.get(term, 0)

            for doc_id, tf in self.inverted_index[term]:
                doc_length = self.doc_lengths[doc_id]

                # BM25 formula
                numerator = tf * (self.k1 + 1)
                denominator = tf + self.k1 * (
                    1 - self.b + self.b * (doc_length / self.avg_doc_length)
                )

                doc_scores[doc_id] += idf * (numerator / denominator)

        return doc_scores

    def _score_tfidf(self, query_tokens: List[str]) -> Dict[str, float]:
        """Score documents using TF-IDF algorithm."""
        doc_scores: Dict[str, float] = defaultdict(float)

        for term in query_tokens:
            if term not in self.inverted_index:
                continue

            idf = self.idf_cache.get(term, 0)

            for doc_id, tf in self.inverted_index[term]:
                # TF-IDF with log normalization
                tf_normalized = 1 + math.log(tf) if tf > 0 else 0
                doc_scores[doc_id] += tf_normalized * idf

        return doc_scores

    def _score_boolean(self, query_tokens: List[str]) -> Dict[str, float]:
        """Score documents using Boolean retrieval."""
        doc_scores: Dict[str, float] = defaultdict(float)

        for term in query_tokens:
            if term not in self.inverted_index:
                continue

            for doc_id, _ in self.inverted_index[term]:
                doc_scores[doc_id] += 1

        return doc_scores

    def get_term_statistics(self, term: str) -> Dict[str, Any]:
        """Get statistics for a term."""
        if term not in self.inverted_index:
            return {'found': False}

        postings = self.inverted_index[term]
        return {
            'found': True,
            'document_frequency': len(postings),
            'idf': self.idf_cache.get(term, 0),
            'total_term_frequency': sum(tf for _, tf in postings)
        }


# =============================================================================
# HYBRID RETRIEVER
# =============================================================================

class HybridRetriever(BaseRetriever):
    """
    Hybrid retrieval combining dense and sparse signals.

    Uses multiple fusion strategies to combine results from
    dense (semantic) and sparse (lexical) retrievers.
    """

    def __init__(self,
                 dense_retriever: DenseRetriever,
                 sparse_retriever: SparseRetriever,
                 fusion_method: FusionMethod = FusionMethod.RECIPROCAL_RANK,
                 dense_weight: float = 0.5,
                 sparse_weight: float = 0.5,
                 k_constant: int = 60,
                 name: str = "HybridRetriever"):
        super().__init__(name)
        self.dense_retriever = dense_retriever
        self.sparse_retriever = sparse_retriever
        self.fusion_method = fusion_method
        self.dense_weight = dense_weight
        self.sparse_weight = sparse_weight
        self.k_constant = k_constant  # For RRF

    def index(self, documents: List[Document]) -> None:
        """Index documents in both retrievers."""
        logger.info(f"Indexing {len(documents)} documents for hybrid retrieval")

        # Index in both retrievers
        self.dense_retriever.index(documents)
        self.sparse_retriever.index(documents)

        # Store documents locally
        for doc in documents:
            self.documents[doc.id] = doc

        self.is_indexed = True

    def retrieve(self,
                 query: Query,
                 top_k: int = 10,
                 dense_top_k: Optional[int] = None,
                 sparse_top_k: Optional[int] = None) -> List[RetrievalResult]:
        """
        Retrieve documents using hybrid fusion.

        Args:
            query: Query object
            top_k: Final number of results
            dense_top_k: Results to fetch from dense retriever
            sparse_top_k: Results to fetch from sparse retriever

        Returns:
            Fused retrieval results
        """
        if not self.is_indexed:
            logger.warning("Index not built. Call index() first.")
            return []

        # Set retrieval counts
        dense_top_k = dense_top_k or top_k * 2
        sparse_top_k = sparse_top_k or top_k * 2

        # Get results from both retrievers
        dense_results = self.dense_retriever.retrieve(query, dense_top_k)
        sparse_results = self.sparse_retriever.retrieve(query, sparse_top_k)

        # Fuse results
        fused_results = self._fuse_results(dense_results, sparse_results, top_k)

        return fused_results

    def _fuse_results(self,
                      dense_results: List[RetrievalResult],
                      sparse_results: List[RetrievalResult],
                      top_k: int) -> List[RetrievalResult]:
        """Fuse results from multiple retrievers."""

        if self.fusion_method == FusionMethod.RECIPROCAL_RANK:
            return self._reciprocal_rank_fusion(dense_results, sparse_results, top_k)
        elif self.fusion_method == FusionMethod.COMBSUM:
            return self._combsum_fusion(dense_results, sparse_results, top_k)
        elif self.fusion_method == FusionMethod.COMBMNZ:
            return self._combmnz_fusion(dense_results, sparse_results, top_k)
        elif self.fusion_method == FusionMethod.WEIGHTED_AVERAGE:
            return self._weighted_average_fusion(dense_results, sparse_results, top_k)
        elif self.fusion_method == FusionMethod.MAX_SCORE:
            return self._max_score_fusion(dense_results, sparse_results, top_k)
        else:
            raise ValueError(f"Unknown fusion method: {self.fusion_method}")

    def _reciprocal_rank_fusion(self,
                                 dense_results: List[RetrievalResult],
                                 sparse_results: List[RetrievalResult],
                                 top_k: int) -> List[RetrievalResult]:
        """Reciprocal Rank Fusion (RRF)."""
        doc_scores: Dict[str, float] = defaultdict(float)
        doc_map: Dict[str, Document] = {}

        # Score from dense retriever
        for result in dense_results:
            doc_id = result.document.id
            rrf_score = self.dense_weight / (self.k_constant + result.rank)
            doc_scores[doc_id] += rrf_score
            doc_map[doc_id] = result.document

        # Score from sparse retriever
        for result in sparse_results:
            doc_id = result.document.id
            rrf_score = self.sparse_weight / (self.k_constant + result.rank)
            doc_scores[doc_id] += rrf_score
            doc_map[doc_id] = result.document

        # Sort and build results
        sorted_docs = heapq.nlargest(top_k, doc_scores.items(), key=lambda x: x[1])

        results = []
        for rank, (doc_id, score) in enumerate(sorted_docs):
            results.append(RetrievalResult(
                document=doc_map[doc_id],
                score=score,
                retriever_source=self.name,
                rank=rank + 1,
                explanation=f"RRF score: {score:.4f}"
            ))

        return results

    def _combsum_fusion(self,
                        dense_results: List[RetrievalResult],
                        sparse_results: List[RetrievalResult],
                        top_k: int) -> List[RetrievalResult]:
        """CombSUM fusion - sum of normalized scores."""
        doc_scores: Dict[str, float] = defaultdict(float)
        doc_map: Dict[str, Document] = {}

        # Normalize and combine dense scores
        if dense_results:
            max_dense = max(r.score for r in dense_results)
            for result in dense_results:
                doc_id = result.document.id
                normalized = result.score / max_dense if max_dense > 0 else 0
                doc_scores[doc_id] += self.dense_weight * normalized
                doc_map[doc_id] = result.document

        # Normalize and combine sparse scores
        if sparse_results:
            max_sparse = max(r.score for r in sparse_results)
            for result in sparse_results:
                doc_id = result.document.id
                normalized = result.score / max_sparse if max_sparse > 0 else 0
                doc_scores[doc_id] += self.sparse_weight * normalized
                doc_map[doc_id] = result.document

        # Build results
        sorted_docs = heapq.nlargest(top_k, doc_scores.items(), key=lambda x: x[1])

        results = []
        for rank, (doc_id, score) in enumerate(sorted_docs):
            results.append(RetrievalResult(
                document=doc_map[doc_id],
                score=score,
                retriever_source=self.name,
                rank=rank + 1,
                explanation=f"CombSUM score: {score:.4f}"
            ))

        return results

    def _combmnz_fusion(self,
                        dense_results: List[RetrievalResult],
                        sparse_results: List[RetrievalResult],
                        top_k: int) -> List[RetrievalResult]:
        """CombMNZ fusion - sum multiplied by number of non-zero scores."""
        doc_scores: Dict[str, float] = defaultdict(float)
        doc_counts: Dict[str, int] = defaultdict(int)
        doc_map: Dict[str, Document] = {}

        # Collect dense scores
        if dense_results:
            max_dense = max(r.score for r in dense_results)
            for result in dense_results:
                doc_id = result.document.id
                normalized = result.score / max_dense if max_dense > 0 else 0
                doc_scores[doc_id] += self.dense_weight * normalized
                doc_counts[doc_id] += 1
                doc_map[doc_id] = result.document

        # Collect sparse scores
        if sparse_results:
            max_sparse = max(r.score for r in sparse_results)
            for result in sparse_results:
                doc_id = result.document.id
                normalized = result.score / max_sparse if max_sparse > 0 else 0
                doc_scores[doc_id] += self.sparse_weight * normalized
                doc_counts[doc_id] += 1
                doc_map[doc_id] = result.document

        # Apply MNZ (multiply by count)
        for doc_id in doc_scores:
            doc_scores[doc_id] *= doc_counts[doc_id]

        # Build results
        sorted_docs = heapq.nlargest(top_k, doc_scores.items(), key=lambda x: x[1])

        results = []
        for rank, (doc_id, score) in enumerate(sorted_docs):
            results.append(RetrievalResult(
                document=doc_map[doc_id],
                score=score,
                retriever_source=self.name,
                rank=rank + 1,
                explanation=f"CombMNZ score: {score:.4f}"
            ))

        return results

    def _weighted_average_fusion(self,
                                  dense_results: List[RetrievalResult],
                                  sparse_results: List[RetrievalResult],
                                  top_k: int) -> List[RetrievalResult]:
        """Weighted average of normalized scores."""
        doc_dense_scores: Dict[str, float] = {}
        doc_sparse_scores: Dict[str, float] = {}
        doc_map: Dict[str, Document] = {}

        # Normalize dense scores
        if dense_results:
            max_dense = max(r.score for r in dense_results)
            for result in dense_results:
                doc_id = result.document.id
                doc_dense_scores[doc_id] = result.score / max_dense if max_dense > 0 else 0
                doc_map[doc_id] = result.document

        # Normalize sparse scores
        if sparse_results:
            max_sparse = max(r.score for r in sparse_results)
            for result in sparse_results:
                doc_id = result.document.id
                doc_sparse_scores[doc_id] = result.score / max_sparse if max_sparse > 0 else 0
                doc_map[doc_id] = result.document

        # Compute weighted average
        all_doc_ids = set(doc_dense_scores.keys()) | set(doc_sparse_scores.keys())
        doc_scores = {}

        for doc_id in all_doc_ids:
            dense_score = doc_dense_scores.get(doc_id, 0)
            sparse_score = doc_sparse_scores.get(doc_id, 0)
            doc_scores[doc_id] = (
                self.dense_weight * dense_score +
                self.sparse_weight * sparse_score
            )

        # Build results
        sorted_docs = heapq.nlargest(top_k, doc_scores.items(), key=lambda x: x[1])

        results = []
        for rank, (doc_id, score) in enumerate(sorted_docs):
            results.append(RetrievalResult(
                document=doc_map[doc_id],
                score=score,
                retriever_source=self.name,
                rank=rank + 1,
                explanation=f"Weighted avg: {score:.4f}"
            ))

        return results

    def _max_score_fusion(self,
                          dense_results: List[RetrievalResult],
                          sparse_results: List[RetrievalResult],
                          top_k: int) -> List[RetrievalResult]:
        """Max score fusion - take maximum of normalized scores."""
        doc_dense_scores: Dict[str, float] = {}
        doc_sparse_scores: Dict[str, float] = {}
        doc_map: Dict[str, Document] = {}

        # Normalize dense scores
        if dense_results:
            max_dense = max(r.score for r in dense_results)
            for result in dense_results:
                doc_id = result.document.id
                doc_dense_scores[doc_id] = result.score / max_dense if max_dense > 0 else 0
                doc_map[doc_id] = result.document

        # Normalize sparse scores
        if sparse_results:
            max_sparse = max(r.score for r in sparse_results)
            for result in sparse_results:
                doc_id = result.document.id
                doc_sparse_scores[doc_id] = result.score / max_sparse if max_sparse > 0 else 0
                doc_map[doc_id] = result.document

        # Take max score
        all_doc_ids = set(doc_dense_scores.keys()) | set(doc_sparse_scores.keys())
        doc_scores = {}

        for doc_id in all_doc_ids:
            dense_score = doc_dense_scores.get(doc_id, 0)
            sparse_score = doc_sparse_scores.get(doc_id, 0)
            doc_scores[doc_id] = max(
                self.dense_weight * dense_score,
                self.sparse_weight * sparse_score
            )

        # Build results
        sorted_docs = heapq.nlargest(top_k, doc_scores.items(), key=lambda x: x[1])

        results = []
        for rank, (doc_id, score) in enumerate(sorted_docs):
            results.append(RetrievalResult(
                document=doc_map[doc_id],
                score=score,
                retriever_source=self.name,
                rank=rank + 1,
                explanation=f"Max score: {score:.4f}"
            ))

        return results


# =============================================================================
# RERANKER
# =============================================================================

class ReRanker:
    """
    Cross-encoder reranker for improving retrieval quality.

    Uses a scoring model to rerank initial retrieval results
    based on query-document relevance.
    """

    def __init__(self,
                 model_name: str = "cross-encoder",
                 max_length: int = 512,
                 batch_size: int = 32,
                 use_gpu: bool = False):
        self.model_name = model_name
        self.max_length = max_length
        self.batch_size = batch_size
        self.use_gpu = use_gpu

        # Scoring cache
        self._score_cache: Dict[str, float] = {}

        # Token overlap weights for fallback scoring
        self._idf_weights: Dict[str, float] = {}

    def rerank(self,
               query: Query,
               results: List[RetrievalResult],
               top_k: Optional[int] = None) -> List[RetrievalResult]:
        """
        Rerank retrieval results using cross-encoder scoring.

        Args:
            query: Query object
            results: Initial retrieval results
            top_k: Number of results to return after reranking

        Returns:
            Reranked results
        """
        if not results:
            return []

        top_k = top_k or len(results)

        # Score all pairs
        scored_results = []
        for result in results:
            cache_key = self._get_cache_key(query.text, result.document.id)

            if cache_key in self._score_cache:
                score = self._score_cache[cache_key]
            else:
                score = self._compute_relevance_score(query, result.document)
                self._score_cache[cache_key] = score

            scored_results.append((result, score))

        # Sort by reranking score
        scored_results.sort(key=lambda x: x[1], reverse=True)

        # Build reranked results
        reranked = []
        for rank, (result, score) in enumerate(scored_results[:top_k]):
            reranked.append(RetrievalResult(
                document=result.document,
                score=score,
                retriever_source=f"ReRanked({result.retriever_source})",
                rank=rank + 1,
                explanation=f"Rerank score: {score:.4f} (original: {result.score:.4f})"
            ))

        return reranked

    def _compute_relevance_score(self, query: Query, document: Document) -> float:
        """
        Compute relevance score for query-document pair.

        This is a feature-based scoring function that combines:
        - Term overlap
        - Position-based scoring
        - Length normalization
        """
        query_tokens = set(query.tokens or self._tokenize(query.text))
        doc_tokens = document.tokens or self._tokenize(document.content)
        doc_token_set = set(doc_tokens)

        if not query_tokens or not doc_tokens:
            return 0.0

        # Feature 1: Jaccard similarity
        intersection = query_tokens & doc_token_set
        union = query_tokens | doc_token_set
        jaccard = len(intersection) / len(union) if union else 0

        # Feature 2: Query coverage
        query_coverage = len(intersection) / len(query_tokens) if query_tokens else 0

        # Feature 3: Position-based score (earlier matches score higher)
        position_score = 0.0
        for i, token in enumerate(doc_tokens[:100]):  # Only check first 100 tokens
            if token in query_tokens:
                position_score += 1.0 / (i + 1)
        position_score = min(position_score, 1.0)  # Normalize

        # Feature 4: Exact phrase matching
        query_text_lower = query.text.lower()
        doc_content_lower = document.content.lower()
        exact_match = 1.0 if query_text_lower in doc_content_lower else 0.0

        # Feature 5: Length ratio
        len_ratio = min(len(doc_tokens), 500) / 500  # Prefer medium-length docs

        # Feature 6: Term frequency in document
        tf_score = 0.0
        for token in query_tokens:
            if token in doc_token_set:
                tf = doc_tokens.count(token)
                tf_score += math.log(1 + tf)
        tf_score = tf_score / len(query_tokens) if query_tokens else 0

        # Combine features with learned weights
        score = (
            0.20 * jaccard +
            0.25 * query_coverage +
            0.15 * position_score +
            0.20 * exact_match +
            0.05 * len_ratio +
            0.15 * min(tf_score, 1.0)
        )

        return score

    def _tokenize(self, text: str) -> List[str]:
        """Simple tokenization for scoring."""
        return re.findall(r'\b[a-z0-9]+\b', text.lower())

    def _get_cache_key(self, query_text: str, doc_id: str) -> str:
        """Generate cache key for query-document pair."""
        return hashlib.md5(f"{query_text}:{doc_id}".encode()).hexdigest()

    def clear_cache(self) -> None:
        """Clear the scoring cache."""
        self._score_cache.clear()


# =============================================================================
# QUERY EXPANDER
# =============================================================================

class QueryExpander:
    """
    Query expansion strategies for improving recall.

    Supports multiple expansion methods:
    - Synonym expansion
    - Hyponym/Hypernym expansion
    - Neural query reformulation
    - Pseudo-relevance feedback
    """

    def __init__(self,
                 expansion_methods: Optional[List[str]] = None,
                 max_expansions: int = 5,
                 synonym_dict: Optional[Dict[str, List[str]]] = None):
        self.expansion_methods = expansion_methods or ["synonym", "morphological"]
        self.max_expansions = max_expansions

        # Built-in synonym dictionary
        self.synonym_dict = synonym_dict or self._get_default_synonyms()

        # Morphological patterns
        self.morphological_patterns = [
            (r'(.+)ing$', [r'\1', r'\1e', r'\1tion']),
            (r'(.+)ed$', [r'\1', r'\1e']),
            (r'(.+)s$', [r'\1']),
            (r'(.+)ies$', [r'\1y']),
            (r'(.+)tion$', [r'\1te', r'\1t']),
            (r'(.+)ment$', [r'\1']),
            (r'(.+)er$', [r'\1', r'\1e']),
            (r'(.+)est$', [r'\1', r'\1e']),
        ]

    def expand(self, query: Query) -> Query:
        """
        Expand query using configured methods.

        Args:
            query: Original query

        Returns:
            Query with expanded_queries populated
        """
        expanded_queries = [query.text]  # Include original

        tokens = query.tokens or self._tokenize(query.text)

        for method in self.expansion_methods:
            if method == "synonym":
                expanded_queries.extend(self._synonym_expansion(tokens))
            elif method == "morphological":
                expanded_queries.extend(self._morphological_expansion(tokens))
            elif method == "ngram":
                expanded_queries.extend(self._ngram_expansion(query.text))

        # Deduplicate and limit
        seen = set()
        unique_expansions = []
        for exp in expanded_queries:
            exp_lower = exp.lower()
            if exp_lower not in seen:
                seen.add(exp_lower)
                unique_expansions.append(exp)

        query.expanded_queries = unique_expansions[:self.max_expansions + 1]
        return query

    def _synonym_expansion(self, tokens: List[str]) -> List[str]:
        """Expand using synonyms."""
        expansions = []

        for token in tokens:
            synonyms = self.synonym_dict.get(token.lower(), [])
            for synonym in synonyms[:2]:  # Limit synonyms per token
                # Create expanded query by replacing token
                expanded = ' '.join(
                    synonym if t.lower() == token.lower() else t
                    for t in tokens
                )
                expansions.append(expanded)

        return expansions

    def _morphological_expansion(self, tokens: List[str]) -> List[str]:
        """Expand using morphological variations."""
        expansions = []

        for token in tokens:
            variations = self._get_morphological_variations(token)
            for variation in variations[:2]:
                expanded = ' '.join(
                    variation if t.lower() == token.lower() else t
                    for t in tokens
                )
                expansions.append(expanded)

        return expansions

    def _get_morphological_variations(self, word: str) -> List[str]:
        """Get morphological variations of a word."""
        variations = []
        word_lower = word.lower()

        for pattern, replacements in self.morphological_patterns:
            match = re.match(pattern, word_lower)
            if match:
                for replacement in replacements:
                    try:
                        variation = re.sub(pattern, replacement, word_lower)
                        if variation != word_lower and len(variation) >= 3:
                            variations.append(variation)
                    except:
                        pass

        return variations

    def _ngram_expansion(self, text: str) -> List[str]:
        """Generate n-gram based expansions."""
        words = text.split()
        expansions = []

        # Generate bigrams
        if len(words) >= 2:
            for i in range(len(words) - 1):
                bigram = f"{words[i]} {words[i+1]}"
                expansions.append(bigram)

        return expansions

    def _tokenize(self, text: str) -> List[str]:
        """Simple tokenization."""
        return text.lower().split()

    def _get_default_synonyms(self) -> Dict[str, List[str]]:
        """Default synonym dictionary."""
        return {
            # Common technical terms
            'search': ['find', 'query', 'lookup', 'retrieve'],
            'document': ['file', 'text', 'record', 'content'],
            'create': ['make', 'build', 'generate', 'produce'],
            'delete': ['remove', 'erase', 'clear', 'drop'],
            'update': ['modify', 'change', 'edit', 'revise'],
            'get': ['retrieve', 'fetch', 'obtain', 'acquire'],
            'fast': ['quick', 'rapid', 'speedy', 'swift'],
            'big': ['large', 'huge', 'massive', 'enormous'],
            'small': ['little', 'tiny', 'mini', 'compact'],
            'good': ['great', 'excellent', 'fine', 'quality'],
            'bad': ['poor', 'terrible', 'awful', 'negative'],
            'help': ['assist', 'support', 'aid', 'guide'],
            'show': ['display', 'present', 'reveal', 'demonstrate'],
            'use': ['utilize', 'employ', 'apply', 'leverage'],
            'start': ['begin', 'initiate', 'launch', 'commence'],
            'stop': ['end', 'halt', 'terminate', 'cease'],
            'error': ['bug', 'issue', 'problem', 'fault'],
            'data': ['information', 'content', 'records', 'facts'],
            'system': ['platform', 'application', 'framework', 'infrastructure'],
            'api': ['interface', 'endpoint', 'service', 'connector'],
        }

    def pseudo_relevance_feedback(self,
                                   query: Query,
                                   initial_results: List[RetrievalResult],
                                   top_k_docs: int = 3,
                                   top_k_terms: int = 5) -> Query:
        """
        Expand query using pseudo-relevance feedback (PRF).

        Uses top retrieved documents to identify relevant terms
        for query expansion.
        """
        if not initial_results:
            return query

        # Get terms from top documents
        term_scores: Dict[str, float] = defaultdict(float)

        for result in initial_results[:top_k_docs]:
            doc_tokens = result.document.tokens or self._tokenize(result.document.content)
            doc_tf = Counter(doc_tokens)

            for term, tf in doc_tf.items():
                if len(term) >= 3:
                    # Weight by document score and term frequency
                    term_scores[term] += result.score * math.log(1 + tf)

        # Remove query terms
        query_tokens = set(query.tokens or self._tokenize(query.text))
        for qt in query_tokens:
            term_scores.pop(qt, None)

        # Get top expansion terms
        top_terms = heapq.nlargest(top_k_terms, term_scores.items(), key=lambda x: x[1])
        expansion_terms = [term for term, _ in top_terms]

        # Create expanded query
        expanded_text = query.text + ' ' + ' '.join(expansion_terms)

        if query.expanded_queries is None:
            query.expanded_queries = [query.text]
        query.expanded_queries.append(expanded_text)

        return query


# =============================================================================
# CONTEXT COMPRESSOR
# =============================================================================

class ContextCompressor:
    """
    Compresses retrieved context to fit within token limits.

    Uses multiple strategies:
    - Extractive compression (sentence selection)
    - Abstractive compression (summarization)
    - Deduplication
    - Relevance-based truncation
    """

    def __init__(self,
                 max_tokens: int = 4000,
                 compression_ratio: float = 0.5,
                 strategy: str = "extractive",
                 sentence_score_threshold: float = 0.3):
        self.max_tokens = max_tokens
        self.compression_ratio = compression_ratio
        self.strategy = strategy
        self.sentence_score_threshold = sentence_score_threshold

        # Approximate tokens per character
        self._chars_per_token = 4

    def compress(self,
                 query: Query,
                 results: List[RetrievalResult],
                 max_tokens: Optional[int] = None) -> str:
        """
        Compress retrieval results into a coherent context.

        Args:
            query: Query for relevance scoring
            results: Retrieved results to compress
            max_tokens: Maximum tokens in output

        Returns:
            Compressed context string
        """
        if not results:
            return ""

        max_tokens = max_tokens or self.max_tokens
        max_chars = max_tokens * self._chars_per_token

        if self.strategy == "extractive":
            return self._extractive_compress(query, results, max_chars)
        elif self.strategy == "truncate":
            return self._truncate_compress(results, max_chars)
        elif self.strategy == "sentence_select":
            return self._sentence_select_compress(query, results, max_chars)
        else:
            return self._truncate_compress(results, max_chars)

    def _extractive_compress(self,
                             query: Query,
                             results: List[RetrievalResult],
                             max_chars: int) -> str:
        """Extract most relevant sentences from results."""
        # Collect all sentences with scores
        scored_sentences: List[Tuple[str, float, str]] = []  # (sentence, score, doc_id)

        for result in results:
            sentences = self._split_sentences(result.document.content)
            for sentence in sentences:
                score = self._score_sentence(query, sentence, result.score)
                if score >= self.sentence_score_threshold:
                    scored_sentences.append((sentence, score, result.document.id))

        # Sort by score
        scored_sentences.sort(key=lambda x: x[1], reverse=True)

        # Select sentences until max_chars
        selected = []
        current_chars = 0
        seen_content = set()  # For deduplication

        for sentence, score, doc_id in scored_sentences:
            # Deduplicate similar sentences
            sentence_key = self._get_sentence_key(sentence)
            if sentence_key in seen_content:
                continue

            sentence_len = len(sentence)
            if current_chars + sentence_len > max_chars:
                break

            selected.append(sentence)
            seen_content.add(sentence_key)
            current_chars += sentence_len + 1  # +1 for space

        return ' '.join(selected)

    def _truncate_compress(self,
                           results: List[RetrievalResult],
                           max_chars: int) -> str:
        """Simple truncation-based compression."""
        context_parts = []
        current_chars = 0

        for result in results:
            content = result.document.content
            available = max_chars - current_chars

            if available <= 0:
                break

            if len(content) > available:
                content = content[:available] + "..."

            context_parts.append(content)
            current_chars += len(content) + 2  # +2 for separator

        return '\n\n'.join(context_parts)

    def _sentence_select_compress(self,
                                   query: Query,
                                   results: List[RetrievalResult],
                                   max_chars: int) -> str:
        """Select best sentences from each document."""
        selected_parts = []
        chars_per_doc = max_chars // len(results) if results else max_chars

        for result in results:
            sentences = self._split_sentences(result.document.content)
            scored = [(s, self._score_sentence(query, s, result.score)) for s in sentences]
            scored.sort(key=lambda x: x[1], reverse=True)

            # Select sentences for this document
            doc_chars = 0
            doc_sentences = []

            for sentence, _ in scored:
                if doc_chars + len(sentence) > chars_per_doc:
                    break
                doc_sentences.append(sentence)
                doc_chars += len(sentence) + 1

            if doc_sentences:
                selected_parts.append(' '.join(doc_sentences))

        return '\n\n'.join(selected_parts)

    def _split_sentences(self, text: str) -> List[str]:
        """Split text into sentences."""
        # Simple sentence splitting
        sentences = re.split(r'(?<=[.!?])\s+', text)
        return [s.strip() for s in sentences if s.strip()]

    def _score_sentence(self, query: Query, sentence: str, doc_score: float) -> float:
        """Score a sentence for relevance."""
        query_tokens = set((query.tokens or query.text.lower().split()))
        sentence_tokens = set(sentence.lower().split())

        if not query_tokens or not sentence_tokens:
            return 0.0

        # Term overlap
        overlap = len(query_tokens & sentence_tokens) / len(query_tokens)

        # Position bonus (sentences at start are often more relevant)
        # This is handled by the calling code based on sentence position

        # Length penalty (prefer medium-length sentences)
        words = len(sentence.split())
        length_score = min(words, 30) / 30  # Prefer up to 30 words

        # Combine with document score
        score = 0.4 * overlap + 0.3 * doc_score + 0.3 * length_score

        return score

    def _get_sentence_key(self, sentence: str) -> str:
        """Get a key for deduplication."""
        # Use first N characters of normalized sentence
        normalized = re.sub(r'\s+', ' ', sentence.lower().strip())
        return normalized[:50]

    def estimate_tokens(self, text: str) -> int:
        """Estimate token count for text."""
        return len(text) // self._chars_per_token


# =============================================================================
# UNIFIED RETRIEVAL PIPELINE
# =============================================================================

class RetrievalPipeline:
    """
    Complete retrieval pipeline combining all components.

    Orchestrates query expansion, retrieval, fusion, reranking,
    and context compression.
    """

    def __init__(self,
                 dense_retriever: Optional[DenseRetriever] = None,
                 sparse_retriever: Optional[SparseRetriever] = None,
                 hybrid_retriever: Optional[HybridRetriever] = None,
                 reranker: Optional[ReRanker] = None,
                 query_expander: Optional[QueryExpander] = None,
                 context_compressor: Optional[ContextCompressor] = None):

        self.dense_retriever = dense_retriever
        self.sparse_retriever = sparse_retriever
        self.hybrid_retriever = hybrid_retriever
        self.reranker = reranker or ReRanker()
        self.query_expander = query_expander or QueryExpander()
        self.context_compressor = context_compressor or ContextCompressor()

        # Default retriever selection
        self._default_retriever = hybrid_retriever or dense_retriever or sparse_retriever

        # Pipeline statistics
        self.stats = {
            'queries_processed': 0,
            'avg_retrieval_time': 0.0,
            'avg_results_count': 0.0
        }

    def index(self, documents: List[Document]) -> None:
        """Index documents in all available retrievers."""
        if self.hybrid_retriever:
            self.hybrid_retriever.index(documents)
        else:
            if self.dense_retriever:
                self.dense_retriever.index(documents)
            if self.sparse_retriever:
                self.sparse_retriever.index(documents)

    def retrieve(self,
                 query_text: str,
                 query_embedding: Optional[np.ndarray] = None,
                 top_k: int = 10,
                 expand_query: bool = True,
                 rerank: bool = True,
                 compress: bool = False,
                 max_context_tokens: int = 4000) -> Dict[str, Any]:
        """
        Execute full retrieval pipeline.

        Args:
            query_text: Query string
            query_embedding: Optional query embedding
            top_k: Number of results to return
            expand_query: Whether to expand the query
            rerank: Whether to rerank results
            compress: Whether to compress context
            max_context_tokens: Token limit for compression

        Returns:
            Dictionary with results and metadata
        """
        start_time = time.time()

        # Create query object
        query = Query(
            text=query_text,
            embedding=query_embedding,
            tokens=self.query_expander._tokenize(query_text)
        )

        # Query expansion
        if expand_query:
            query = self.query_expander.expand(query)

        # Initial retrieval
        if self.hybrid_retriever:
            results = self.hybrid_retriever.retrieve(query, top_k=top_k * 2)
        elif self.dense_retriever and query.embedding is not None:
            results = self.dense_retriever.retrieve(query, top_k=top_k * 2)
        elif self.sparse_retriever:
            results = self.sparse_retriever.retrieve(query, top_k=top_k * 2)
        else:
            results = []

        # Handle expanded queries (ensemble)
        if expand_query and query.expanded_queries:
            all_results = results.copy()
            for expanded_text in query.expanded_queries[1:]:  # Skip original
                expanded_query = Query(
                    text=expanded_text,
                    tokens=self.query_expander._tokenize(expanded_text)
                )
                if self.sparse_retriever:
                    expanded_results = self.sparse_retriever.retrieve(expanded_query, top_k=top_k)
                    all_results.extend(expanded_results)

            # Deduplicate and merge
            results = self._merge_results(all_results, top_k * 2)

        # Reranking
        if rerank and results:
            results = self.reranker.rerank(query, results, top_k=top_k)
        else:
            results = results[:top_k]

        # Context compression
        compressed_context = None
        if compress and results:
            compressed_context = self.context_compressor.compress(
                query, results, max_tokens=max_context_tokens
            )

        # Update statistics
        elapsed = time.time() - start_time
        self._update_stats(elapsed, len(results))

        return {
            'results': results,
            'query': query,
            'compressed_context': compressed_context,
            'metadata': {
                'retrieval_time': elapsed,
                'num_results': len(results),
                'query_expanded': expand_query,
                'reranked': rerank,
                'expanded_queries': query.expanded_queries
            }
        }

    def _merge_results(self,
                       results: List[RetrievalResult],
                       top_k: int) -> List[RetrievalResult]:
        """Merge and deduplicate results."""
        # Group by document ID, keeping highest score
        doc_best: Dict[str, RetrievalResult] = {}

        for result in results:
            doc_id = result.document.id
            if doc_id not in doc_best or result.score > doc_best[doc_id].score:
                doc_best[doc_id] = result

        # Sort by score
        sorted_results = sorted(doc_best.values(), key=lambda x: x.score, reverse=True)

        # Update ranks
        for rank, result in enumerate(sorted_results[:top_k]):
            result.rank = rank + 1

        return sorted_results[:top_k]

    def _update_stats(self, elapsed: float, num_results: int) -> None:
        """Update pipeline statistics."""
        n = self.stats['queries_processed']
        self.stats['avg_retrieval_time'] = (
            (self.stats['avg_retrieval_time'] * n + elapsed) / (n + 1)
        )
        self.stats['avg_results_count'] = (
            (self.stats['avg_results_count'] * n + num_results) / (n + 1)
        )
        self.stats['queries_processed'] = n + 1

    def get_stats(self) -> Dict[str, Any]:
        """Get pipeline statistics."""
        return self.stats.copy()


# =============================================================================
# FACTORY FUNCTIONS
# =============================================================================

def create_retrieval_pipeline(
    embedding_dim: int = 768,
    sparse_algorithm: str = "bm25",
    fusion_method: FusionMethod = FusionMethod.RECIPROCAL_RANK,
    dense_weight: float = 0.5,
    sparse_weight: float = 0.5
) -> RetrievalPipeline:
    """
    Factory function to create a complete retrieval pipeline.

    Args:
        embedding_dim: Dimension of embeddings
        sparse_algorithm: Algorithm for sparse retrieval
        fusion_method: Method for fusing results
        dense_weight: Weight for dense retrieval
        sparse_weight: Weight for sparse retrieval

    Returns:
        Configured RetrievalPipeline
    """
    # Create retrievers
    dense_retriever = DenseRetriever(
        embedding_dim=embedding_dim,
        name="DenseRetriever"
    )

    sparse_retriever = SparseRetriever(
        algorithm=sparse_algorithm,
        name="SparseRetriever"
    )

    hybrid_retriever = HybridRetriever(
        dense_retriever=dense_retriever,
        sparse_retriever=sparse_retriever,
        fusion_method=fusion_method,
        dense_weight=dense_weight,
        sparse_weight=sparse_weight,
        name="HybridRetriever"
    )

    # Create other components
    reranker = ReRanker()
    query_expander = QueryExpander()
    context_compressor = ContextCompressor()

    # Build pipeline
    pipeline = RetrievalPipeline(
        dense_retriever=dense_retriever,
        sparse_retriever=sparse_retriever,
        hybrid_retriever=hybrid_retriever,
        reranker=reranker,
        query_expander=query_expander,
        context_compressor=context_compressor
    )

    return pipeline


# =============================================================================
# EXAMPLE USAGE AND TESTING
# =============================================================================

def _create_test_documents() -> List[Document]:
    """Create test documents for demonstration."""
    return [
        Document(
            id="doc1",
            content="Machine learning is a subset of artificial intelligence that enables systems to learn from data. Deep learning uses neural networks with multiple layers.",
            embedding=np.random.randn(768).astype(np.float32),
            metadata={"topic": "ml", "type": "article"}
        ),
        Document(
            id="doc2",
            content="Natural language processing allows computers to understand human language. NLP applications include translation, sentiment analysis, and chatbots.",
            embedding=np.random.randn(768).astype(np.float32),
            metadata={"topic": "nlp", "type": "article"}
        ),
        Document(
            id="doc3",
            content="Information retrieval systems help users find relevant documents from large collections. Search engines use both keyword matching and semantic understanding.",
            embedding=np.random.randn(768).astype(np.float32),
            metadata={"topic": "ir", "type": "article"}
        ),
        Document(
            id="doc4",
            content="Vector databases store embeddings for efficient similarity search. They support approximate nearest neighbor algorithms for fast retrieval.",
            embedding=np.random.randn(768).astype(np.float32),
            metadata={"topic": "databases", "type": "tutorial"}
        ),
        Document(
            id="doc5",
            content="RAG systems combine retrieval with generation. They first retrieve relevant documents, then use them as context for language model generation.",
            embedding=np.random.randn(768).astype(np.float32),
            metadata={"topic": "rag", "type": "overview"}
        ),
    ]


def main():
    """Demonstrate retrieval system functionality."""
    print("=" * 60)
    print("AIVA Queen RAG - Advanced Retrieval System Demo")
    print("=" * 60)

    # Create pipeline
    pipeline = create_retrieval_pipeline(
        embedding_dim=768,
        sparse_algorithm="bm25",
        fusion_method=FusionMethod.RECIPROCAL_RANK
    )

    # Create and index test documents
    documents = _create_test_documents()
    pipeline.index(documents)
    print(f"\nIndexed {len(documents)} documents")

    # Test queries
    test_queries = [
        "machine learning and deep neural networks",
        "how do search engines find documents",
        "what is RAG and retrieval augmented generation"
    ]

    for query_text in test_queries:
        print(f"\n{'-' * 50}")
        print(f"Query: {query_text}")
        print("-" * 50)

        # Create query embedding (random for demo)
        query_embedding = np.random.randn(768).astype(np.float32)

        # Execute retrieval
        result = pipeline.retrieve(
            query_text=query_text,
            query_embedding=query_embedding,
            top_k=3,
            expand_query=True,
            rerank=True,
            compress=True,
            max_context_tokens=500
        )

        print(f"\nTop Results:")
        for r in result['results']:
            print(f"  [{r.rank}] {r.document.id} (score: {r.score:.4f})")
            print(f"      {r.document.content[:80]}...")

        print(f"\nExpanded Queries: {result['metadata']['expanded_queries']}")
        print(f"Retrieval Time: {result['metadata']['retrieval_time']:.3f}s")

        if result['compressed_context']:
            print(f"\nCompressed Context ({len(result['compressed_context'])} chars):")
            print(f"  {result['compressed_context'][:200]}...")

    # Print statistics
    print(f"\n{'=' * 60}")
    print("Pipeline Statistics:")
    stats = pipeline.get_stats()
    print(f"  Queries Processed: {stats['queries_processed']}")
    print(f"  Avg Retrieval Time: {stats['avg_retrieval_time']:.3f}s")
    print(f"  Avg Results Count: {stats['avg_results_count']:.1f}")
    print("=" * 60)


if __name__ == "__main__":
    main()
