# memory_recall_engine.py
"""
High-Performance Memory Recall Engine for AIVA

This module provides a high-performance memory recall system with semantic
similarity search, caching, memory tier promotion, error handling,
logging, and metrics.
"""

import json
import hashlib
import time
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
from enum import Enum
import threading

# Import existing Genesis memory components (mocked for standalone execution)
try:
    from genesis_memory_cortex import MemorySystem, MemoryItem, SurpriseScore, Memory, MemoryTier, WorkingMemoryCache
    from logging_config import get_logger
    from metrics import GenesisMetrics
    from vector_backends import VectorManager, VectorDocument
    GENESIS_AVAILABLE = True
except ImportError:
    print("Genesis dependencies not found. Running in standalone mode.")
    GENESIS_AVAILABLE = False

    class MemorySystem:
        pass

    class MemoryItem:
        pass

    class SurpriseScore:
        pass

    @dataclass
    class Memory:
        id: str
        content: str
        tier: 'MemoryTier'
        score: float
        domain: str
        source: str
        timestamp: str
        embedding: Optional[List[float]] = None
        relations: Optional[List[str]] = None
        access_count: int = 0
        last_accessed: Optional[str] = None
        metadata: Optional[Dict] = None

        def to_dict(self) -> Dict:
            d = asdict(self)
            d['tier'] = self.tier.value
            return d

    class MemoryTier(Enum):
        DISCARD = "discard"
        WORKING = "working"
        EPISODIC = "episodic"
        SEMANTIC = "semantic"

    class WorkingMemoryCache:
        def __init__(self, *args, **kwargs):
            pass
        def get(self, memory_id: str, extend_ttl: bool = True) -> Optional['Memory']:
            return None
        def set(self, memory: 'Memory', adaptive_ttl: bool = True):
            pass

    def get_logger(name):
        class MockLogger:
            def info(self, msg, extra=None):
                print(f"INFO: {msg} {extra}")
            def warning(self, msg, extra=None):
                print(f"WARNING: {msg} {extra}")
            def error(self, msg, extra=None):
                print(f"ERROR: {msg} {extra}")
            def debug(self, msg, extra=None):
                print(f"DEBUG: {msg} {extra}")
        return MockLogger()

    class GenesisMetrics:
        memory_operations = None
        memory_latency = None
        cache_hits = None
        cache_misses = None

    class VectorManager:
        def __init__(self, *args, **kwargs):
            pass
        def search(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
            return []

    class VectorDocument:
        pass

logger = get_logger("memory_recall_engine")

class MemoryRecallEngine:
    """
    Engine for recalling memories based on semantic similarity and tiering.
    """

    def __init__(
        self,
        working_memory_cache: Optional[WorkingMemoryCache] = None,
        vector_manager: Optional[VectorManager] = None,
        similarity_threshold: float = 0.7,
        promotion_threshold: float = 0.8,
    ):
        """
        Initializes the MemoryRecallEngine.

        Args:
            working_memory_cache: The WorkingMemoryCache instance.
            vector_manager: The VectorManager instance for similarity search.
            similarity_threshold: The minimum similarity score for a memory to be considered relevant.
            promotion_threshold: The minimum score to promote a memory to a higher tier.
        """
        self.working_memory_cache = working_memory_cache or WorkingMemoryCache()
        self.vector_manager = vector_manager or VectorManager()
        self.similarity_threshold = similarity_threshold
        self.promotion_threshold = promotion_threshold
        self.lock = threading.Lock()  # Thread safety for access counts

    def recall(self, query: str, top_k: int = 5) -> List[Memory]:
        """
        Recalls memories relevant to the given query.

        Args:
            query: The query string.
            top_k: The maximum number of memories to return.

        Returns:
            A list of Memory objects, sorted by relevance.
        """
        try:
            start_time = time.time()
            logger.info(f"Recalling memories for query: {query}")

            # 1. Check Working Memory Cache
            cached_results = self._check_cache(query)
            if cached_results:
                logger.info(f"Cache hit for query: {query}")
                if GENESIS_AVAILABLE and GenesisMetrics:
                    GenesisMetrics.cache_hits.inc(labels={"tier": "working"})
                return cached_results

            # 2. Semantic Similarity Search
            vector_results = self._semantic_search(query, top_k)

            # 3. Post-processing and Tier Promotion
            recalled_memories = self._post_process_results(vector_results)

            # 4. Store in Working Memory Cache
            self._store_in_cache(query, recalled_memories)

            duration = time.time() - start_time
            logger.info(f"Memory recall completed in {duration:.4f} seconds")
            return recalled_memories

        except Exception as e:
            logger.error(f"Error during memory recall: {e}")
            return []

    def _check_cache(self, query: str) -> Optional[List[Memory]]:
        """
        Checks the working memory cache for relevant memories.

        Args:
            query: The query string.

        Returns:
            A list of Memory objects if found in the cache, otherwise None.
        """
        # Hash the query for cache key stability
        query_hash = hashlib.sha256(query.encode()).hexdigest()
        cached_memory = self.working_memory_cache.get(query_hash)
        if cached_memory:
            if isinstance(cached_memory, Memory):
                return [cached_memory]
            else:
                try:
                    return [Memory(**item) for item in json.loads(cached_memory)]
                except (TypeError, json.JSONDecodeError) as e:
                    logger.error(f"Error decoding cached memory: {e}")
                    return None
        return None


    def _semantic_search(self, query: str, top_k: int) -> List[Tuple[str, float]]:
        """
        Performs semantic similarity search using the VectorManager.

        Args:
            query: The query string.
            top_k: The maximum number of results to return.

        Returns:
            A list of tuples containing memory IDs and similarity scores.
        """
        try:
            if self.vector_manager:
                return self.vector_manager.search(query, top_k)
            else:
                logger.warning("VectorManager not available. Skipping semantic search.")
                return []
        except Exception as e:
            logger.error(f"Error during semantic search: {e}")
            return []

    def _post_process_results(self, vector_results: List[Tuple[str, float]]) -> List[Memory]:
        """
        Post-processes the results from the vector search, including tier promotion.

        Args:
            vector_results: A list of tuples containing memory IDs and similarity scores.

        Returns:
            A list of Memory objects, sorted by relevance.
        """
        recalled_memories: List[Memory] = []
        for memory_id, score in vector_results:
            try:
                # Mock memory retrieval - replace with actual DB call
                memory = self._retrieve_memory(memory_id)
                if memory:
                    if score >= self.similarity_threshold:
                        memory.score = score
                        self._update_access_count(memory)
                        self._promote_tier(memory)
                        recalled_memories.append(memory)
                    else:
                        logger.debug(f"Memory {memory_id} below similarity threshold ({score} < {self.similarity_threshold})")
                else:
                    logger.warning(f"Memory {memory_id} not found.")
            except Exception as e:
                logger.error(f"Error processing memory {memory_id}: {e}")

        recalled_memories.sort(key=lambda x: x.score, reverse=True)
        return recalled_memories

    def _retrieve_memory(self, memory_id: str) -> Optional[Memory]:
        """
        Retrieves a memory from the database (mocked for now).

        Args:
            memory_id: The ID of the memory to retrieve.

        Returns:
            The Memory object if found, otherwise None.
        """
        # Replace with actual database retrieval logic
        # Example:
        # memory_data = self.db_client.get_memory(memory_id)
        # if memory_data:
        #     return Memory(**memory_data)

        # Mocked memory for testing
        if memory_id == "memory_1":
            return Memory(
                id="memory_1",
                content="Parallel execution improves performance by 37%",
                tier=MemoryTier.WORKING,
                score=0.6,
                domain="performance",
                source="experiment",
                timestamp="2024-01-01T00:00:00",
                embedding=[0.1, 0.2, 0.3],
                relations=[],
                access_count=0,
                last_accessed=None,
                metadata={}
            )
        elif memory_id == "memory_2":
            return Memory(
                id="memory_2",
                content="Vector search is efficient for semantic similarity.",
                tier=MemoryTier.EPISODIC,
                score=0.7,
                domain="search",
                source="documentation",
                timestamp="2024-01-02T00:00:00",
                embedding=[0.4, 0.5, 0.6],
                relations=[],
                access_count=0,
                last_accessed=None,
                metadata={}
            )
        else:
            return None

    def _update_access_count(self, memory: Memory):
        """Updates access count and last accessed timestamp in a thread-safe way."""
        with self.lock:
            memory.access_count += 1
            memory.last_accessed = datetime.now().isoformat()
            # Replace with actual DB update logic
            # Example:
            # self.db_client.update_memory(memory.id, {"access_count": memory.access_count, "last_accessed": memory.last_accessed})
            logger.debug(f"Updated access count for memory {memory.id} to {memory.access_count}")

    def _promote_tier(self, memory: Memory):
        """
        Promotes the memory tier if the score exceeds the threshold.
        """
        if memory.score >= self.promotion_threshold:
            if memory.tier == MemoryTier.WORKING:
                memory.tier = MemoryTier.EPISODIC
                logger.info(f"Promoted memory {memory.id} to EPISODIC tier")
                # Replace with actual DB update logic
                # self.db_client.update_memory(memory.id, {"tier": "episodic"})
            elif memory.tier == MemoryTier.EPISODIC:
                memory.tier = MemoryTier.SEMANTIC
                logger.info(f"Promoted memory {memory.id} to SEMANTIC tier")
                # Replace with actual DB update logic
                # self.db_client.update_memory(memory.id, {"tier": "semantic"})
        else:
            logger.debug(f"Memory {memory.id} not promoted (score {memory.score} below threshold {self.promotion_threshold})")

    def _store_in_cache(self, query: str, memories: List[Memory]):
        """
        Stores the recalled memories in the working memory cache.

        Args:
            query: The query string.
            memories: A list of Memory objects to store.
        """
        try:
            # Hash the query for cache key stability
            query_hash = hashlib.sha256(query.encode()).hexdigest()
            # Serialize memory objects to JSON before storing
            memory_data = json.dumps([memory.to_dict() for memory in memories])
            if self.working_memory_cache:
                self.working_memory_cache.set(Memory(id=query_hash, content=memory_data, tier=MemoryTier.WORKING, score=0.9, domain="recall", source="engine", timestamp=datetime.now().isoformat()), adaptive_ttl=True)
                logger.info(f"Stored results for query '{query}' in working memory cache")
            else:
                logger.warning("WorkingMemoryCache not available. Skipping cache storage.")
        except Exception as e:
            logger.error(f"Error storing results in cache: {e}")


if __name__ == "__main__":
    # Test Functions
    def test_recall_basic():
        """Tests basic memory recall with a simple query."""
        engine = MemoryRecallEngine()
        results = engine.recall("performance optimization")
        assert len(results) > 0
        assert "performance" in results[0].content.lower()
        print("Test: Basic recall passed.")

    def test_recall_no_results():
        """Tests recall when no relevant memories are found."""
        engine = MemoryRecallEngine()
        results = engine.recall("nonexistent topic")
        assert len(results) == 0
        print("Test: No results recall passed.")

    def test_tier_promotion():
        """Tests memory tier promotion."""
        engine = MemoryRecallEngine(promotion_threshold=0.5)  # Set low threshold for testing
        memory_id = "memory_1"
        # Simulate a memory recall that triggers promotion
        memory = engine._retrieve_memory(memory_id)
        if memory:
            memory.score = 0.9  # Set score above promotion threshold
            engine._promote_tier(memory)
            assert memory.tier == MemoryTier.EPISODIC # Should promote to episodic
            engine._promote_tier(memory)
            assert memory.tier == MemoryTier.SEMANTIC # Should promote to semantic
            print("Test: Tier promotion passed.")
        else:
            assert False, "Memory not found for tier promotion test."

    test_recall_basic()
    test_recall_no_results()
    test_tier_promotion()