# memory_recall_engine.py
"""
High-Performance Memory Recall Engine
======================================

This module provides a memory recall system with semantic similarity search,
fast retrieval with caching, memory tier promotion, comprehensive error
handling, logging, and metrics.

Features:
    - Semantic similarity search using embeddings via VectorManager
    - Fast retrieval using WorkingMemoryCache (Redis)
    - Memory tier promotion logic
    - Error handling with circuit breaker
    - Logging and metrics integration

Usage:
    from memory_recall_engine import MemoryRecallEngine

    engine = MemoryRecallEngine(cortex)
    results = engine.recall("search query", top_k=5)
"""

import time
from typing import List, Optional, Tuple
from genesis_memory_cortex import MemoryCortex, Memory, MemoryTier, WorkingMemoryCache, VectorManager, LOGGING_AVAILABLE, METRICS_AVAILABLE, GenesisMetrics, logger
from memory_schemas import MemoryOutput
from redis import Redis

class MemoryRecallEngine:
    """
    Memory recall engine orchestrating similarity search, caching, and tier promotion.
    """

    def __init__(self, cortex: MemoryCortex):
        """
        Initializes the MemoryRecallEngine.

        Args:
            cortex: An instance of the MemoryCortex class.
        """
        self.cortex = cortex
        self.working_memory_cache: WorkingMemoryCache = cortex.working_memory_cache
        self.vector_manager: VectorManager = cortex.vector_manager
        self.surprise_memory = cortex.surprise_memory
        self.metrics_enabled = METRICS_AVAILABLE and GenesisMetrics

        if LOGGING_AVAILABLE and logger:
            logger.info("MemoryRecallEngine initialized.")

    def recall(self, query: str, top_k: int = 5, domain_filter: Optional[str] = None) -> List[MemoryOutput]:
        """
        Recalls relevant memories based on semantic similarity search and caching.

        Args:
            query: The search query.
            top_k: The number of results to return.
            domain_filter: Optional domain to filter memories by.

        Returns:
            A list of MemoryOutput objects representing the recalled memories.
        """
        start_time = time.time()
        try:
            # 1. Semantic Similarity Search via VectorManager
            if not self.vector_manager or not self.vector_manager.available:
                if LOGGING_AVAILABLE and logger:
                    logger.warning("VectorManager unavailable, skipping semantic search.")
                return [] # Or return from other memory tiers if available

            vector_results: List[Tuple[str, float]] = self.vector_manager.search(query, top_k=top_k)

            if not vector_results:
                if LOGGING_AVAILABLE and logger:
                    logger.info("No vector search results found.")
                return []

            memory_ids = [result[0] for result in vector_results]
            memories: List[Memory] = []

            # 2. Fast Retrieval with Caching (WorkingMemoryCache)
            cached_memories: List[Memory] = []
            uncached_ids: List[str] = []

            if self.working_memory_cache and self.working_memory_cache.available:
                for memory_id in memory_ids:
                    memory: Optional[Memory] = self.working_memory_cache.get(memory_id)
                    if memory:
                        cached_memories.append(memory)
                    else:
                        uncached_ids.append(memory_id)
            else:
                uncached_ids = memory_ids

            # 3. Retrieve from Episodic/Semantic Memory (if not cached)
            retrieved_memories: List[Memory] = []
            if uncached_ids:
                retrieved_memories = self._retrieve_from_episodic_semantic(uncached_ids)

            # 4. Combine and Rank Results
            memories = cached_memories + retrieved_memories
            if domain_filter:
                memories = [m for m in memories if m.domain == domain_filter]

            # 5. Sort by similarity score (from vector search results)
            memory_id_to_score = {result[0]: result[1] for result in vector_results}
            memories.sort(key=lambda m: memory_id_to_score.get(m.id, 0.0), reverse=True)

            # 6. Memory Tier Promotion (Example: Promote to working if frequently accessed)
            self._promote_memory_tiers(memories)

            # 7. Convert to MemoryOutput format
            output_memories: List[MemoryOutput] = [
                MemoryOutput(
                    id=m.id,
                    content=m.content,
                    tier=m.tier.value,
                    score=memory_id_to_score.get(m.id, 0.0),
                    timestamp=m.timestamp,
                    stored_in=["Redis" if m in cached_memories else "Episodic/Semantic"],
                    metadata=m.metadata if m.metadata else {}
                )
                for m in memories[:top_k]
            ]

            # Metrics and Logging
            if self.metrics_enabled:
                duration = time.time() - start_time
                GenesisMetrics.recall_latency.observe(duration)
                GenesisMetrics.memory_recalls.inc()

            if LOGGING_AVAILABLE and logger:
                logger.info(f"Successfully recalled {len(output_memories)} memories.")

            return output_memories

        except Exception as e:
            if LOGGING_AVAILABLE and logger:
                logger.exception("Error during memory recall.")
            if self.metrics_enabled:
                GenesisMetrics.recall_errors.inc()
            return []

    def _retrieve_from_episodic_semantic(self, memory_ids: List[str]) -> List[Memory]:
        """Retrieves memories from episodic/semantic storage."""
        memories: List[Memory] = []
        for memory_id in memory_ids:
            memory = self.cortex.retrieve_memory(memory_id) # Assuming this method exists in MemoryCortex
            if memory:
                memories.append(memory)
                # Cache in working memory after retrieval
                if self.working_memory_cache and self.working_memory_cache.available:
                    self.working_memory_cache.set(memory)
        return memories

    def _promote_memory_tiers(self, memories: List[Memory]):
        """Promotes memory tiers based on access frequency and importance."""
        for memory in memories:
            if self.surprise_memory:
                surprise_score = self.surprise_memory.calculate_surprise(memory.content)
                if surprise_score.surprise_score > 0.7 and memory.tier == MemoryTier.EPISODIC:
                    memory.tier = MemoryTier.SEMANTIC # Example promotion
                    self.cortex.update_memory(memory) # Assuming this method exists
                    if LOGGING_AVAILABLE and logger:
                        logger.info(f"Promoted memory {memory.id} to Semantic tier.")
            else:
                if LOGGING_AVAILABLE and logger:
                    logger.warning("Surprise Memory unavailable, skipping tier promotion.")

    def clear_cache(self):
        """Clears the working memory cache (Redis)."""
        if self.working_memory_cache and self.working_memory_cache.available:
            try:
                self.working_memory_cache.clear_cache()
                if LOGGING_AVAILABLE and logger:
                    logger.info("Working memory cache cleared.")
            except Exception as e:
                if LOGGING_AVAILABLE and logger:
                    logger.error(f"Error clearing working memory cache: {e}")
        else:
            if LOGGING_AVAILABLE and logger:
                logger.warning("Working memory cache unavailable, cannot clear.")


