# memory_recall_engine.py
"""
High-Performance Memory Recall Engine for AIVA's Memory System.

This module provides a high-performance memory recall system leveraging
semantic similarity search, caching, and memory tier promotion.

Features:
- Semantic similarity search using vector embeddings.
- Fast retrieval with caching.
- Memory tier promotion (working -> episodic -> semantic).
- Comprehensive error handling.
- Proper logging and metrics.
"""

import time
from typing import List, Optional, Tuple, Dict, Any
import numpy as np
from scipy.spatial.distance import cosine

# Import existing Genesis memory components
try:
    from genesis_memory_cortex import MemoryCortex, Memory, MemoryTier, WorkingMemoryCache
    from logging_config import get_logger
    from metrics import GenesisMetrics
    LOGGING_AVAILABLE = True
    METRICS_AVAILABLE = True
    logger = get_logger("genesis.recall")
except ImportError:
    LOGGING_AVAILABLE = False
    METRICS_AVAILABLE = False
    logger = None
    GenesisMetrics = None
    MemoryCortex = None
    Memory = None
    MemoryTier = None
    WorkingMemoryCache = None


class MemoryRecallEngine:
    """
    Engine for recalling memories based on semantic similarity and tiering.
    """

    def __init__(self, memory_cortex: 'MemoryCortex', embedding_model=None):
        """
        Initializes the MemoryRecallEngine.

        Args:
            memory_cortex: The MemoryCortex instance to use for memory access.
            embedding_model: An optional embedding model for semantic similarity.
        """
        self.cortex = memory_cortex
        self.embedding_model = embedding_model  # e.g., SentenceTransformer
        self.working_memory_cache = WorkingMemoryCache()
        self.similarity_threshold = 0.75  # Adjust as needed
        self.promotion_threshold = 0.85  # Adjust as needed

    def recall_memories(self, query: str, top_k: int = 5) -> List[Memory]:
        """
        Recalls memories relevant to the given query.

        Args:
            query: The search query.
            top_k: The number of top memories to retrieve.

        Returns:
            A list of Memory objects ranked by relevance.
        """
        start_time = time.time()
        try:
            if not self.embedding_model:
                if logger:
                    logger.warning("No embedding model provided. Returning exact matches only.")
                return self._exact_match_recall(query, top_k)

            query_embedding = self.embedding_model.encode(query)
            candidate_memories = self._get_candidate_memories()
            
            if not candidate_memories:
                if logger:
                    logger.info("No candidate memories found.")
                return []

            # Calculate similarity scores
            scored_memories = self._score_memories(query_embedding, candidate_memories)

            # Sort by score and return top_k
            sorted_memories = sorted(scored_memories, key=lambda x: x[0], reverse=True)
            top_memories = [memory for _, memory in sorted_memories[:top_k]]

            # Promote memories if needed
            self._promote_memories(top_memories)

            if METRICS_AVAILABLE and GenesisMetrics:
                duration = time.time() - start_time
                GenesisMetrics.recall_latency.observe(duration)

            return top_memories

        except Exception as e:
            if logger:
                logger.error(f"Error during memory recall: {e}")
            return []

    def _exact_match_recall(self, query: str, top_k: int) -> List[Memory]:
        """Fallback: Returns exact match from working memory (if available)"""
        if self.working_memory_cache.available:
            all_keys = self.working_memory_cache.client.keys("genesis:working_memory:*")
            if not all_keys:
                return []
            
            memories = []
            for key in all_keys:
                memory = self.working_memory_cache.get(key.split(":")[-1])
                if memory and query.lower() in memory.content.lower():
                    memories.append(memory)
            
            return memories[:top_k]
        else:
            return []

    def _get_candidate_memories(self) -> List[Memory]:
        """
        Retrieves candidate memories from all relevant memory tiers.

        Returns:
            A list of Memory objects.
        """
        memories = []

        # Working memory (Redis)
        if self.working_memory_cache.available:
            all_keys = self.working_memory_cache.client.keys("genesis:working_memory:*")
            for key in all_keys:
                memory = self.working_memory_cache.get(key.split(":")[-1])
                if memory:
                    memories.append(memory)
        else:
            if logger:
                logger.warning("Working memory cache not available.")

        # TODO: Implement episodic and semantic memory retrieval
        # (e.g., from SQLite and Neo4j/MCP)

        return memories

    def _score_memories(self, query_embedding: np.ndarray, memories: List[Memory]) -> List[Tuple[float, Memory]]:
        """
        Scores memories based on semantic similarity to the query.

        Args:
            query_embedding: The embedding of the query.
            memories: A list of Memory objects.

        Returns:
            A list of tuples, where each tuple contains the similarity score and the Memory object.
        """
        scored_memories = []
        for memory in memories:
            if memory.embedding is not None:
                try:
                    memory_embedding = np.array(memory.embedding)
                    similarity_score = 1 - cosine(query_embedding, memory_embedding)
                    scored_memories.append((similarity_score, memory))
                except Exception as e:
                    if logger:
                        logger.error(f"Error calculating similarity for memory {memory.id}: {e}")
            else:
                if logger:
                    logger.warning(f"Memory {memory.id} has no embedding. Skipping.")

        return scored_memories

    def _promote_memories(self, memories: List[Memory]) -> None:
        """
        Promotes memories to higher tiers based on their relevance.

        Args:
            memories: A list of Memory objects.
        """
        for memory in memories:
            if memory.score >= self.promotion_threshold:
                if memory.tier == MemoryTier.WORKING:
                    # Promote to episodic
                    self.cortex.store_memory(memory.content, memory.domain, memory.source, tier=MemoryTier.EPISODIC)
                    if logger:
                        logger.info(f"Promoted memory {memory.id} to episodic tier.")
                elif memory.tier == MemoryTier.EPISODIC:
                    # Promote to semantic
                    self.cortex.store_memory(memory.content, memory.domain, memory.source, tier=MemoryTier.SEMANTIC)
                    if logger:
                        logger.info(f"Promoted memory {memory.id} to semantic tier.")

                # Remove from lower tier
                self.cortex.forget_memory(memory.id, memory.tier)

