# memory_recall_engine.py
"""
High-Performance Memory Recall Engine
=====================================

This module provides a high-performance memory recall system leveraging semantic
similarity search, caching, and memory tier promotion.

Features:
- Semantic similarity search using embeddings and vector backends
- Fast retrieval with caching (Redis)
- Memory tier promotion (working -> episodic -> semantic)
- Comprehensive error handling and logging
- Metrics collection for performance monitoring

Dependencies:
- genesis_memory_cortex (modified to expose embedding functionality)
- memory_schemas
- redis
- typing
- logging
- time
- hashlib

Usage:
    from memory_recall_engine import MemoryRecallEngine

    engine = MemoryRecallEngine()
    results = engine.recall("Summarize recent performance improvements")
"""

import json
import hashlib
import time
import logging
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass, asdict
from enum import Enum
import redis

# Import existing Genesis memory components
try:
    from genesis_memory_cortex import MemorySystem, MemoryItem, SurpriseScore, Memory, MemoryTier, WorkingMemoryCache, VectorManager
    from memory_schemas import MemoryItemInput, MemoryOutput
    GENESIS_AVAILABLE = True
except ImportError as e:
    print(f"Error importing Genesis modules: {e}")
    GENESIS_AVAILABLE = False

try:
    from logging_config import get_logger
    LOGGING_AVAILABLE = True
    logger = get_logger("memory_recall_engine")
except ImportError:
    LOGGING_AVAILABLE = False
    logger = logging.getLogger("memory_recall_engine")
    logger.setLevel(logging.INFO)
    # Add a simple handler if logging_config isn't available
    ch = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    ch.setFormatter(formatter)
    logger.addHandler(ch)

try:
    from metrics import GenesisMetrics
    METRICS_AVAILABLE = True
except ImportError:
    METRICS_AVAILABLE = False
    GenesisMetrics = None


class RecallEngineConfig:
    """Configuration class for the MemoryRecallEngine."""
    def __init__(self, top_k: int = 5, min_score: float = 0.5, promote_threshold: float = 0.8,
                 working_memory_ttl: int = 3600):
        self.top_k = top_k  # Number of results to return
        self.min_score = min_score  # Minimum similarity score to consider a result
        self.promote_threshold = promote_threshold  # Score threshold for tier promotion
        self.working_memory_ttl = working_memory_ttl # TTL for working memory in seconds