# Example Test Functions (replace with proper unit tests)
if __name__ == "__main__":
    # Mock MemoryCortex and other dependencies for testing
    class MockRedis:
        def __init__(self):
            self.data = {}
        def get(self, key):
            return self.data.get(key)
        def set(self, key, value):
            self.data[key] = value
        def delete(self, key):
            if key in self.data:
                del self.data[key]

    class MockVectorManager:
        def __init__(self):
            self.available = True
        def search(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
            # Simulate vector search results
            if query == "performance optimization":
                return [("memory_1", 0.9), ("memory_2", 0.8), ("memory_3", 0.7)]
            else:
                return []

    class MockSurpriseMemory:
        def calculate_surprise(self, content):
            class MockScore:
                def __init__(self, surprise_score):
                    self.surprise_score = surprise_score
            return MockScore(0.8)

    class MockMemoryCortex:
        def __init__(self):
            self.working_memory_cache = WorkingMemoryCache(redis_client=MockRedis())
            self.vector_manager = MockVectorManager()
            self.surprise_memory = MockSurpriseMemory()

        def retrieve_memory(self, memory_id: str) -> Optional[Memory]:
            if memory_id == "memory_2":
                return Memory(id="memory_2", content="Episodic memory content", tier=MemoryTier.EPISODIC, score=0.6, domain="test", source="test", timestamp="2024-01-01")
            else:
                return None
        def update_memory(self, memory: Memory):
            print(f"Memory {memory.id} updated.")

    # Test 1: Basic recall with mock data
    def test_basic_recall():
        cortex = MockMemoryCortex()
        engine = MemoryRecallEngine(cortex)
        results = engine.recall("performance optimization", top_k=2)
        assert len(results) == 2
        assert results[0].id == "memory_1"
        print("Test 1 passed.")

    # Test 2: Recall with episodic memory retrieval
    def test_episodic_recall():
        cortex = MockMemoryCortex()
        engine = MemoryRecallEngine(cortex)
        results = engine.recall("performance optimization", top_k=3)
        assert len(results) == 3
        assert results[1].id == "memory_2"
        assert results[1].content == "Episodic memory content"
        print("Test 2 passed.")

    # Test 3: Recall with no results
    def test_no_results():
        cortex = MockMemoryCortex()
        engine = MemoryRecallEngine(cortex)
        results = engine.recall("unrelated query", top_k=5)
        assert len(results) == 0
        print("Test 3 passed.")

    test_basic_recall()
    test_episodic_recall()
    test_no_results()