# memory_recall_engine.py
"""
High-Performance Memory Recall Engine
=====================================

This module provides a high-performance memory recall system using semantic similarity search,
caching, and memory tier promotion.  It integrates with existing Genesis memory components
and ensures fast and accurate memory retrieval.

Features:
    - Semantic similarity search using vector embeddings
    - Fast retrieval with caching (Redis)
    - Memory tier promotion (working -> episodic -> semantic)
    - Comprehensive error handling
    - Proper logging and metrics
"""

import json
import hashlib
import redis
import os
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
from pathlib import Path
from enum import Enum
import threading
import time

# Import existing Genesis memory components
try:
    from surprise_memory import MemorySystem, MemoryItem, SurpriseScore
except ImportError:
    print("[!] surprise_memory not found. Install or mock.")
    MemorySystem = None
    MemoryItem = None
    SurpriseScore = None

# Import observability modules
try:
    from logging_config import get_logger, with_context, OperationTimer
    LOGGING_AVAILABLE = True
    logger = get_logger("genesis.recall")
except ImportError:
    LOGGING_AVAILABLE = False
    logger = None

try:
    from metrics import GenesisMetrics, TimedOperation
    METRICS_AVAILABLE = True
except ImportError:
    METRICS_AVAILABLE = False
    GenesisMetrics = None

# Import circuit breaker for resilience
try:
    from circuit_breaker import get_circuit_breaker, CircuitBreaker
    CIRCUIT_AVAILABLE = True
except ImportError:
    CIRCUIT_AVAILABLE = False
    get_circuit_breaker = None

# Import secrets loader for secure credential management
try:
    from secrets_loader import get_redis_config, RedisConfig
    SECRETS_AVAILABLE = True
except ImportError:
    SECRETS_AVAILABLE = False
    get_redis_config = None
    RedisConfig = None

# Import vector backends for semantic similarity search
try:
    from vector_backends import VectorManager, VectorDocument
    VECTOR_AVAILABLE = True
except ImportError:
    VECTOR_AVAILABLE = False
    VectorManager = None

try:
    from genesis_memory_cortex import Memory, MemoryTier, WorkingMemoryCache
except ImportError:
    print("[!] genesis_memory_cortex not found. Install or mock.")
    Memory = None
    MemoryTier = None
    WorkingMemoryCache = None


