# memory_recall_engine.py
"""
High-Performance Memory Recall Engine for AIVA.

This module provides a high-performance memory recall system with semantic
similarity search, fast retrieval with caching, memory tier promotion,
comprehensive error handling, logging, and metrics.
"""

import json
import hashlib
import time
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
from enum import Enum
import threading
import logging
import os

# Import existing Genesis memory components
try:
    from genesis_memory_cortex import MemorySystem, MemoryItem, SurpriseScore, Memory, MemoryTier, WorkingMemoryCache
    from vector_backends import VectorManager, VectorDocument
    from logging_config import get_logger
    from metrics import GenesisMetrics

    HAS_DEPENDENCIES = True
except ImportError as e:
    print(f"Failed to import dependencies: {e}.  Make sure genesis_memory_cortex, vector_backends, logging_config, and metrics are available.")
    HAS_DEPENDENCIES = False
    MemorySystem = None
    MemoryItem = None
    SurpriseScore = None
    Memory = None
    MemoryTier = None
    WorkingMemoryCache = None
    VectorManager = None
    VectorDocument = None
    get_logger = None
    GenesisMetrics = None


# Configure logging
logger = None
if HAS_DEPENDENCIES and get_logger:
    logger = get_logger("memory_recall_engine")
else:
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

class RecallEngine:
    """
    A high-performance memory recall engine.
    """

    def __init__(self, working_memory_cache: Optional[WorkingMemoryCache] = None,
                 vector_manager: Optional[VectorManager] = None):
        """
        Initializes the RecallEngine.

        Args:
            working_memory_cache: An optional WorkingMemoryCache instance.
            vector_manager: An optional VectorManager instance.
        """
        self.working_memory_cache = working_memory_cache
        self.vector_manager = vector_manager
        self.recall_threshold = 0.7 # Adjust this threshold as needed
        self.lock = threading.Lock()

    def recall(self, query: str, domain_filter: Optional[str] = None, top_k: int = 5) -> List[Memory]:
        """
        Recalls memories based on a query, using semantic similarity search and caching.

        Args:
            query: The query string.
            domain_filter: An optional domain filter.
            top_k: The number of top memories to retrieve.

        Returns:
            A list of Memory objects that match the query, sorted by relevance.
        """
        if not HAS_DEPENDENCIES:
            logger.warning("Recall engine is running in degraded mode due to missing dependencies.")
            return []

        start_time = time.time()
        try:
            # 1. Check Working Memory Cache
            cached_results = self._check_cache(query, domain_filter, top_k)
            if cached_results:
                if logger:
                    logger.debug(f"Cache hit for query: {query}")
                if GenesisMetrics:
                    GenesisMetrics.cache_hits.inc(labels={"tier": "working"})
                return cached_results

            # 2. Semantic Similarity Search
            if not self.vector_manager:
                logger.warning("Vector manager is not available. Cannot perform semantic search.")
                return []

            results = self._semantic_search(query, domain_filter, top_k)

            # 3. Cache Results
            self._cache_results(query, domain_filter, results)

            return results
        except Exception as e:
            logger.exception(f"Error during recall: {e}")
            return []
        finally:
            duration = time.time() - start_time
            if GenesisMetrics:
                GenesisMetrics.recall_latency.observe(duration)

    def _check_cache(self, query: str, domain_filter: Optional[str], top_k: int) -> Optional[List[Memory]]:
        """
        Checks the working memory cache for relevant memories.

        Args:
            query: The query string.
            domain_filter: An optional domain filter.
            top_k: The number of top memories to retrieve.

        Returns:
            A list of Memory objects if found in the cache, otherwise None.
        """
        if not self.working_memory_cache or not self.working_memory_cache.available:
            return None

        cache_key = self._generate_cache_key(query, domain_filter, top_k)
        
        # Use a lock to prevent race conditions during cache access
        with self.lock:
            cached_data = self.working_memory_cache.get(cache_key)
            if cached_data:
                try:
                    memories_json = json.loads(cached_data.content) # Access content attribute
                    memories = [Memory(**m) for m in memories_json]
                    return memories
                except (TypeError, json.JSONDecodeError) as e:
                    logger.error(f"Error decoding cached data: {e}")
                    return None
            else:
                return None

    def _semantic_search(self, query: str, domain_filter: Optional[str], top_k: int) -> List[Memory]:
        """
        Performs semantic similarity search using the VectorManager.

        Args:
            query: The query string.
            domain_filter: An optional domain filter.
            top_k: The number of top memories to retrieve.

        Returns:
            A list of Memory objects that match the query, sorted by relevance.
        """
        if not self.vector_manager:
            logger.warning("Vector manager is not available.")
            return []

        try:
            results = self.vector_manager.search(
                query=query,
                limit=top_k,
                domain=domain_filter
            )

            # Convert VectorDocuments to Memory objects
            memories = []
            for doc in results:
                try:
                    memory = Memory(
                        id=doc.id,
                        content=doc.content,
                        tier=MemoryTier.SEMANTIC,  # Assuming semantic tier after vector search
                        score=doc.score,
                        domain=doc.domain,
                        source="vector_search",
                        timestamp=doc.metadata.get("timestamp", str(time.time())),
                        metadata=doc.metadata,
                    )
                    memories.append(memory)
                except Exception as e:
                    logger.error(f"Error converting VectorDocument to Memory: {e}")
                    continue

            return memories

        except Exception as e:
            logger.exception(f"Error during semantic search: {e}")
            return []

    def _cache_results(self, query: str, domain_filter: Optional[str], results: List[Memory]):
        """
        Caches the search results in the working memory cache.

        Args:
            query: The query string.
            domain_filter: An optional domain filter.
            results: The list of Memory objects to cache.
        """
        if not self.working_memory_cache or not self.working_memory_cache.available:
            return

        try:
            if results:
                cache_key = self._generate_cache_key(query, domain_filter, len(results))
                # Convert Memory objects to dictionaries for JSON serialization
                memories_json = [m.to_dict() for m in results]
                
                # Store the JSON representation of the list of memories in the cache.
                memory = Memory(
                    id=cache_key,
                    content=json.dumps(memories_json),
                    tier=MemoryTier.WORKING,
                    score=1.0,
                    domain="recall_cache",
                    source="recall_engine",
                    timestamp=str(time.time()),
                )
                
                # Use a lock to prevent race conditions during cache writes
                with self.lock:
                    self.working_memory_cache.set(memory)
                if logger:
                    logger.debug(f"Cached results for query: {query}")
            else:
                logger.info("No results to cache.")
        except Exception as e:
            logger.error(f"Error caching results: {e}")

    def _generate_cache_key(self, query: str, domain_filter: Optional[str], top_k: int) -> str:
        """
        Generates a cache key based on the query, domain filter, and top_k value.

        Args:
            query: The query string.
            domain_filter: An optional domain filter.
            top_k: The number of top memories to retrieve.

        Returns:
            A unique cache key string.
        """
        key_string = f"{query}-{domain_filter}-{top_k}"
        return hashlib.md5(key_string.encode()).hexdigest()

    def promote_memory_tier(self, memory_id: str, new_tier: MemoryTier):
        """
        Promotes a memory to a higher tier.  This is a placeholder for future functionality.
        """
        logger.info(f"Promoting memory {memory_id} to tier {new_tier}")
        # TODO: Implement actual tier promotion logic (e.g., moving data between databases)
        pass


