# memory_recall_engine.py
"""
High-Performance Memory Recall Engine with Semantic Similarity and Tiered Storage

This module implements a memory recall system that utilizes semantic similarity search,
caching, and memory tier promotion to provide fast and accurate memory retrieval.
"""

import json
import hashlib
import time
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple

# Import existing Genesis memory components
try:
    from genesis_memory_cortex import MemorySystem, MemoryItem, SurpriseScore, Memory, MemoryTier, WorkingMemoryCache
    from vector_backends import VectorManager, VectorDocument
    from logging_config import get_logger
    LOGGING_AVAILABLE = True
    logger = get_logger("memory_recall_engine")
except ImportError as e:
    print(f"ImportError: {e}. Ensure all dependencies are installed.")
    LOGGING_AVAILABLE = False
    logger = None

try:
    from metrics import GenesisMetrics
    METRICS_AVAILABLE = True
except ImportError:
    METRICS_AVAILABLE = False
    GenesisMetrics = None


class MemoryRecallEngine:
    """
    Memory Recall Engine with Semantic Similarity and Tiered Storage.
    """

    def __init__(self, working_memory_cache: Optional[WorkingMemoryCache] = None,
                 vector_manager: Optional[VectorManager] = None,
                 surprise_memory: Optional[MemorySystem] = None,
                 similarity_threshold: float = 0.7,
                 promotion_threshold: float = 0.9):
        """
        Initializes the MemoryRecallEngine.

        Args:
            working_memory_cache: An instance of WorkingMemoryCache for fast retrieval.
            vector_manager: An instance of VectorManager for semantic similarity search.
            surprise_memory: An instance of MemorySystem for managing episodic memory.
            similarity_threshold: The minimum similarity score for a memory to be considered relevant.
            promotion_threshold: The minimum score for promoting a memory to a higher tier.
        """
        self.working_memory_cache = working_memory_cache
        self.vector_manager = vector_manager
        self.surprise_memory = surprise_memory
        self.similarity_threshold = similarity_threshold
        self.promotion_threshold = promotion_threshold

        if self.vector_manager is None:
            if logger:
                logger.warning("VectorManager not provided. Semantic similarity search will be unavailable.")
            else:
                print("[!] MemoryRecallEngine: VectorManager not provided.")

        if self.surprise_memory is None:
             if logger:
                logger.warning("SurpriseMemory not provided. Episodic memory will be unavailable.")
             else:
                print("[!] MemoryRecallEngine: SurpriseMemory not provided.")

    def recall(self, query: str, domain: str = "general", top_k: int = 5) -> List[Memory]:
        """
        Recalls memories relevant to the given query using semantic similarity search and tiered storage.

        Args:
            query: The query string.
            domain: The domain of the query.
            top_k: The maximum number of memories to return.

        Returns:
            A list of Memory objects ranked by relevance.
        """
        start_time = time.time()
        results: List[Memory] = []

        # 1. Check Working Memory Cache
        if self.working_memory_cache and self.working_memory_cache.available:
            cached_results = self._recall_from_working_memory(query, domain, top_k)
            results.extend(cached_results)

        # 2. Semantic Similarity Search (Vector Backends)
        if self.vector_manager:
            vector_results = self._semantic_similarity_search(query, domain, top_k)
            results.extend(vector_results)

        # 3. Episodic Memory (Surprise Memory)
        if self.surprise_memory:
            episodic_results = self._recall_from_episodic_memory(query, domain, top_k)
            results.extend(episodic_results)

        # 4. Deduplicate and Rank Results
        unique_results: Dict[str, Memory] = {}
        for result in results:
            if result.id not in unique_results:
                unique_results[result.id] = result

        ranked_results = sorted(unique_results.values(), key=lambda x: x.score, reverse=True)[:top_k]

        # Log and Metrics
        duration = time.time() - start_time
        if logger:
            logger.info(f"Recall completed in {duration:.4f} seconds", extra={"query": query, "domain": domain, "num_results": len(ranked_results)})
        if METRICS_AVAILABLE and GenesisMetrics:
            GenesisMetrics.recall_latency.observe(duration)
            GenesisMetrics.recall_operations.inc()

        return ranked_results

    def _recall_from_working_memory(self, query: str, domain: str, top_k: int) -> List[Memory]:
        """
        Recalls memories from the working memory cache based on the query and domain.

        Args:
            query: The query string.
            domain: The domain of the query.
            top_k: The maximum number of memories to return.

        Returns:
            A list of Memory objects.
        """
        results: List[Memory] = []
        try:
            # In a real-world scenario, you'd likely have an index or some other
            # mechanism to efficiently search the working memory cache.
            # For this example, we'll iterate through all keys and perform a simple
            # substring search.  This is highly inefficient for large caches.
            if self.working_memory_cache.client:
                keys = self.working_memory_cache.client.keys(pattern=f"{self.working_memory_cache.namespace}:*")
                for key in keys:
                    data = self.working_memory_cache.client.get(key)
                    if data:
                        try:
                            mem_dict = json.loads(data)
                            memory = self._dict_to_memory(mem_dict)
                            if query.lower() in memory.content.lower() and memory.domain.lower() == domain.lower():
                                results.append(memory)
                        except (json.JSONDecodeError, AttributeError) as e:
                            if logger:
                                logger.error(f"Error decoding or processing memory from cache: {e}", extra={"key": key})
                            else:
                                print(f"[!] Error decoding or processing memory from cache: {e}")
            else:
                if logger:
                    logger.warning("Working memory cache not available.")
                else:
                    print("[!] Working memory cache not available.")

        except Exception as e:
            if logger:
                logger.error(f"Error recalling from working memory: {e}")
            else:
                print(f"[!] Error recalling from working memory: {e}")

        return results[:top_k]

    def _semantic_similarity_search(self, query: str, domain: str, top_k: int) -> List[Memory]:
        """
        Performs semantic similarity search using vector embeddings.

        Args:
            query: The query string.
            domain: The domain of the query.
            top_k: The maximum number of memories to return.

        Returns:
            A list of Memory objects.
        """
        results: List[Memory] = []
        try:
            if self.vector_manager:
                documents = self.vector_manager.search(query, top_k=top_k, domain=domain)
                for doc in documents:
                    if doc.score >= self.similarity_threshold:
                        try:
                            memory = self._document_to_memory(doc)
                            results.append(memory)
                        except Exception as e:
                            if logger:
                                logger.error(f"Error converting document to memory: {e}", extra={"document": str(doc)})
                            else:
                                print(f"[!] Error converting document to memory: {e}")
        except Exception as e:
            if logger:
                logger.error(f"Error performing semantic similarity search: {e}")
            else:
                print(f"[!] Error performing semantic similarity search: {e}")
        return results

    def _recall_from_episodic_memory(self, query: str, domain: str, top_k: int) -> List[Memory]:
        """
        Recalls memories from episodic memory (Surprise Memory).

        Args:
            query: The query string.
            domain: The domain of the query.
            top_k: The maximum number of memories to return.

        Returns:
            A list of Memory objects.
        """
        results: List[Memory] = []
        try:
            if self.surprise_memory:
                memory_items = self.surprise_memory.search_memory(query, top_k=top_k)
                for item in memory_items:
                    try:
                        memory = self._memory_item_to_memory(item)
                        if memory.domain.lower() == domain.lower():
                            results.append(memory)
                    except Exception as e:
                        if logger:
                            logger.error(f"Error converting memory item to memory: {e}", extra={"item": str(item)})
                        else:
                            print(f"[!] Error converting memory item to memory: {e}")
        except Exception as e:
            if logger:
                logger.error(f"Error recalling from episodic memory: {e}")
            else:
                print(f"[!] Error recalling from episodic memory: {e}")
        return results

    def promote_memory_tier(self, memory: Memory) -> None:
        """
        Promotes a memory to a higher tier based on its score.

        Args:
            memory: The Memory object to promote.
        """
        try:
            if memory.score >= self.promotion_threshold:
                if memory.tier == MemoryTier.WORKING:
                    memory.tier = MemoryTier.EPISODIC
                    if self.working_memory_cache and self.working_memory_cache.available:
                        self.working_memory_cache.client.delete(f"{self.working_memory_cache.namespace}:{memory.id}")  # Remove from working memory
                    if self.surprise_memory:
                        self.surprise_memory.add_memory(MemoryItem(content=memory.content, source=memory.source, domain=memory.domain)) # Add to episodic
                    if logger:
                        logger.info(f"Promoted memory {memory.id} to episodic tier.")
                elif memory.tier == MemoryTier.EPISODIC:
                    memory.tier = MemoryTier.SEMANTIC
                    # Logic to promote to semantic memory (e.g., Knowledge Graph)
                    if logger:
                        logger.info(f"Promoted memory {memory.id} to semantic tier.")
        except Exception as e:
            if logger:
                logger.error(f"Error promoting memory tier: {e}")
            else:
                print(f"[!] Error promoting memory tier: {e}")

    def _dict_to_memory(self, mem_dict: Dict) -> Memory:
        """Converts a dictionary to a Memory object."""
        try:
            tier_str = mem_dict.get('tier', 'working')  # Default to 'working' if missing
            tier = MemoryTier(tier_str)
            return Memory(
                id=mem_dict['id'],
                content=mem_dict['content'],
                tier=tier,
                score=float(mem_dict['score']),
                domain=mem_dict['domain'],
                source=mem_dict['source'],
                timestamp=mem_dict['timestamp'],
                embedding=mem_dict.get('embedding'),
                relations=mem_dict.get('relations'),
                access_count=mem_dict.get('access_count', 0),
                last_accessed=mem_dict.get('last_accessed'),
                metadata=mem_dict.get('metadata')
            )
        except KeyError as e:
            raise ValueError(f"Missing key in memory dictionary: {e}")
        except ValueError as e:
            raise ValueError(f"Invalid value in memory dictionary: {e}")

    def _document_to_memory(self, doc: 'VectorDocument') -> Memory:
        """Converts a VectorDocument to a Memory object."""
        try:
            return Memory(
                id=doc.id,
                content=doc.content,
                tier=MemoryTier.SEMANTIC,  # Assuming vector store is semantic
                score=doc.score,
                domain=doc.domain,
                source=doc.metadata.get('source', 'vector_store'),
                timestamp=doc.metadata.get('timestamp', str(datetime.now())),
                embedding=doc.embedding,
                relations=doc.metadata.get('relations'),
                access_count=0,
                last_accessed=None,
                metadata=doc.metadata
            )
        except KeyError as e:
            raise ValueError(f"Missing key in VectorDocument: {e}")
        except ValueError as e:
            raise ValueError(f"Invalid value in VectorDocument: {e}")

    def _memory_item_to_memory(self, item: 'MemoryItem') -> Memory:
        """Converts a MemoryItem to a Memory object."""
        try:
            return Memory(
                id=hashlib.md5(item.content.encode()).hexdigest(), # Create a unique ID from content
                content=item.content,
                tier=MemoryTier.EPISODIC,
                score=item.surprise_score.relevance,
                domain=item.domain,
                source=item.source,
                timestamp=str(datetime.now()),
                embedding=None,  # Episodic memory might not have embeddings
                relations=None,
                access_count=0,
                last_accessed=None,
                metadata={}
            )
        except KeyError as e:
            raise ValueError(f"Missing key in MemoryItem: {e}")
        except ValueError as e:
            raise ValueError(f"Invalid value in MemoryItem: {e}")


