# memory_recall_engine.py
"""
High-Performance Memory Recall Engine
======================================
Provides semantic similarity search, caching, and memory tier promotion
for the Genesis Memory Cortex.
"""

import json
import hashlib
import redis
import os
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
from enum import Enum
import threading
import time
import numpy as np

# Import existing Genesis memory components (mocked for standalone execution)
try:
    from genesis_memory_cortex import MemorySystem, MemoryItem, SurpriseScore, Memory, MemoryTier, WorkingMemoryCache
    from logging_config import get_logger
    from metrics import GenesisMetrics
    from vector_backends import VectorManager, VectorDocument
    MEMORY_CONTEXT_AVAILABLE = True
    logger = get_logger("memory_recall_engine")
except ImportError:
    print("[!] WARNING: Genesis Memory Cortex dependencies not found. Running in standalone mode.")
    MEMORY_CONTEXT_AVAILABLE = False

    class MemorySystem:
        def __init__(self):
            pass

    class MemoryItem:
        pass

    class SurpriseScore:
        pass

    @dataclass
    class Memory:
        id: str
        content: str
        tier: str
        score: float
        domain: str
        source: str
        timestamp: str
        embedding: Optional[List[float]] = None
        relations: Optional[List[str]] = None
        access_count: int = 0
        last_accessed: Optional[str] = None
        metadata: Optional[Dict] = None

        def to_dict(self) -> Dict:
            return asdict(self)

    class MemoryTier(Enum):
        DISCARD = "discard"
        WORKING = "working"
        EPISODIC = "episodic"
        SEMANTIC = "semantic"
    
    class WorkingMemoryCache:
        def __init__(self, redis_config=None, ttl_seconds=3600):
            self.available = False
            print("[!] WorkingMemoryCache initialized in standalone mode (Redis unavailable).")
            self.cache = {} #Mock redis cache

        def set(self, memory: Memory, adaptive_ttl: bool = True):
            self.cache[memory.id] = memory

        def get(self, memory_id: str, extend_ttl: bool = True) -> Optional[Memory]:
            if memory_id in self.cache:
                return self.cache[memory_id]
            return None
        
    def get_logger(name):
        class MockLogger:
            def info(self, msg, extra=None):
                print(f"[INFO - {name}] {msg}, {extra}")
            def warning(self, msg, extra=None):
                print(f"[WARNING - {name}] {msg}, {extra}")
            def error(self, msg, extra=None):
                print(f"[ERROR - {name}] {msg}, {extra}")
            def debug(self, msg, extra=None):
                print(f"[DEBUG - {name}] {msg}, {extra}")
        return MockLogger()
    
    class GenesisMetrics:
        memory_operations = None
        memory_latency = None
        cache_hits = None
        cache_misses = None

    class VectorManager:
        def __init__(self):
            pass

        def query(self, query_embedding: List[float], top_k: int = 5, namespace: str = "default") -> List[Tuple[str, float]]:
            return [] # Mock implementation

        def add(self, document: 'VectorDocument', namespace: str = "default"):
            pass

        def delete(self, memory_id: str, namespace: str = "default"):
            pass

        def update(self, document: 'VectorDocument', namespace: str = "default"):
            pass
    logger = get_logger("memory_recall_engine")


