# memory_recall_engine.py
"""
High-Performance Memory Recall Engine
=======================================

This module provides a high-performance memory recall system with semantic
similarity search, caching, and memory tier promotion.

Key Features:
- Semantic Similarity Search: Uses embeddings to find memories relevant to a query.
- Fast Retrieval with Caching: Caches frequently accessed memories for quick access.
- Memory Tier Promotion: Automatically promotes memories to higher tiers based on usage and importance.
- Comprehensive Error Handling: Includes robust error handling to ensure system stability.
- Logging and Metrics: Provides detailed logging and metrics for monitoring and debugging.

Dependencies:
- genesis_memory_cortex (MemoryCortex, Memory, MemoryTier, WorkingMemoryCache, VectorManager)
- memory_schemas (MemoryItemInput, MemoryOutput)
- redis
- typing
- time
"""

import time
from typing import List, Optional, Tuple
import logging
import json

# Import existing Genesis memory components (adjust path if needed)
try:
    from genesis_memory_cortex import MemoryCortex, Memory, MemoryTier, WorkingMemoryCache
    from genesis_memory_cortex import VectorManager
    from memory_schemas import MemoryItemInput, MemoryOutput

    MODULES_AVAILABLE = True

except ImportError as e:
    print(f"Error importing modules: {e}")
    MODULES_AVAILABLE = False

# Configure logging (replace with your preferred setup)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class MemoryRecallEngine:
    """
    A high-performance memory recall engine with semantic similarity search,
    caching, and memory tier promotion.
    """

    def __init__(self, memory_cortex: 'MemoryCortex', similarity_threshold: float = 0.7):
        """
        Initializes the MemoryRecallEngine.

        Args:
            memory_cortex: The MemoryCortex instance to use for memory management.
            similarity_threshold: The minimum similarity score for a memory to be considered relevant.
        """
        if not MODULES_AVAILABLE:
            raise ImportError("Required modules not available. Check installation.")

        self.cortex = memory_cortex
        self.similarity_threshold = similarity_threshold
        self.cache = {}  # Simple in-memory cache (replace with a more robust solution for production)
        self.vector_manager = VectorManager() # Assuming VectorManager can be initialized without arguments.  If not, adjust.

    def recall(self, query: str, top_k: int = 5) -> List['Memory']:
        """
        Recalls memories relevant to the given query using semantic similarity search and caching.

        Args:
            query: The query string.
            top_k: The maximum number of memories to return.

        Returns:
            A list of Memory objects relevant to the query, sorted by similarity score.
        """
        start_time = time.time()

        try:
            # 1. Check Cache
            if query in self.cache:
                logger.info(f"Cache hit for query: {query}")
                cached_results = self.cache[query]
                # Ensure results are still valid (e.g., check TTL) - omitted for simplicity
                return cached_results[:top_k]

            # 2. Semantic Similarity Search
            query_embedding = self.vector_manager.get_embedding(query)

            if query_embedding is None:
                logger.error(f"Failed to generate embedding for query: {query}")
                return []

            relevant_memories = []
            all_memories = self.cortex.get_all_memories()  # Assuming this method exists in MemoryCortex
            if all_memories is None:
                logger.warning("No memories found in the cortex.")
                return []

            for memory in all_memories:
                if memory.embedding is None:
                   continue
                similarity_score = self.vector_manager.calculate_similarity(query_embedding, memory.embedding)
                if similarity_score >= self.similarity_threshold:
                    relevant_memories.append((memory, similarity_score))

            # Sort by similarity score (highest first)
            relevant_memories.sort(key=lambda x: x[1], reverse=True)

            # Extract Memory objects and truncate to top_k
            results = [memory for memory, _ in relevant_memories[:top_k]]

            # 3. Cache Results
            self.cache[query] = results  # Store the top_k results
            logger.info(f"Cache miss for query: {query}, storing results.")

            return results

        except Exception as e:
            logger.exception(f"Error during recall: {e}")
            return []
        finally:
            end_time = time.time()
            duration = end_time - start_time
            logger.info(f"Recall completed in {duration:.4f} seconds for query: {query}")

    def promote_memory_tier(self, memory_id: str, target_tier: 'MemoryTier'):
        """
        Promotes a memory to a higher tier based on usage and importance.

        Args:
            memory_id: The ID of the memory to promote.
            target_tier: The target memory tier.
        """
        try:
            memory = self.cortex.get_memory_by_id(memory_id)
            if memory is None:
                logger.warning(f"Memory with ID {memory_id} not found.")
                return

            if memory.tier == target_tier:
                logger.info(f"Memory {memory_id} is already at tier {target_tier}.")
                return

            if memory.tier.value == "semantic":
                 logger.info(f"Memory {memory_id} is already at the highest tier (semantic).")
                 return

            if target_tier.value not in ["working", "episodic", "semantic"]:
                 logger.error(f"Invalid memory tier: {target_tier}")
                 return

            self.cortex.update_memory_tier(memory_id, target_tier)
            logger.info(f"Memory {memory_id} promoted to tier {target_tier}.")

        except Exception as e:
            logger.exception(f"Error promoting memory {memory_id}: {e}")

    def clear_cache(self):
        """Clears the in-memory cache."""
        self.cache = {}
        logger.info("Cache cleared.")

# --- Test Functions ---
if __name__ == '__main__':
    if not MODULES_AVAILABLE:
        print("Skipping tests due to missing modules.")
    else:
        # Mock MemoryCortex for testing (replace with actual instance in production)
        class MockMemoryCortex:
            def __init__(self):
                self.memories = {}

            def remember(self, content: str, domain: str = "test", source: str = "test", relations: List[str] = [], metadata: dict = {}):
                memory_input = MemoryItemInput(content=content, domain=domain, source=source, relations=relations, metadata=metadata)
                memory_id = str(len(self.memories) + 1)
                memory = Memory(id=memory_id, content=memory_input.content, tier=MemoryTier.WORKING, score=0.6, domain=memory_input.domain, source=memory_input.source, timestamp="2024-01-01", embedding=[0.1, 0.2, 0.3])
                self.memories[memory_id] = memory
                return memory_id

            def get_all_memories(self):
                return list(self.memories.values())

            def get_memory_by_id(self, memory_id: str):
                return self.memories.get(memory_id)

            def update_memory_tier(self, memory_id: str, tier: 'MemoryTier'):
                memory = self.memories.get(memory_id)
                if memory:
                    memory.tier = tier
                    self.memories[memory_id] = memory

        # Create a MockMemoryCortex instance
        mock_cortex = MockMemoryCortex()

        # Create a MemoryRecallEngine instance
        recall_engine = MemoryRecallEngine(mock_cortex)

        # Test 1: Recall with no memories
        results = recall_engine.recall("test query")
        print(f"Test 1: Recall with no memories - {len(results) == 0}")

        # Test 2: Add memories and recall
        mock_cortex.remember("This is a test memory about performance.", domain="performance")
        mock_cortex.remember("This is another memory about something else.", domain="other")
        results = recall_engine.recall("performance", top_k=1)
        print(f"Test 2: Recall with memories - {len(results) == 1 and 'performance' in results[0].domain}")

        # Test 3: Promote memory tier
        memory_id = results[0].id
        recall_engine.promote_memory_tier(memory_id, MemoryTier.SEMANTIC)
        memory = mock_cortex.get_memory_by_id(memory_id)
        print(f"Test 3: Promote memory tier - {memory.tier == MemoryTier.SEMANTIC}")