# --- Test Functions ---
if __name__ == '__main__':
    # Simple Test Configuration (replace with your actual setup)
    class MockRedisClient:
        def __init__(self):
            self.data = {}

        def get(self, key):
            return self.data.get(key)

        def setex(self, key, ttl, value):
            self.data[key] = value

        def keys(self, pattern):
            return [k for k in self.data if pattern.replace("*", "") in k]

        def delete(self, key):
            if key in self.data:
                del self.data[key]

        def ttl(self, key):
            return 60  # Mock TTL value

        def expire(self, key, ttl):
            pass

        def hincrby(self, name, key, amount=1):
            return 1

    class MockWorkingMemoryCache:
        def __init__(self):
            self.client = MockRedisClient()
            self.available = True
            self.namespace = "test_namespace"

        def set(self, memory: Memory, adaptive_ttl: bool = True):
            memory_data = json.dumps(memory.to_dict())
            self.client.setex(f"{self.namespace}:{memory.id}", 60, memory_data)

        def get(self, memory_id: str, extend_ttl: bool = True) -> Optional[Memory]:
            data = self.client.get(f"{self.namespace}:{memory_id}")
            if data:
                mem_dict = json.loads(data)
                return Memory(
                id=mem_dict['id'],
                content=mem_dict['content'],
                tier=MemoryTier(mem_dict['tier']),
                score=float(mem_dict['score']),
                domain=mem_dict['domain'],
                source=mem_dict['source'],
                timestamp=mem_dict['timestamp'],
                embedding=mem_dict.get('embedding'),
                relations=mem_dict.get('relations'),
                access_count=mem_dict.get('access_count', 0),
                last_accessed=mem_dict.get('last_accessed'),
                metadata=mem_dict.get('metadata')
            )
            return None

    class MockVectorManager:
        def search(self, query: str, top_k: int, domain: str) -> List['VectorDocument']:
            return [VectorDocument(id="vector_doc_1", content="Test vector memory", embedding=[0.1, 0.2], score=0.8, domain="general", metadata={})]

    class MockSurpriseMemory:
        def search_memory(self, query: str, top_k: int) -> List['MemoryItem']:
            return [MemoryItem(content="Test episodic memory", source="test", domain="general", surprise_score = SurpriseScore(relevance=0.6, surprise=0.2))]

        def add_memory(self, item: 'MemoryItem') -> None:
            pass

    def test_recall_from_working_memory():
        working_memory_cache = MockWorkingMemoryCache()
        recall_engine = MemoryRecallEngine(working_memory_cache=working_memory_cache)
        memory = Memory(id="test_memory", content="Test working memory", tier=MemoryTier.WORKING, score=0.5, domain="general", source="test", timestamp=str(datetime.now()))
        working_memory_cache.set(memory)
        results = recall_engine.recall("working memory", domain="general")
        assert len(results) > 0
        assert results[0].id == "test_memory"
        print("test_recall_from_working_memory passed")

    def test_semantic_similarity_search():
        vector_manager = MockVectorManager()
        recall_engine = MemoryRecallEngine(vector_manager=vector_manager)
        results = recall_engine.recall("Test", domain="general")
        assert len(results) > 0
        assert results[0].id == "vector_doc_1"
        print("test_semantic_similarity_search passed")

    def test_episodic_memory_recall():
        surprise_memory = MockSurpriseMemory()
        recall_engine = MemoryRecallEngine(surprise_memory=surprise_memory)
        results = recall_engine.recall("Test", domain="general")
        assert len(results) > 0
        assert results[0].content == "Test episodic memory"
        print("test_episodic_memory_recall passed")

    test_recall_from_working_memory()
    test_semantic_similarity_search()
    test_episodic_memory_recall()