class MemoryRecallEngine:
    """
    Engine for recalling memories based on semantic similarity and tiering.
    """

    def __init__(self, vector_manager: Optional[VectorManager] = None,
                 working_memory_cache: Optional['WorkingMemoryCache'] = None):
        """
        Initializes the MemoryRecallEngine.

        Args:
            vector_manager: The VectorManager instance for semantic search.
            working_memory_cache: The WorkingMemoryCache instance for fast retrieval.
        """
        self.vector_manager = vector_manager
        self.working_memory_cache = working_memory_cache
        self.recall_threshold = 0.7  # Similarity threshold for recall
        self.promotion_threshold = 0.8 # Score threshold for tier promotion

        if self.vector_manager is None and VECTOR_AVAILABLE:
            self.vector_manager = VectorManager()
        elif not VECTOR_AVAILABLE:
            if logger:
                logger.warning("VectorManager unavailable, recall may be limited.")
            else:
                print("[!] VectorManager unavailable, recall may be limited.")

        if self.working_memory_cache is None and SECRETS_AVAILABLE:
            redis_config = get_redis_config() if get_redis_config else None
            self.working_memory_cache = WorkingMemoryCache(redis_config)
        elif not SECRETS_AVAILABLE:
            if logger:
                logger.warning("WorkingMemoryCache unavailable, recall may be slower.")
            else:
                print("[!] WorkingMemoryCache unavailable, recall may be slower.")

    def recall(self, query: str, domain: str = "general", top_k: int = 5) -> List[Memory]:
        """
        Recalls memories based on semantic similarity to the query.

        Args:
            query: The query string.
            domain: The domain of the query.
            top_k: The number of memories to return.

        Returns:
            A list of Memory objects that are semantically similar to the query.
        """
        start_time = time.time()
        recalled_memories = []

        try:
            # 1. Check Working Memory Cache
            if self.working_memory_cache and self.working_memory_cache.available:
                # Simple hash-based cache key
                query_hash = hashlib.md5(query.encode()).hexdigest()
                cached_memory = self.working_memory_cache.get(f"recall:{query_hash}")
                if cached_memory:
                    recalled_memories.append(cached_memory)
                    if logger:
                        logger.debug(f"Recalled from cache: {cached_memory.id}")
                    if METRICS_AVAILABLE and GenesisMetrics:
                        GenesisMetrics.cache_hits.inc(labels={"tier": "working"})
                    return [cached_memory]  # Return immediately if found in cache

            # 2. Semantic Similarity Search
            if self.vector_manager:
                results = self.vector_manager.search(query, top_k=top_k)
                if results:
                    for doc in results:
                        try:
                            memory = self._document_to_memory(doc)
                            if memory.score >= self.recall_threshold:
                                recalled_memories.append(memory)
                        except Exception as e:
                            if logger:
                                logger.error(f"Error converting document to memory: {e}")
                            else:
                                print(f"[!] Error converting document to memory: {e}")


            # 3. Post-processing (Tier Promotion, Caching)
            for memory in recalled_memories:
                # Tier Promotion (Example: Promote from working to episodic)
                if memory.score >= self.promotion_threshold and memory.tier == MemoryTier.WORKING:
                    self._promote_memory_tier(memory, MemoryTier.EPISODIC)

                # Cache result in working memory
                if self.working_memory_cache and self.working_memory_cache.available:
                    # Simple hash-based cache key
                    query_hash = hashlib.md5(query.encode()).hexdigest()
                    self.working_memory_cache.set(Memory(id=f"recall:{query_hash}", content=memory.content, tier=MemoryTier.WORKING, score=memory.score, domain=domain, source="recall_engine", timestamp=datetime.now().isoformat()))


            if METRICS_AVAILABLE and GenesisMetrics:
                duration = time.time() - start_time
                GenesisMetrics.recall_latency.observe(duration, labels={"domain": domain})
                GenesisMetrics.memories_recalled.inc(len(recalled_memories), labels={"domain": domain})
            if logger:
                logger.info(f"Recalled {len(recalled_memories)} memories for query: {query}")

            return recalled_memories

        except Exception as e:
            if logger:
                logger.error(f"Recall failed for query: {query}. Error: {e}")
            else:
                print(f"[!] Recall failed for query: {query}. Error: {e}")
            return []  # Return empty list on error


    def _promote_memory_tier(self, memory: Memory, new_tier: MemoryTier) -> None:
        """
        Promotes a memory to a higher tier.  This would involve moving the memory
        to a different storage system (e.g., from Redis to PostgreSQL).
        This is a placeholder for actual data migration.

        Args:
            memory: The memory to promote.
            new_tier: The new memory tier.
        """
        try:
            # In a real implementation, this would involve:
            # 1. Moving data from the old storage to the new storage.
            # 2. Updating the memory's tier in the system.
            # 3. Deleting the memory from the old storage (optional).

            # For now, just log the promotion.
            if logger:
                logger.info(f"Promoting memory {memory.id} from {memory.tier} to {new_tier}")
            else:
                print(f"[!] Promoting memory {memory.id} from {memory.tier} to {new_tier}")

            memory.tier = new_tier  # Update the memory tier

            # Update memory in the new storage (e.g., PostgreSQL)
            # ...

            # Remove from old storage (e.g., Redis)
            if self.working_memory_cache and self.working_memory_cache.available:
                self.working_memory_cache.delete(memory.id)


        except Exception as e:
             if logger:
                logger.error(f"Failed to promote memory {memory.id} to {new_tier}. Error: {e}")
             else:
                print(f"[!] Failed to promote memory {memory.id} to {new_tier}. Error: {e}")



    def _document_to_memory(self, doc: 'VectorDocument') -> Memory:
        """
        Converts a VectorDocument to a Memory object.

        Args:
            doc: The VectorDocument to convert.

        Returns:
            A Memory object.
        """
        try:
            metadata = doc.metadata if doc.metadata else {}
            return Memory(
                id=doc.id,
                content=doc.content,
                tier=MemoryTier.EPISODIC,  # Assuming episodic by default
                score=doc.similarity,
                domain=metadata.get("domain", "general"),
                source=metadata.get("source", "vector_search"),
                timestamp=metadata.get("timestamp", datetime.now().isoformat()),
                embedding=doc.embedding,
                relations=metadata.get("relations", []),
                metadata=metadata
            )
        except Exception as e:
            if logger:
                logger.error(f"Failed to convert VectorDocument to Memory: {e}")
            else:
                print(f"[!] Failed to convert VectorDocument to Memory: {e}")
            raise  # Re-raise the exception after logging

    def clear_cache(self) -> None:
        """Clears the working memory cache."""
        if self.working_memory_cache and self.working_memory_cache.available:
            try:
                self.working_memory_cache.clear()
                if logger:
                    logger.info("Working memory cache cleared.")
                else:
                    print("[!] Working memory cache cleared.")
            except Exception as e:
                if logger:
                    logger.error(f"Failed to clear working memory cache: {e}")
                else:
                    print(f"[!] Failed to clear working memory cache: {e}")

