# memory_recall_engine.py
"""
High-Performance Memory Recall Engine
======================================

This module implements a sophisticated memory recall system leveraging semantic similarity search,
caching, and memory tier promotion for optimal performance and accuracy.

Features:
    - Semantic similarity search using embeddings
    - Fast retrieval with caching
    - Memory tier promotion (working -> episodic -> semantic)
    - Comprehensive error handling
    - Proper logging and metrics
"""

import json
import time
import logging
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass
from enum import Enum
import hashlib

# Import existing Genesis memory components
try:
    from surprise_memory import MemorySystem, MemoryItem, SurpriseScore
except ImportError:
    print("Error: surprise_memory module not found.  Please ensure it is installed.")
    MemorySystem = None
    MemoryItem = None
    SurpriseScore = None

try:
    from genesis_memory_cortex import Memory, MemoryTier, WorkingMemoryCache
except ImportError:
    print("Error: genesis_memory_cortex module not found.  Please ensure it is installed.")
    Memory = None
    MemoryTier = None
    WorkingMemoryCache = None

# Import observability modules
try:
    from logging_config import get_logger
    LOGGING_AVAILABLE = True
    logger = get_logger("memory_recall")
except ImportError:
    LOGGING_AVAILABLE = False
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    print("Logging config not found, using basic logger.")

try:
    from metrics import GenesisMetrics
    METRICS_AVAILABLE = True
except ImportError:
    METRICS_AVAILABLE = False
    GenesisMetrics = None
    print("Metrics module not found.")

# Import vector backends for semantic similarity search
try:
    from vector_backends import VectorManager, VectorDocument
    VECTOR_AVAILABLE = True
except ImportError:
    VECTOR_AVAILABLE = False
    VectorManager = None
    print("Vector backends module not found.")

class RecallError(Exception):
    """Base class for recall-related exceptions."""
    pass


class NoRelevantMemoriesFound(RecallError):
    """Raised when no memories meet the recall criteria."""
    pass