class RecallEngine:
    """
    The core memory recall engine, responsible for:
    - Semantic similarity search
    - Fast retrieval with caching (Redis)
    - Memory tier promotion
    - Error handling and logging
    """

    def __init__(self, working_memory: Optional[WorkingMemoryCache] = None, vector_manager: Optional[VectorManager] = None):
        """
        Initializes the RecallEngine.

        Args:
            working_memory:  An instance of WorkingMemoryCache for fast retrieval.
            vector_manager:  An instance of VectorManager for semantic search.
        """
        self.working_memory = working_memory if working_memory else WorkingMemoryCache()
        self.vector_manager = vector_manager if vector_manager else VectorManager()
        self.namespace = "recall_engine"
        self.embedding_dimension = 1536  # Adjust based on your embedding model
        self.memory_systems = {
            MemoryTier.WORKING: self.working_memory,
            # Add other memory tiers as needed (e.g., episodic, semantic)
        }

    def embed_query(self, query: str) -> List[float]:
        """
        Generates an embedding for the given query.

        This is a placeholder; replace with your actual embedding model.

        Args:
            query: The query string.

        Returns:
            A list of floats representing the embedding.
        """
        # Mock embedding generation (replace with your actual model)
        hasher = hashlib.sha256(query.encode())
        hash_bytes = hasher.digest()
        embedding = [float(byte) / 255.0 for byte in hash_bytes[:self.embedding_dimension]]
        return embedding

    def recall(self, query: str, top_k: int = 5, domain: str = "default") -> List[Memory]:
        """
        Recalls relevant memories based on semantic similarity to the query.

        Args:
            query: The query string.
            top_k: The number of memories to retrieve.
            domain: The domain of the query (used for filtering).

        Returns:
            A list of Memory objects, sorted by relevance.
        """
        start_time = time.time()
        try:
            # 1. Generate query embedding
            query_embedding = self.embed_query(query)

            # 2. Perform semantic similarity search (vector store)
            results = self.vector_manager.query(query_embedding, top_k=top_k, namespace=domain)

            # 3. Retrieve Memory objects from cache/storage based on IDs
            memories = []
            for memory_id, score in results:
                # First, check working memory (fastest)
                memory = self.working_memory.get(memory_id)
                if memory:
                    memories.append(memory)
                    continue

                # If not in working memory, retrieve from other tiers (e.g., episodic, semantic)
                # (Implementation depends on your storage setup)
                # Example:
                # memory = self.episodic_memory.get(memory_id)
                # if memory:
                #     memories.append(memory)
                #     continue

            # 4. Sort memories by relevance (score)
            memories.sort(key=lambda x: x.score, reverse=True)

            if memories:
                logger.info(f"Successfully recalled {len(memories)} memories for query: {query}", extra={"domain": domain})
            else:
                logger.info(f"No memories found for query: {query}", extra={"domain": domain})

            return memories

        except Exception as e:
            logger.error(f"Error during memory recall: {e}", exc_info=True, extra={"query": query, "domain": domain})
            return []  # Return empty list on error
        finally:
            duration = time.time() - start_time
            if MEMORY_CONTEXT_AVAILABLE and GenesisMetrics:
                GenesisMetrics.memory_latency.observe(duration, labels={"tier": "recall"})

    def promote_memory_tier(self, memory: Memory, new_tier: MemoryTier) -> bool:
        """
        Promotes a memory to a higher tier of storage based on importance and access frequency.

        Args:
            memory: The Memory object to promote.
            new_tier: The target MemoryTier.

        Returns:
            True if the promotion was successful, False otherwise.
        """
        try:
            if memory.tier == new_tier:
                logger.warning(f"Memory {memory.id} already at tier {new_tier.value}")
                return True

            # 1. Remove from current tier (if applicable)
            if memory.tier != MemoryTier.DISCARD:
                if memory.tier in self.memory_systems:
                    if memory.tier == MemoryTier.WORKING:
                        # Assuming delete method exists in WorkingMemoryCache to remove memory
                        self.working_memory.client.delete(f"genesis:working_memory:{memory.id}")
                    else:
                        logger.warning(f"Deletion from tier {memory.tier} not implemented.")
                else:
                    logger.error(f"Invalid memory tier: {memory.tier}")
                    return False

            # 2. Store in the new tier
            memory.tier = new_tier
            if new_tier in self.memory_systems:
                if new_tier == MemoryTier.WORKING:
                    self.working_memory.set(memory)
                    logger.info(f"Memory {memory.id} promoted to working memory")
                else:
                    logger.warning(f"Storage in tier {new_tier} not implemented.")
            else:
                logger.error(f"Invalid target memory tier: {new_tier}")
                return False

            # 3. Update vector store (if applicable)
            if memory.embedding:
                vector_doc = VectorDocument(id=memory.id, content=memory.content, metadata=memory.to_dict(), embedding=memory.embedding)
                self.vector_manager.update(vector_doc, namespace=memory.domain)
                logger.info(f"Vector store updated for memory {memory.id}")

            return True

        except Exception as e:
            logger.error(f"Error promoting memory {memory.id} to tier {new_tier.value}: {e}", exc_info=True)
            return False


# Example usage and tests (mocked)
if __name__ == "__main__":
    # Create a mock RecallEngine instance
    recall_engine = RecallEngine()

    # Test Data
    memory1 = Memory(id="1", content="The sky is blue.", tier=MemoryTier.WORKING, score=0.8, domain="general", source="observation", timestamp=str(datetime.now()), embedding=recall_engine.embed_query("The sky is blue."))
    memory2 = Memory(id="2", content="Grass is green.", tier=MemoryTier.WORKING, score=0.7, domain="general", source="observation", timestamp=str(datetime.now()), embedding=recall_engine.embed_query("Grass is green."))
    memory3 = Memory(id="3", content="Water is wet.", tier=MemoryTier.WORKING, score=0.9, domain="general", source="observation", timestamp=str(datetime.now()), embedding=recall_engine.embed_query("Water is wet."))

    # Mock populate working memory
    recall_engine.working_memory.set(memory1)
    recall_engine.working_memory.set(memory2)
    recall_engine.working_memory.set(memory3)

    # Mock populate vector database.
    recall_engine.vector_manager.add(VectorDocument(id="1", content="The sky is blue.", metadata=memory1.to_dict(), embedding=memory1.embedding))
    recall_engine.vector_manager.add(VectorDocument(id="2", content="Grass is green.", metadata=memory2.to_dict(), embedding=memory2.embedding))
    recall_engine.vector_manager.add(VectorDocument(id="3", content="Water is wet.", metadata=memory3.to_dict(), embedding=memory3.embedding))

    def test_recall_basic():
        """Tests basic recall functionality."""
        results = recall_engine.recall("What color is the sky?")
        assert len(results) > 0
        assert results[0].content == "The sky is blue."
        print("[OK] test_recall_basic passed")

    def test_recall_no_results():
        """Tests recall when no relevant memories exist."""
        results = recall_engine.recall("What is the meaning of life?")
        assert len(results) == 0
        print("[OK] test_recall_no_results passed")

    def test_memory_promotion():
        """Tests memory tier promotion."""
        initial_tier = memory1.tier
        success = recall_engine.promote_memory_tier(memory1, MemoryTier.EPISODIC)
        assert success
        assert memory1.tier == MemoryTier.EPISODIC
        print("[OK] test_memory_promotion passed")

    # Run tests
    test_recall_basic()
    test_recall_no_results()
    test_memory_promotion()