# memory_recall_engine.py
"""
High-Performance Memory Recall Engine with Semantic Similarity Search and Tier Management.
"""

import json
import hashlib
import time
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
from enum import Enum

# Import existing Genesis memory components
try:
    from surprise_memory import MemorySystem, MemoryItem, SurpriseScore
except ImportError:
    print("[!] surprise_memory not found.  Please install it.")
    MemorySystem = None
    MemoryItem = None
    SurpriseScore = None

# Import observability modules
try:
    from logging_config import get_logger
    LOGGING_AVAILABLE = True
    logger = get_logger("genesis.recall")
except ImportError:
    LOGGING_AVAILABLE = False
    logger = None

try:
    from metrics import GenesisMetrics, TimedOperation
    METRICS_AVAILABLE = True
except ImportError:
    METRICS_AVAILABLE = False
    GenesisMetrics = None

# Import vector backends for semantic similarity search
try:
    from vector_backends import VectorManager, VectorDocument
    VECTOR_AVAILABLE = True
except ImportError:
    VECTOR_AVAILABLE = False
    VectorManager = None

# Import WorkingMemoryCache
try:
    from genesis_memory_cortex import WorkingMemoryCache, Memory, MemoryTier
except ImportError:
    print("[!] genesis_memory_cortex not found. Please install it.")
    WorkingMemoryCache = None
    Memory = None
    MemoryTier = None

# Import schema
try:
    from memory_schemas import MemoryItemInput, MemoryOutput
except ImportError:
    print("[!] memory_schemas not found. Please install it.")
    MemoryItemInput = None
    MemoryOutput = None


class MemoryRecallEngine:
    """
    Engine for recalling memories based on semantic similarity and memory tier.

    Features:
    - Semantic similarity search using VectorManager.
    - Fast retrieval from WorkingMemoryCache.
    - Memory tier promotion.
    - Comprehensive error handling.
    - Logging and metrics.
    """

    def __init__(self, vector_manager: Optional[VectorManager] = None,
                 working_memory_cache: Optional[WorkingMemoryCache] = None,
                 similarity_threshold: float = 0.7):
        """
        Initializes the MemoryRecallEngine.

        Args:
            vector_manager: The VectorManager instance for semantic similarity search.
            working_memory_cache: The WorkingMemoryCache instance for fast retrieval.
            similarity_threshold: The minimum similarity score for a memory to be considered relevant.
        """
        self.vector_manager = vector_manager
        self.working_memory_cache = working_memory_cache
        self.similarity_threshold = similarity_threshold

        if self.vector_manager is None and VECTOR_AVAILABLE:
            print("[!] MemoryRecallEngine: VectorManager is None, semantic search will be disabled.")
            if logger:
                logger.warning("MemoryRecallEngine: VectorManager is None, semantic search will be disabled.")

        if self.working_memory_cache is None and WorkingMemoryCache:
            print("[!] MemoryRecallEngine: WorkingMemoryCache is None, cache retrieval will be disabled.")
            if logger:
                logger.warning("MemoryRecallEngine: WorkingMemoryCache is None, cache retrieval will be disabled.")

    def recall(self, query: str, domain_filter: Optional[str] = None, top_k: int = 5) -> List[Memory]:
        """
        Recalls memories based on the given query, domain filter, and top_k results.

        Args:
            query: The query string.
            domain_filter: An optional domain filter.
            top_k: The number of top results to return.

        Returns:
            A list of Memory objects that match the query.
        """
        start_time = time.time()
        results: List[Memory] = []

        try:
            # 1. Attempt retrieval from WorkingMemoryCache
            if self.working_memory_cache and self.working_memory_cache.available:
                cached_memory = self._retrieve_from_cache(query)
                if cached_memory:
                    results.append(cached_memory)
                    if logger:
                        logger.debug(f"Memory recalled from cache: {cached_memory.id}")
                    if METRICS_AVAILABLE and GenesisMetrics:
                        GenesisMetrics.cache_hits.inc(labels={"tier": "working"})

            # 2. Semantic similarity search using VectorManager
            if self.vector_manager and VECTOR_AVAILABLE:
                vector_results = self._semantic_search(query, domain_filter, top_k)
                results.extend(vector_results)
            else:
                if logger:
                    logger.warning("Vector search disabled. No VectorManager available.")

            # 3. Tier promotion logic (example)
            # This is a placeholder.  Real implementation would involve
            # tracking access counts and promoting memories based on usage.
            # self._promote_memory_tier(results)

            # Remove duplicates based on memory ID
            unique_results = []
            seen_ids = set()
            for result in results:
                if result.id not in seen_ids:
                    unique_results.append(result)
                    seen_ids.add(result.id)

            # Sort by score (descending)
            unique_results.sort(key=lambda x: x.score, reverse=True)

            if METRICS_AVAILABLE and GenesisMetrics:
                duration = time.time() - start_time
                GenesisMetrics.recall_latency.observe(duration)
                GenesisMetrics.memory_operations.inc(labels={"tier": "all", "op": "recall"})

            return unique_results[:top_k]

        except Exception as e:
            if logger:
                logger.error(f"Error during memory recall: {e}")
            else:
                print(f"[!] Error during memory recall: {e}")
            return []

    def _retrieve_from_cache(self, query: str) -> Optional[Memory]:
        """
        Retrieves a memory from the WorkingMemoryCache based on the query.

        For simplicity, this implementation uses the query as the memory ID.
        A more sophisticated implementation would use a hash of the query.

        Args:
            query: The query string (used as memory ID).

        Returns:
            The Memory object if found, otherwise None.
        """
        try:
            if self.working_memory_cache and self.working_memory_cache.available:
                memory = self.working_memory_cache.get(query)
                return memory
            else:
                if logger:
                    logger.warning("WorkingMemoryCache is not available.")
                return None
        except Exception as e:
            if logger:
                logger.error(f"Error retrieving from cache: {e}")
            return None

    def _semantic_search(self, query: str, domain_filter: Optional[str], top_k: int) -> List[Memory]:
        """
        Performs semantic similarity search using the VectorManager.

        Args:
            query: The query string.
            domain_filter: An optional domain filter.
            top_k: The number of top results to return.

        Returns:
            A list of Memory objects that match the query.
        """
        try:
            if self.vector_manager and VECTOR_AVAILABLE:
                documents = self.vector_manager.search(query, top_k=top_k, domain=domain_filter)
                memories: List[Memory] = []
                for doc in documents:
                    if doc.score >= self.similarity_threshold:
                        try:
                            memory = self._document_to_memory(doc)
                            memories.append(memory)
                        except Exception as e:
                            if logger:
                                logger.error(f"Error converting VectorDocument to Memory: {e}")
                            else:
                                print(f"[!] Error converting VectorDocument to Memory: {e}")
                return memories
            else:
                if logger:
                    logger.warning("VectorManager is not available.")
                return []
        except Exception as e:
            if logger:
                logger.error(f"Error during semantic search: {e}")
            else:
                print(f"[!] Error during semantic search: {e}")
            return []

    def _promote_memory_tier(self, memories: List[Memory]) -> None:
        """
        Promotes the memory tier based on access count and other factors.

        This is a placeholder implementation.  A real implementation would
        involve more complex logic and interaction with the underlying
        memory systems (e.g., updating Redis, moving data to a knowledge graph).

        Args:
            memories: A list of Memory objects to consider for promotion.
        """
        # Placeholder implementation
        for memory in memories:
            if memory.access_count > 10 and memory.tier == MemoryTier.WORKING:
                memory.tier = MemoryTier.EPISODIC
                if logger:
                    logger.info(f"Promoted memory {memory.id} to episodic tier.")

    def _document_to_memory(self, doc: 'VectorDocument') -> Memory:
        """Converts a VectorDocument to a Memory object."""
        try:
            metadata = json.loads(doc.metadata) if isinstance(doc.metadata, str) else doc.metadata
        except (TypeError, json.JSONDecodeError):
            metadata = {}  # Handle cases where metadata is None or invalid JSON.

        return Memory(
            id=doc.id,
            content=doc.content,
            tier=MemoryTier.WORKING,  # Assuming all vector search results start in working memory
            score=doc.score,
            domain=doc.domain,
            source="vector_search",
            timestamp=str(datetime.now()),
            embedding=doc.embedding,
            relations=[],
            metadata=metadata
        )