class RecallEngine:
    """
    Core engine for recalling memories based on semantic similarity, caching, and tier promotion.
    """

    def __init__(self, working_memory: Optional['WorkingMemoryCache'] = None, vector_manager: Optional['VectorManager'] = None,  surprise_memory: Optional['MemorySystem'] = None):
        """
        Initializes the RecallEngine.

        Args:
            working_memory: An instance of WorkingMemoryCache for fast retrieval.
            vector_manager: An instance of VectorManager for semantic similarity search.
        """
        self.working_memory = working_memory if working_memory else WorkingMemoryCache()
        self.vector_manager = vector_manager if vector_manager else VectorManager()
        self.surprise_memory = surprise_memory if surprise_memory else MemorySystem()
        self.recall_threshold = 0.7  # Minimum similarity score for recall

    def recall(self, query: str, domain: str = "general", top_k: int = 5) -> List['Memory']:
        """
        Recalls memories relevant to the given query, considering semantic similarity, cache, and tier promotion.

        Args:
            query: The search query.
            domain: The domain of the query (e.g., "programming", "science").
            top_k: The maximum number of memories to return.

        Returns:
            A list of Memory objects ranked by relevance.

        Raises:
            RecallError: If an error occurs during the recall process.
            NoRelevantMemoriesFound: If no memories meet the recall criteria.
        """
        start_time = time.time()
        try:
            # 1. Check Working Memory Cache
            cached_results = self._check_working_memory(query, domain)
            if cached_results:
                if logger:
                    logger.info(f"Retrieved {len(cached_results)} memories from working memory cache.")
                if METRICS_AVAILABLE and GenesisMetrics:
                    GenesisMetrics.cache_hits.inc(labels={"tier": "working"})
                return cached_results

            # 2. Semantic Similarity Search using Vector Backends
            if VECTOR_AVAILABLE and self.vector_manager:
                vector_results = self._semantic_search(query, domain, top_k * 3) # Fetch more and filter later
            else:
                if logger:
                    logger.warning("Vector search disabled.  Returning empty result.")
                vector_results = []

            # 3. Episodic Memory Search (SQLite/Postgres) - omitted for brevity

            # 4. Semantic Memory Search (Neo4j/MCP) - omitted for brevity

            # 5. Rank and Filter Results
            ranked_results = self._rank_and_filter(query, vector_results)

            # 6. Promote Memory Tiers (if applicable)
            self._promote_memory_tiers(ranked_results)

            # 7. Cache Results in Working Memory
            self._cache_results(ranked_results)

            if not ranked_results:
                raise NoRelevantMemoriesFound("No relevant memories found for the query.")

            if logger:
                logger.info(f"Recall successful. Retrieved {len(ranked_results)} memories.")
            if METRICS_AVAILABLE and GenesisMetrics:
                GenesisMetrics.memory_operations.inc(labels={"tier": "all", "op": "recall"})
                duration = time.time() - start_time
                GenesisMetrics.memory_latency.observe(duration, labels={"tier": "all"})

            return ranked_results[:top_k]

        except NoRelevantMemoriesFound as e:
            if logger:
                logger.warning(f"No relevant memories found: {e}")
            raise
        except Exception as e:
            if logger:
                logger.error(f"Error during recall: {e}", exc_info=True)
            if METRICS_AVAILABLE and GenesisMetrics:
                GenesisMetrics.recall_errors.inc()
            raise RecallError(f"Recall failed: {e}")

    def _check_working_memory(self, query: str, domain: str) -> List['Memory']:
        """
        Checks the working memory cache for relevant memories.

        Args:
            query: The search query.
            domain: The domain of the query.

        Returns:
            A list of Memory objects from the cache, or an empty list if no relevant memories are found.
        """
        if not self.working_memory or not self.working_memory.available:
            return []

        # Simplified: Search by query hash for exact matches
        query_hash = hashlib.sha256(query.encode()).hexdigest()
        memory = self.working_memory.get(query_hash)
        if memory:
            return [memory]

        return []

    def _semantic_search(self, query: str, domain: str, top_k: int) -> List['Memory']:
        """
        Performs semantic similarity search using the configured vector backends.

        Args:
            query: The search query.
            domain: The domain of the query.
            top_k: The number of results to return from each backend.

        Returns:
            A list of Memory objects.
        """
        try:
            if not VECTOR_AVAILABLE or not self.vector_manager:
                return []

            results: List['Memory'] = []
            documents: List['VectorDocument'] = self.vector_manager.search(query, top_k=top_k)
            for doc in documents:
                try:
                    memory_dict = json.loads(doc.metadata.get("memory_json", "{}"))
                    memory = self._dict_to_memory(memory_dict)
                    results.append(memory)
                except Exception as e:
                    if logger:
                        logger.error(f"Error deserializing memory from vector document: {e}, doc id: {doc.id}", exc_info=True)
                    continue
            return results
        except Exception as e:
            if logger:
                logger.error(f"Vector search failed: {e}", exc_info=True)
            return []

    def _rank_and_filter(self, query: str, memories: List['Memory']) -> List['Memory']:
        """
        Ranks and filters the retrieved memories based on relevance and a similarity threshold.

        Args:
            query: The search query.
            memories: A list of Memory objects.

        Returns:
            A list of Memory objects ranked by relevance.
        """
        # In a real implementation, this would involve a more sophisticated ranking algorithm
        # considering factors like semantic similarity, recency, and importance.
        # For simplicity, we'll just filter based on a minimum score.
        filtered_memories = [m for m in memories if m.score >= self.recall_threshold]
        filtered_memories.sort(key=lambda x: x.score, reverse=True)  # Sort by score descending

        return filtered_memories

    def _promote_memory_tiers(self, memories: List['Memory']):
        """
        Promotes the memory tiers of relevant memories based on access frequency and importance.

        Args:
            memories: A list of Memory objects.
        """
        # In a real implementation, this would involve updating the memory tier of the
        # memories in the appropriate storage systems (e.g., moving from Redis to PostgreSQL).
        # For simplicity, we'll just log the promotion.
        for memory in memories:
            if memory.tier == MemoryTier.WORKING and memory.score >= 0.8:
                if logger:
                    logger.info(f"Promoting memory {memory.id} to episodic tier.")
                memory.tier = MemoryTier.EPISODIC # Simulate promotion
            elif memory.tier == MemoryTier.EPISODIC and memory.score >= 0.95:
                if logger:
                    logger.info(f"Promoting memory {memory.id} to semantic tier.")
                memory.tier = MemoryTier.SEMANTIC # Simulate promotion

    def _cache_results(self, memories: List['Memory']):
        """
        Caches the retrieved memories in the working memory cache.

        Args:
            memories: A list of Memory objects.
        """
        if self.working_memory and self.working_memory.available:
            for memory in memories:
                # Use a hash of the content as the ID for caching
                memory.id = hashlib.sha256(memory.content.encode()).hexdigest()
                self.working_memory.set(memory)

    def _dict_to_memory(self, mem_dict: Dict[str, Any]) -> 'Memory':
        """Convert dictionary to Memory object, handling potential errors."""
        try:
            tier_str = mem_dict.get('tier', 'working')
            tier = MemoryTier(tier_str)  # Convert string to enum
            return Memory(
                id=mem_dict['id'],
                content=mem_dict['content'],
                tier=tier,
                score=mem_dict['score'],
                domain=mem_dict['domain'],
                source=mem_dict['source'],
                timestamp=mem_dict['timestamp'],
                embedding=mem_dict.get('embedding'),
                relations=mem_dict.get('relations'),
                access_count=mem_dict.get('access_count', 0),
                last_accessed=mem_dict.get('last_accessed'),
                metadata=mem_dict.get('metadata')
            )
        except ValueError as e:
            if logger:
                logger.error(f"Invalid MemoryTier value: {tier_str}.  Defaulting to WORKING.", exc_info=True)
            # Handle invalid MemoryTier values gracefully
            mem_dict['tier'] = 'working'
            tier = MemoryTier('working')
            return Memory(
                id=mem_dict['id'],
                content=mem_dict['content'],
                tier=tier,
                score=mem_dict['score'],
                domain=mem_dict['domain'],
                source=mem_dict['source'],
                timestamp=mem_dict['timestamp'],
                embedding=mem_dict.get('embedding'),
                relations=mem_dict.get('relations'),
                access_count=mem_dict.get('access_count', 0),
                last_accessed=mem_dict.get('last_accessed'),
                metadata=mem_dict.get('metadata')
            )
        except KeyError as e:
            if logger:
                logger.error(f"Missing key in memory dictionary: {e}.  Returning None.", exc_info=True)
            raise ValueError(f"Missing key in memory dictionary: {e}")
        except Exception as e:
            if logger:
                logger.error(f"Failed to convert dictionary to Memory object: {e}", exc_info=True)
            raise ValueError(f"Failed to convert dictionary to Memory object: {e}")


