"""
Genesis Memory System - Hybrid Retriever
=========================================
Combines vector, keyword, temporal, and graph search for optimal recall.
"""

from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from datetime import datetime, timedelta
import sys

sys.path.insert(0, '/mnt/e/genesis-system/genesis-memory')
from storage.postgresql_store import get_postgresql_store, Episode
from storage.redis_store import get_redis_store, WorkingMemory
from storage.qdrant_store import get_qdrant_store, VectorSearchResult
from intelligence.embedding_generator import get_embedding_generator
from intelligence.voi_scoring import get_voi_scorer, VoIScore


@dataclass
class RetrievalResult:
    """Unified result from hybrid retrieval."""
    id: str
    content: str
    score: float
    tier: str  # 'working', 'episodic', 'semantic'
    source: str
    strength: float
    importance: float
    created_at: Optional[str]
    metadata: Dict[str, Any]


class HybridRetriever:
    """
    Orchestrates multi-source memory retrieval.

    Strategies:
    - vector: Pure semantic similarity via Qdrant
    - keyword: Text matching in PostgreSQL
    - temporal: Time-based retrieval
    - graph: Traverse memory connections
    - hybrid: Combined scoring from multiple sources

    Result Fusion:
    - Reciprocal Rank Fusion (RRF) for combining ranked lists
    - Configurable weights per source
    """

    # Default weights for hybrid scoring
    DEFAULT_WEIGHTS = {
        "vector": 0.4,
        "keyword": 0.2,
        "temporal": 0.2,
        "working": 0.2
    }

    def __init__(self):
        self.pg_store = get_postgresql_store()
        self.redis_store = get_redis_store()
        self.qdrant_store = get_qdrant_store()
        self.embedding_gen = get_embedding_generator()
        self.voi_scorer = get_voi_scorer()

    def retrieve(
        self,
        query: str,
        strategy: str = "hybrid",
        limit: int = 10,
        agent_id: Optional[str] = None,
        include_working: bool = True,
        include_episodic: bool = True,
        include_semantic: bool = True,
        min_importance: float = 0.0,
        min_strength: float = 0.0,
        time_range: Optional[tuple] = None,
        weights: Optional[Dict[str, float]] = None
    ) -> List[RetrievalResult]:
        """
        Retrieve memories using specified strategy.

        Args:
            query: Search query text
            strategy: 'vector', 'keyword', 'temporal', 'graph', or 'hybrid'
            limit: Maximum results to return
            agent_id: Filter by specific agent
            include_working: Include working memory
            include_episodic: Include episodic memory
            include_semantic: Include semantic entities
            min_importance: Minimum importance threshold
            min_strength: Minimum strength threshold
            time_range: (start_datetime, end_datetime) tuple
            weights: Custom weights for hybrid scoring

        Returns:
            List of RetrievalResult sorted by relevance
        """
        if strategy == "vector":
            return self._vector_search(
                query, limit, agent_id, min_importance, min_strength
            )
        elif strategy == "keyword":
            return self._keyword_search(
                query, limit, agent_id, min_importance, min_strength
            )
        elif strategy == "temporal":
            return self._temporal_search(
                query, limit, agent_id, time_range
            )
        elif strategy == "graph":
            return self._graph_search(query, limit)
        else:  # hybrid
            return self._hybrid_search(
                query, limit, agent_id,
                include_working, include_episodic, include_semantic,
                min_importance, min_strength,
                weights or self.DEFAULT_WEIGHTS
            )

    def _vector_search(
        self,
        query: str,
        limit: int,
        agent_id: Optional[str],
        min_importance: float,
        min_strength: float
    ) -> List[RetrievalResult]:
        """Pure vector similarity search via Qdrant."""
        # Generate query embedding
        emb_result = self.embedding_gen.generate(query)

        # Search Qdrant
        results = self.qdrant_store.search_episodes(
            query_embedding=emb_result.embedding,
            limit=limit,
            agent_id=agent_id,
            min_importance=min_importance,
            min_strength=min_strength
        )

        return [
            RetrievalResult(
                id=r.id,
                content=r.content,
                score=r.score,
                tier="episodic",
                source="vector_search",
                strength=r.payload.get("strength", 1.0),
                importance=r.payload.get("importance", 0.5),
                created_at=r.payload.get("created_at"),
                metadata=r.payload
            )
            for r in results
        ]

    def _keyword_search(
        self,
        query: str,
        limit: int,
        agent_id: Optional[str],
        min_importance: float,
        min_strength: float
    ) -> List[RetrievalResult]:
        """Keyword/text search via PostgreSQL."""
        episodes = self.pg_store.search_episodes(
            query=query,
            agent_id=agent_id,
            min_importance=min_importance,
            min_strength=min_strength,
            limit=limit
        )

        return [
            RetrievalResult(
                id=ep.episode_id,
                content=ep.content,
                score=ep.importance_score,  # Use importance as score proxy
                tier="episodic",
                source="keyword_search",
                strength=ep.strength,
                importance=ep.importance_score,
                created_at=ep.created_at,
                metadata={
                    "source_type": ep.source_type,
                    "agent_id": ep.agent_id,
                    "consolidation_count": ep.consolidation_count
                }
            )
            for ep in episodes
        ]

    def _temporal_search(
        self,
        query: str,
        limit: int,
        agent_id: Optional[str],
        time_range: Optional[tuple]
    ) -> List[RetrievalResult]:
        """Time-based search prioritizing recent memories."""
        # Search with keyword match but order by recency
        episodes = self.pg_store.search_episodes(
            query=query if query else None,
            agent_id=agent_id,
            limit=limit * 2  # Get more for time-based reranking
        )

        # Filter by time range if specified
        if time_range:
            start, end = time_range
            episodes = [
                ep for ep in episodes
                if ep.created_at and start <= ep.created_at <= end
            ]

        # Sort by recency (most recent first)
        episodes.sort(
            key=lambda x: x.created_at or "",
            reverse=True
        )

        return [
            RetrievalResult(
                id=ep.episode_id,
                content=ep.content,
                score=1.0 - (i * 0.1),  # Decay score by position
                tier="episodic",
                source="temporal_search",
                strength=ep.strength,
                importance=ep.importance_score,
                created_at=ep.created_at,
                metadata={"source_type": ep.source_type}
            )
            for i, ep in enumerate(episodes[:limit])
        ]

    def _graph_search(
        self,
        query: str,
        limit: int
    ) -> List[RetrievalResult]:
        """Graph traversal search via memory connections."""
        # First find seed episodes via keyword
        seeds = self.pg_store.search_episodes(query=query, limit=3)

        if not seeds:
            return []

        results = []
        seen = set()

        for seed in seeds:
            # Get connected episodes
            connected = self.pg_store.get_connected_episodes(
                seed.episode_id,
                limit=limit
            )

            for ep, conn in connected:
                if ep.episode_id in seen:
                    continue
                seen.add(ep.episode_id)

                results.append(RetrievalResult(
                    id=ep.episode_id,
                    content=ep.content,
                    score=conn.strength * conn.confidence,
                    tier="episodic",
                    source="graph_search",
                    strength=ep.strength,
                    importance=ep.importance_score,
                    created_at=None,
                    metadata={
                        "relation_type": conn.relation_type,
                        "connection_strength": conn.strength
                    }
                ))

        # Sort by connection strength
        results.sort(key=lambda x: x.score, reverse=True)
        return results[:limit]

    def _hybrid_search(
        self,
        query: str,
        limit: int,
        agent_id: Optional[str],
        include_working: bool,
        include_episodic: bool,
        include_semantic: bool,
        min_importance: float,
        min_strength: float,
        weights: Dict[str, float]
    ) -> List[RetrievalResult]:
        """
        Combined search using multiple strategies.

        Uses Reciprocal Rank Fusion (RRF) for score combination.
        """
        all_results = {}  # id -> (result, scores_dict)

        # 1. Vector search
        if include_episodic and weights.get("vector", 0) > 0:
            vector_results = self._vector_search(
                query, limit * 2, agent_id, min_importance, min_strength
            )
            for i, r in enumerate(vector_results):
                if r.id not in all_results:
                    all_results[r.id] = (r, {})
                all_results[r.id][1]["vector"] = 1 / (i + 1)  # RRF score

        # 2. Keyword search
        if include_episodic and weights.get("keyword", 0) > 0:
            keyword_results = self._keyword_search(
                query, limit * 2, agent_id, min_importance, min_strength
            )
            for i, r in enumerate(keyword_results):
                if r.id not in all_results:
                    all_results[r.id] = (r, {})
                all_results[r.id][1]["keyword"] = 1 / (i + 1)

        # 3. Temporal search
        if include_episodic and weights.get("temporal", 0) > 0:
            temporal_results = self._temporal_search(
                query, limit * 2, agent_id, None
            )
            for i, r in enumerate(temporal_results):
                if r.id not in all_results:
                    all_results[r.id] = (r, {})
                all_results[r.id][1]["temporal"] = 1 / (i + 1)

        # 4. Working memory search
        if include_working and agent_id and weights.get("working", 0) > 0:
            working = self.redis_store.search_working_memory(
                agent_id, query, limit
            )
            for i, wm in enumerate(working):
                result = RetrievalResult(
                    id=wm.memory_id,
                    content=wm.content,
                    score=wm.importance,
                    tier="working",
                    source="working_memory",
                    strength=1.0,
                    importance=wm.importance,
                    created_at=None,
                    metadata={"agent_id": wm.agent_id}
                )
                if result.id not in all_results:
                    all_results[result.id] = (result, {})
                all_results[result.id][1]["working"] = 1 / (i + 1)

        # Combine scores using RRF with weights
        combined = []
        for result_id, (result, scores) in all_results.items():
            combined_score = sum(
                scores.get(source, 0) * weight
                for source, weight in weights.items()
            )
            result.score = combined_score
            result.metadata["score_breakdown"] = scores
            combined.append(result)

        # Sort by combined score
        combined.sort(key=lambda x: x.score, reverse=True)

        return combined[:limit]

    def retrieve_by_entity(
        self,
        entity_query: str,
        limit: int = 10
    ) -> List[RetrievalResult]:
        """Retrieve memories related to a semantic entity."""
        # Generate entity embedding
        emb_result = self.embedding_gen.generate(entity_query)

        # Search entity embeddings
        entities = self.qdrant_store.search_entities(
            query_embedding=emb_result.embedding,
            limit=5
        )

        results = []
        for entity in entities:
            entity_id = entity.id

            # Get episodes mentioning this entity
            episodes = self.pg_store.get_entity_episodes(entity_id, limit=limit)

            for ep in episodes:
                results.append(RetrievalResult(
                    id=ep.episode_id,
                    content=ep.content,
                    score=entity.score,  # Entity similarity as score
                    tier="semantic",
                    source="entity_search",
                    strength=ep.strength,
                    importance=ep.importance_score,
                    created_at=None,
                    metadata={
                        "entity_id": entity_id,
                        "entity_name": entity.payload.get("name"),
                        "entity_type": entity.payload.get("entity_type")
                    }
                ))

        # Deduplicate and sort
        seen = set()
        unique = []
        for r in sorted(results, key=lambda x: x.score, reverse=True):
            if r.id not in seen:
                seen.add(r.id)
                unique.append(r)

        return unique[:limit]

    def get_context_for_agent(
        self,
        agent_id: str,
        current_task: Optional[str] = None,
        limit: int = 20
    ) -> List[RetrievalResult]:
        """
        Get relevant context for an agent's current task.

        Combines:
        - Working memory contents
        - Recent episodic memories
        - Task-relevant memories (if task provided)
        """
        results = []

        # 1. Get working memory
        working = self.redis_store.get_working_memory(agent_id, limit=10)
        for wm in working:
            results.append(RetrievalResult(
                id=wm.memory_id,
                content=wm.content,
                score=wm.importance * 1.5,  # Boost working memory
                tier="working",
                source="working_memory",
                strength=1.0,
                importance=wm.importance,
                created_at=None,
                metadata={"agent_id": wm.agent_id}
            ))

        # 2. Get recent episodic memories for this agent
        recent = self.pg_store.search_episodes(
            agent_id=agent_id,
            limit=10
        )
        for ep in recent:
            results.append(RetrievalResult(
                id=ep.episode_id,
                content=ep.content,
                score=ep.importance_score,
                tier="episodic",
                source="recent_episodic",
                strength=ep.strength,
                importance=ep.importance_score,
                created_at=ep.created_at,
                metadata={"source_type": ep.source_type}
            ))

        # 3. If task provided, search for relevant memories
        if current_task:
            relevant = self.retrieve(
                query=current_task,
                strategy="hybrid",
                limit=10,
                agent_id=agent_id
            )
            for r in relevant:
                r.score *= 1.2  # Boost task-relevant
                results.append(r)

        # Deduplicate and sort
        seen = set()
        unique = []
        for r in sorted(results, key=lambda x: x.score, reverse=True):
            if r.id not in seen:
                seen.add(r.id)
                unique.append(r)

        return unique[:limit]

    # =========================================================================
    # VoI-ENHANCED RETRIEVAL
    # =========================================================================

    def retrieve_with_voi(
        self,
        query: str,
        limit: int = 10,
        strategy: str = "hybrid",
        agent_id: Optional[str] = None,
        min_voi: float = 0.0,
        voi_weight: float = 0.3
    ) -> List[RetrievalResult]:
        """
        Retrieve memories with VoI-enhanced ranking.

        Combines base retrieval score with VoI score for optimal results.

        Args:
            query: Search query
            limit: Maximum results
            strategy: Base retrieval strategy
            agent_id: Filter by agent
            min_voi: Minimum VoI threshold
            voi_weight: Weight given to VoI in final score (0.0-1.0)

        Returns:
            List of RetrievalResult with VoI-enhanced scores
        """
        from datetime import datetime, timezone

        # Get base results
        base_results = self.retrieve(
            query=query,
            strategy=strategy,
            limit=limit * 2,  # Get more for VoI filtering
            agent_id=agent_id
        )

        # Calculate VoI for each and combine scores
        enhanced_results = []
        for r in base_results:
            # Calculate VoI
            voi = self.voi_scorer.calculate_voi(
                memory_id=r.id,
                content=r.content,
                created_at=r.created_at or datetime.now(timezone.utc).isoformat(),
                importance=r.importance,
                relevance_score=r.score
            )

            # Skip if below threshold
            if voi.total_score < min_voi:
                continue

            # Combine scores: (1-voi_weight)*retrieval + voi_weight*voi
            combined_score = (1 - voi_weight) * r.score + voi_weight * voi.total_score

            # Track retrieval for future VoI calculations
            self.voi_scorer.record_retrieval(r.id)

            # Create enhanced result
            enhanced = RetrievalResult(
                id=r.id,
                content=r.content,
                score=combined_score,
                tier=r.tier,
                source=r.source + "_voi_enhanced",
                strength=r.strength,
                importance=r.importance,
                created_at=r.created_at,
                metadata={
                    **r.metadata,
                    "voi_score": voi.total_score,
                    "voi_breakdown": {
                        "recency": voi.recency_score,
                        "relevance": voi.relevance_score,
                        "importance": voi.importance_score,
                        "outcome": voi.outcome_score
                    },
                    "base_retrieval_score": r.score,
                    "usage_count": voi.usage_count
                }
            )
            enhanced_results.append(enhanced)

        # Sort by combined score
        enhanced_results.sort(key=lambda x: x.score, reverse=True)
        return enhanced_results[:limit]

    def get_high_value_memories(
        self,
        limit: int = 20,
        min_voi: float = 0.5
    ) -> List[RetrievalResult]:
        """
        Get highest-value memories based on VoI scoring.

        Useful for context building and memory review.
        """
        from datetime import datetime, timezone

        # Get recent episodic memories
        episodes = self.pg_store.search_episodes(limit=limit * 3)

        # Score and filter
        results = []
        for ep in episodes:
            voi = self.voi_scorer.calculate_voi(
                memory_id=ep.episode_id,
                content=ep.content,
                created_at=ep.created_at or datetime.now(timezone.utc).isoformat(),
                importance=ep.importance_score
            )

            if voi.total_score >= min_voi:
                results.append(RetrievalResult(
                    id=ep.episode_id,
                    content=ep.content,
                    score=voi.total_score,
                    tier="episodic",
                    source="high_value_scan",
                    strength=ep.strength,
                    importance=ep.importance_score,
                    created_at=ep.created_at,
                    metadata={
                        "voi_breakdown": {
                            "recency": voi.recency_score,
                            "relevance": voi.relevance_score,
                            "importance": voi.importance_score,
                            "outcome": voi.outcome_score
                        },
                        "source_type": ep.source_type
                    }
                ))

        # Sort by VoI score
        results.sort(key=lambda x: x.score, reverse=True)
        return results[:limit]


# Singleton instance
_retriever = None

def get_hybrid_retriever() -> HybridRetriever:
    """Get or create the hybrid retriever singleton."""
    global _retriever
    if _retriever is None:
        _retriever = HybridRetriever()
    return _retriever


if __name__ == "__main__":
    # Test the retriever
    retriever = get_hybrid_retriever()

    print("Hybrid Retriever Test")
    print("=" * 50)

    # Test hybrid search
    results = retriever.retrieve(
        query="memory system architecture",
        strategy="hybrid",
        limit=5
    )

    print(f"Hybrid search results: {len(results)}")
    for r in results:
        print(f"  [{r.tier}] {r.id}: score={r.score:.3f}")
        print(f"    Content: {r.content[:60]}...")

    # Test entity search
    entity_results = retriever.retrieve_by_entity(
        entity_query="Genesis",
        limit=5
    )
    print(f"\nEntity search results: {len(entity_results)}")