class MemoryRecallEngine:
    """
    High-performance memory recall engine with semantic similarity search,
    caching, and tier promotion.
    """

    def __init__(self, config: Optional[RecallEngineConfig] = None):
        if not GENESIS_AVAILABLE:
            logger.error("Genesis Memory Cortex not available. Recall engine will not function.")
            self.available = False
            self.config = None
            self.memory_system = None
            self.working_memory_cache = None
            self.vector_manager = None
            return

        self.available = True
        self.config = config if config else RecallEngineConfig()
        self.memory_system = MemorySystem()  # Assuming default MemorySystem initialization is sufficient
        self.working_memory_cache = WorkingMemoryCache(ttl_seconds=self.config.working_memory_ttl)
        self.vector_manager = VectorManager() # Assuming default VectorManager is sufficient
        # Ensure VectorManager is initialized correctly
        if not self.vector_manager.available:
            logger.warning("VectorManager not available. Semantic search will be limited.")


    def recall(self, query: str, domain: str = "default", source: str = "recall_engine") -> List[MemoryOutput]:
        """
        Recalls relevant memories based on semantic similarity to the query.

        Args:
            query: The search query.
            domain: The domain of the query (e.g., "programming", "finance").
            source: The source of the recall request (e.g., "user", "agent").

        Returns:
            A list of MemoryOutput objects representing the recalled memories,
            sorted by relevance (similarity score).
        """
        if not self.available:
            logger.warning("Recall engine not available.")
            return []

        start_time = time.time()
        try:
            # 1. Semantic Search using embeddings
            results = self._semantic_search(query, domain)

            # 2. Filter based on minimum score
            filtered_results = [r for r in results if r[1] >= self.config.min_score]

            # 3. Sort by score
            sorted_results = sorted(filtered_results, key=lambda x: x[1], reverse=True)

            # 4. Limit to top_k results
            top_results = sorted_results[:self.config.top_k]

            # 5. Tier promotion (if applicable)
            promoted_memories = self._promote_memories(top_results, domain, source)

            # 6. Format output
            output = self._format_output(promoted_memories)

            if METRICS_AVAILABLE and GenesisMetrics:
                duration = time.time() - start_time
                GenesisMetrics.recall_latency.observe(duration, labels={"domain": domain})
                GenesisMetrics.recall_operations.inc(labels={"domain": domain})

            return output

        except Exception as e:
            logger.exception(f"Error during recall: {e}")
            return []

    def _semantic_search(self, query: str, domain: str) -> List[Tuple[Memory, float]]:
        """
        Performs semantic similarity search using embeddings and a vector backend.

        Args:
            query: The search query.
            domain: The domain of the query.

        Returns:
            A list of tuples, where each tuple contains a Memory object and its
            similarity score to the query.
        """
        try:
            # 1. Generate embedding for the query
            query_embedding = self.vector_manager.generate_embedding(query)
            if query_embedding is None:
                logger.warning("Failed to generate embedding for query. Returning empty search results.")
                return []

            # 2. Search the vector database for similar memories
            search_results = self.vector_manager.search(query_embedding, top_k=10, domain=domain)

            # 3. Convert search results to Memory objects and scores
            results = []
            for doc in search_results:
                try:
                    memory = self._get_memory_by_id(doc.memory_id)
                    if memory:
                        results.append((memory, doc.similarity_score))
                except Exception as e:
                    logger.warning(f"Error retrieving memory {doc.memory_id}: {e}")

            return results

        except Exception as e:
            logger.exception(f"Error during semantic search: {e}")
            return []

    def _get_memory_by_id(self, memory_id: str) -> Optional[Memory]:
        """
        Retrieves a Memory object by its ID, checking the working memory cache first.

        Args:
            memory_id: The ID of the memory to retrieve.

        Returns:
            The Memory object if found, otherwise None.
        """
        # 1. Check working memory cache
        memory = self.working_memory_cache.get(memory_id)
        if memory:
            if logger:
                logger.debug(f"Memory {memory_id} retrieved from working memory cache.")
            return memory

        # 2. Check the main memory system
        try:
            memory_item = self.memory_system.get_memory(memory_id)
            if memory_item:
                memory = self._memory_item_to_memory(memory_item)
                return memory
            else:
                if logger:
                    logger.debug(f"Memory {memory_id} not found in any tier.")
                return None
        except Exception as e:
            logger.error(f"Error retrieving memory {memory_id} from memory system: {e}")
            return None

    def _memory_item_to_memory(self, memory_item: MemoryItem) -> Memory:
        """Converts a MemoryItem to a Memory object."""
        return Memory(
            id=memory_item.id,
            content=memory_item.content,
            tier=MemoryTier(memory_item.tier),
            score=memory_item.score,
            domain=memory_item.domain,
            source=memory_item.source,
            timestamp=memory_item.timestamp,
            embedding=memory_item.embedding,
            relations=memory_item.relations,
            access_count=memory_item.access_count,
            last_accessed=memory_item.last_accessed,
            metadata=memory_item.metadata
        )

    def _promote_memories(self, results: List[Tuple[Memory, float]], domain: str, source: str) -> List[Memory]:
        """
        Promotes memories to higher tiers based on their relevance score.

        Args:
            results: A list of tuples, where each tuple contains a Memory object
                and its similarity score to the query.
            domain: The domain of the query.
            source: The source of the recall request.

        Returns:
            A list of Memory objects, potentially with updated tiers.
        """
        promoted_memories = []
        for memory, score in results:
            if score >= self.config.promote_threshold:
                try:
                    current_tier = memory.tier
                    if current_tier == MemoryTier.WORKING:
                        new_tier = MemoryTier.EPISODIC
                    elif current_tier == MemoryTier.EPISODIC:
                        new_tier = MemoryTier.SEMANTIC
                    else:
                        new_tier = current_tier # no promotion needed

                    if new_tier != current_tier:
                        memory.tier = new_tier
                        # Update the memory in the memory system
                        self.memory_system.update_memory_tier(memory.id, new_tier.value)
                        logger.info(f"Memory {memory.id} promoted from {current_tier} to {new_tier}")
                except Exception as e:
                    logger.error(f"Error promoting memory {memory.id}: {e}")
            promoted_memories.append(memory)
        return promoted_memories

    def _format_output(self, memories: List[Memory]) -> List[MemoryOutput]:
        """
        Formats a list of Memory objects into a list of MemoryOutput objects.

        Args:
            memories: A list of Memory objects.

        Returns:
            A list of MemoryOutput objects.
        """
        output = []
        for memory in memories:
            stored_in = [memory.tier.value] # List where the memory is stored.  Can be extended to include vector store.
            output.append(MemoryOutput(
                id=memory.id,
                content=memory.content,
                tier=memory.tier.value,
                score=memory.score,
                timestamp=memory.timestamp,
                stored_in=stored_in,
                metadata=memory.metadata if memory.metadata else {}
            ))
        return output