# --- Test Functions ---
if __name__ == '__main__':
    # Configure basic logging for tests
    logging.basicConfig(level=logging.INFO)

    def create_test_memory(content: str, score: float, domain: str, source: str) -> 'Memory':
        """Helper function to create a Memory object for testing."""
        return Memory(
            id=hashlib.sha256(content.encode()).hexdigest(),
            content=content,
            tier=MemoryTier.WORKING,
            score=score,
            domain=domain,
            source=source,
            timestamp=str(datetime.now()),
            embedding=None,
            relations=None,
            access_count=0,
            last_accessed=None,
            metadata=None
        )

    def test_recall_basic():
        """Tests basic recall functionality."""
        recall_engine = RecallEngine()

        # Simulate adding memories to vector store
        test_memory1 = create_test_memory("The sky is blue.", 0.8, "general", "test")
        test_memory2 = create_test_memory("Grass is green.", 0.9, "general", "test")

        if VECTOR_AVAILABLE and recall_engine.vector_manager:
            recall_engine.vector_manager.add([
                VectorDocument(id=test_memory1.id, content=test_memory1.content, metadata={"memory_json": json.dumps(test_memory1.to_dict())}),
                VectorDocument(id=test_memory2.id, content=test_memory2.content, metadata={"memory_json": json.dumps(test_memory2.to_dict())})
            ])

        try:
            results = recall_engine.recall("What color is the sky?", "general")
            assert len(results) > 0
            assert "sky is blue" in results[0].content.lower()
            logger.info("test_recall_basic passed")
        except Exception as e:
            logger.error(f"test_recall_basic failed: {e}", exc_info=True)

        if VECTOR_AVAILABLE and recall_engine.vector_manager:
            recall_engine.vector_manager.delete(test_memory1.id)
            recall_engine.vector_manager.delete(test_memory2.id)


    def test_recall_no_results():
        """Tests recall when no relevant memories are found."""
        recall_engine = RecallEngine()
        try:
            recall_engine.recall("This query should return no results.", "unknown")
            assert False, "Expected NoRelevantMemoriesFound exception"
        except NoRelevantMemoriesFound:
            logger.info("test_recall_no_results passed")
        except Exception as e:
            logger.error(f"test_recall_no_results failed: {e}", exc_info=True)

    def test_recall_tier_promotion():
        """Tests memory tier promotion functionality."""
        recall_engine = RecallEngine()
        test_memory = create_test_memory("This is an important memory.", 0.9, "critical", "test")

        if VECTOR_AVAILABLE and recall_engine.vector_manager:
            recall_engine.vector_manager.add(VectorDocument(id=test_memory.id, content=test_memory.content, metadata={"memory_json": json.dumps(test_memory.to_dict())}))

        try:
            results = recall_engine.recall("important memory", "critical")
            assert len(results) > 0
            assert results[0].tier == MemoryTier.EPISODIC # Check if tier was promoted (simulated)
            logger.info("test_recall_tier_promotion passed")
        except Exception as e:
            logger.error(f"test_recall_tier_promotion failed: {e}", exc_info=True)

        if VECTOR_AVAILABLE and recall_engine.vector_manager:
            recall_engine.vector_manager.delete(test_memory.id)

    test_recall_basic()
    test_recall_no_results()
    test_recall_tier_promotion()