# Example Usage / Tests
if __name__ == "__main__":
    # Mock VectorManager and WorkingMemoryCache for testing
    class MockVectorManager:
        def search(self, query: str, top_k: int = 5, domain: Optional[str] = None) -> List['VectorDocument']:
            results = []
            if query == "performance optimization":
                results.append(VectorDocument(id="perf1", content="Parallel execution improves performance", score=0.8, domain="optimization", embedding=[0.1, 0.2], metadata='{"type": "insight"}'))
                results.append(VectorDocument(id="perf2", content="Caching frequently used data", score=0.75, domain="optimization", embedding=[0.2, 0.3], metadata='{"type": "insight"}'))

            return results

    class MockWorkingMemoryCache:
        def __init__(self):
            self.cache = {}
            self.available = True

        def get(self, memory_id: str) -> Optional[Memory]:
            return self.cache.get(memory_id)

        def set(self, memory: Memory, adaptive_ttl: bool = True):
            self.cache[memory.id] = memory

    def test_recall_from_cache():
        """Tests recall from the working memory cache."""
        mock_cache = MockWorkingMemoryCache()
        mock_cache.set(Memory(id="test_cache", content="This is a test memory", tier=MemoryTier.WORKING, score=0.9, domain="test", source="test", timestamp=str(datetime.now())))
        engine = MemoryRecallEngine(working_memory_cache=mock_cache)
        results = engine.recall("test_cache")
        assert len(results) == 1
        assert results[0].id == "test_cache"
        print("[OK] test_recall_from_cache passed")

    def test_semantic_search():
        """Tests semantic similarity search."""
        mock_vector_manager = MockVectorManager()
        engine = MemoryRecallEngine(vector_manager=mock_vector_manager)
        results = engine.recall("performance optimization")
        assert len(results) > 0
        assert "performance" in results[0].content
        print("[OK] test_semantic_search passed")

    def test_no_results():
        """Tests the case where no results are found."""
        mock_vector_manager = MockVectorManager()
        engine = MemoryRecallEngine(vector_manager=mock_vector_manager)
        results = engine.recall("nonexistent query")
        assert len(results) == 0
        print("[OK] test_no_results passed")

    test_recall_from_cache()
    test_semantic_search()
    test_no_results()