################################################################################
# Test Functions
################################################################################

if __name__ == '__main__':
    # Mock VectorManager and WorkingMemoryCache for testing

    class MockVectorManager:
        def search(self, query: str, top_k: int = 5) -> List['VectorDocument']:
            # Return some mock VectorDocument objects
            return [
                VectorDocument(
                    id="mock_memory_1",
                    content="This is a test memory about performance.",
                    embedding=[0.1, 0.2, 0.3],
                    similarity=0.8,
                    metadata={"domain": "performance", "source": "test"}
                ),
                VectorDocument(
                    id="mock_memory_2",
                    content="Another test memory about optimization.",
                    embedding=[0.4, 0.5, 0.6],
                    similarity=0.6,
                    metadata={"domain": "optimization", "source": "test"}
                )
            ]

    class MockWorkingMemoryCache:
        def __init__(self):
            self.cache = {}
            self.available = True

        def get(self, memory_id: str) -> Optional[Memory]:
            return self.cache.get(memory_id)

        def set(self, memory: Memory, adaptive_ttl: bool = True):
            self.cache[memory.id] = memory

        def delete(self, memory_id: str) -> None:
            if memory_id in self.cache:
                del self.cache[memory_id]

        def clear(self) -> None:
            self.cache = {}

    def test_recall_basic():
        """Tests basic recall functionality."""
        mock_vector_manager = MockVectorManager()
        mock_cache = MockWorkingMemoryCache()
        recall_engine = MemoryRecallEngine(vector_manager=mock_vector_manager, working_memory_cache=mock_cache)
        results = recall_engine.recall("performance", domain="general")
        assert len(results) > 0
        assert "performance" in results[0].content

    def test_recall_cache():
        """Tests recall with caching."""
        mock_vector_manager = MockVectorManager()
        mock_cache = MockWorkingMemoryCache()
        recall_engine = MemoryRecallEngine(vector_manager=mock_vector_manager, working_memory_cache=mock_cache)

        # First recall should populate the cache
        results1 = recall_engine.recall("performance", domain="general")
        assert len(results1) > 0
        assert "performance" in results1[0].content

        # Second recall should retrieve from the cache
        results2 = recall_engine.recall("performance", domain="general")
        assert len(results2) > 0
        assert "performance" in results2[0].content
        assert results1[0].id == results2[0].id # Verify it's the same memory

    def test_recall_no_vector_manager():
        """Tests recall without a vector manager."""
        mock_cache = MockWorkingMemoryCache()
        recall_engine = MemoryRecallEngine(vector_manager=None, working_memory_cache=mock_cache)
        results = recall_engine.recall("performance", domain="general")
        assert len(results) == 0  # Should return an empty list if no vector manager

    # Run tests
    test_recall_basic()
    test_recall_cache()
    test_recall_no_vector_manager()

    print("All tests passed!")