# Test functions
if __name__ == '__main__':
    # Mock MemoryCortex and Embedding Model for testing
    class MockMemoryCortex:
        def __init__(self):
            self.memories = {}

        def store_memory(self, content, domain, source, tier):
            memory_id = len(self.memories) + 1
            self.memories[memory_id] = {"content": content, "domain": domain, "source": source, "tier": tier}
            return memory_id

        def forget_memory(self, memory_id, tier):
            del self.memories[memory_id]

    class MockEmbeddingModel:
        def encode(self, text):
            # Dummy embeddings for testing
            if "performance" in text:
                return np.array([0.8, 0.6])
            elif "optimization" in text:
                return np.array([0.7, 0.5])
            else:
                return np.array([0.2, 0.3])

    def test_recall_with_embedding():
        """Tests recall with a mock embedding model."""
        cortex = MockMemoryCortex()
        embedding_model = MockEmbeddingModel()
        recall_engine = MemoryRecallEngine(cortex, embedding_model)

        # Add some mock memories
        memory1_id = cortex.store_memory("Parallel execution improves performance", "performance", "test", MemoryTier.WORKING)
        memory2_id = cortex.store_memory("Code optimization techniques", "optimization", "test", MemoryTier.WORKING)
        
        # Mock set embedding
        if recall_engine.working_memory_cache.available:
            memory1 = recall_engine.working_memory_cache.get(str(memory1_id))
            if memory1:
                memory1.embedding = [0.7, 0.5]
                recall_engine.working_memory_cache.set(memory1)
            
            memory2 = recall_engine.working_memory_cache.get(str(memory2_id))
            if memory2:
                memory2.embedding = [0.8, 0.6]
                recall_engine.working_memory_cache.set(memory2)
        
        results = recall_engine.recall_memories("performance optimization", top_k=2)
        assert len(results) > 0
        print("test_recall_with_embedding passed")

    def test_recall_no_embedding():
        """Tests recall with no embedding model (exact match fallback)."""
        cortex = MockMemoryCortex()
        recall_engine = MemoryRecallEngine(cortex, cortex) #Passing cortex to embedding model to trigger fallback
        cortex.store_memory("Exact match test", "test", "test", MemoryTier.WORKING)
        
        results = recall_engine.recall_memories("Exact match test", top_k=1)
        assert len(results) > 0
        print("test_recall_no_embedding passed")

    def test_recall_empty_memory():
        """Tests recall when no memories are available."""
        cortex = MockMemoryCortex()
        embedding_model = MockEmbeddingModel()
        recall_engine = MemoryRecallEngine(cortex, embedding_model)

        results = recall_engine.recall_memories("performance optimization", top_k=2)
        assert len(results) == 0
        print("test_recall_empty_memory passed")

    test_recall_with_embedding()
    test_recall_no_embedding()
    test_recall_empty_memory()