# Test functions
if __name__ == "__main__":
    if not GENESIS_AVAILABLE:
        print("Genesis Memory Cortex not available. Skipping tests.")
    else:
        def create_test_memory_item(content: str, source: str, domain: str, score: float) -> MemoryItem:
            """Helper function to create a test MemoryItem."""
            memory_id = hashlib.sha256(content.encode()).hexdigest()[:16]
            return MemoryItem(
                id=memory_id,
                content=content,
                tier="working",
                score=score,
                domain=domain,
                source=source,
                timestamp=datetime.now().isoformat(),
                embedding=[0.1] * 128,  # Dummy embedding
                relations=[],
                access_count=0,
                last_accessed=None,
                metadata={}
            )

        def test_recall_basic():
            """Tests basic recall functionality."""
            engine = MemoryRecallEngine()
            if not engine.available:
                print("Recall engine not available, skipping test.")
                return

            # Create some test memories
            memory1 = create_test_memory_item("Parallel execution improves performance.", "test", "performance", 0.7)
            memory2 = create_test_memory_item("Vectorization also improves performance.", "test", "performance", 0.6)
            memory3 = create_test_memory_item("Unrelated information.", "test", "random", 0.4)

            # Store the memories
            engine.memory_system.store_memory(memory1)
            engine.memory_system.store_memory(memory2)
            engine.memory_system.store_memory(memory3)

            # Recall memories related to "performance"
            results = engine.recall("performance optimization", domain="performance")

            # Assert that the correct memories are returned
            assert len(results) >= 1
            assert "performance" in results[0].content.lower()

            print("test_recall_basic passed")

        def test_recall_empty_results():
            """Tests recall when no relevant memories are found."""
            engine = MemoryRecallEngine()
            if not engine.available:
                print("Recall engine not available, skipping test.")
                return

            # Recall memories related to a non-existent topic
            results = engine.recall("nonexistent topic", domain="random")

            # Assert that no memories are returned
            assert len(results) == 0
            print("test_recall_empty_results passed")

        def test_recall_tier_promotion():
            """Tests memory tier promotion functionality."""
            engine = MemoryRecallEngine(config=RecallEngineConfig(promote_threshold=0.6))
            if not engine.available:
                print("Recall engine not available, skipping test.")
                return

            # Create a test memory in the "working" tier
            memory = create_test_memory_item("Important information.", "test", "important", 0.8)
            engine.memory_system.store_memory(memory)

            # Recall the memory
            results = engine.recall("important", domain="important")

            # Assert that the memory was promoted to the "episodic" tier
            assert len(results) >= 1
            retrieved_memory = engine._get_memory_by_id(memory.id)
            assert retrieved_memory.tier == MemoryTier.EPISODIC

            print("test_recall_tier_promotion passed")


        test_recall_basic()
        test_recall_empty_results()
        test_recall_tier_promotion()