if __name__ == "__main__":
    # Example Usage and Testing

    # Mock dependencies if they are not available
    if not HAS_DEPENDENCIES:
        class MockWorkingMemoryCache:
            def __init__(self):
                self.cache = {}
            def get(self, key):
                return self.cache.get(key)
            def set(self, key, value):
                self.cache[key] = value
            @property
            def available(self):
                return True

        class MockVectorManager:
            def search(self, query, limit, domain):
                # Return some mock VectorDocument objects
                return [
                    VectorDocument(id="1", content="This is a test memory about performance.", score=0.8, domain="performance", metadata={"timestamp": str(time.time())}),
                    VectorDocument(id="2", content="Another memory about optimization.", score=0.7, domain="optimization", metadata={"timestamp": str(time.time())}),
                ]
        WorkingMemoryCache = MockWorkingMemoryCache
        VectorManager = MockVectorManager
        MemoryTier = Enum('MemoryTier', ['WORKING', 'EPISODIC', 'SEMANTIC'])
        Memory = dataclass(
            'Memory',
            [('id', str), ('content', str), ('tier', MemoryTier), ('score', float), ('domain', str), ('source', str), ('timestamp', str), ('metadata', Optional[Dict], None)]
        )

    def test_recall_engine_basic():
        """Basic test to ensure recall engine initializes and returns results."""
        if not HAS_DEPENDENCIES:
            print("Skipping test_recall_engine_basic due to missing dependencies.")
            return

        working_memory_cache = WorkingMemoryCache() if HAS_DEPENDENCIES else None
        vector_manager = VectorManager() if HAS_DEPENDENCIES else None
        recall_engine = RecallEngine(working_memory_cache=vector_manager, vector_manager=vector_manager)

        results = recall_engine.recall("performance", top_k=2)
        assert isinstance(results, list)
        if results:
            assert len(results) <= 2
            assert isinstance(results[0], Memory)
        print("test_recall_engine_basic passed")

    def test_recall_engine_caching():
        """Test to verify caching functionality."""
        if not HAS_DEPENDENCIES:
            print("Skipping test_recall_engine_caching due to missing dependencies.")
            return
        working_memory_cache = WorkingMemoryCache() if HAS_DEPENDENCIES else None
        vector_manager = VectorManager() if HAS_DEPENDENCIES else None
        recall_engine = RecallEngine(working_memory_cache=working_memory_cache, vector_manager=vector_manager)

        # First call should populate the cache
        results1 = recall_engine.recall("performance", top_k=1)

        # Second call should retrieve from the cache
        results2 = recall_engine.recall("performance", top_k=1)

        assert isinstance(results2, list)
        if results2:
            assert len(results2) <= 1
            assert isinstance(results2[0], Memory)
        print("test_recall_engine_caching passed")

    def test_recall_engine_domain_filter():
        """Test to verify domain filtering."""
        if not HAS_DEPENDENCIES:
            print("Skipping test_recall_engine_domain_filter due to missing dependencies.")
            return
        working_memory_cache = WorkingMemoryCache() if HAS_DEPENDENCIES else None
        vector_manager = VectorManager() if HAS_DEPENDENCIES else None
        recall_engine = RecallEngine(working_memory_cache=working_memory_cache, vector_manager=vector_manager)

        results = recall_engine.recall("performance", domain_filter="performance", top_k=1)
        assert isinstance(results, list)
        if results:
            assert len(results) <= 1
            assert isinstance(results[0], Memory)
            assert results[0].domain == "performance"
        print("test_recall_engine_domain_filter passed")

    # Run tests
    test_recall_engine_basic()
    test_recall_engine_caching()
    test_recall_engine_domain_filter()