import time
from typing import List, Dict, Any, Callable, Optional
import logging
import threading

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class MemoryStore:
    """
    Base class for a memory store.  Defines the interface for all memory types.
    """
    def __init__(self, name: str):
        self.name = name
        self.lock = threading.Lock()  # Thread safety

    def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """
        Searches the memory store for relevant information.
        Args:
            query: The search query.
            top_k: The number of results to return.

        Returns:
            A list of dictionaries, where each dictionary represents a search result.
            Each dictionary should have at least 'content' and 'relevance' keys.
        """
        raise NotImplementedError

    def add(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None:
        """
        Adds new information to the memory store.
        Args:
            content: The content to add.
            metadata: Optional metadata associated with the content.
        """
        raise NotImplementedError

    def delete(self, key: str) -> bool:
        """
        Deletes an entry from the memory store, using a unique key.
        """
        raise NotImplementedError

    def get_size(self) -> int:
        """Returns the current size (number of entries) in the memory store."""
        raise NotImplementedError

    def export_data(self) -> Any:
        """Exports all data from the memory store (e.g., for backup)."""
        raise NotImplementedError

    def import_data(self, data: Any) -> None:
        """Imports data into the memory store."""
        raise NotImplementedError


class WorkingMemory(MemoryStore):
    """
    Fast, short-term memory for current context.  Uses a simple list-based approach.
    """
    def __init__(self, name: str = "WorkingMemory", max_size: int = 10):
        super().__init__(name)
        self.memory: List[Dict[str, Any]] = []
        self.max_size = max_size

    def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """
        Searches the working memory using a simple string matching approach.
        """
        with self.lock:
            results = []
            for item in self.memory:
                if query.lower() in item['content'].lower():
                    relevance = self._calculate_relevance(query, item['content'])
                    results.append({'content': item['content'], 'relevance': relevance, 'source': self.name})

            results.sort(key=lambda x: x['relevance'], reverse=True)
            return results[:top_k]

    def add(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None:
        """
        Adds new information to the working memory, maintaining a maximum size.
        """
        with self.lock:
            if len(self.memory) >= self.max_size:
                self.memory.pop(0)  # Remove the oldest entry
            self.memory.append({'content': content, 'metadata': metadata})

    def delete(self, key: str) -> bool:
        """Not applicable for Working Memory (no keys). Returns False."""
        return False

    def get_size(self) -> int:
        with self.lock:
            return len(self.memory)

    def export_data(self) -> List[Dict[str, Any]]:
        with self.lock:
            return self.memory.copy()

    def import_data(self, data: List[Dict[str, Any]]) -> None:
        with self.lock:
            self.memory = data.copy()
            while len(self.memory) > self.max_size:
                self.memory.pop(0)


    def _calculate_relevance(self, query: str, content: str) -> float:
        """
        Simple relevance calculation based on the proportion of query words present in the content.
        """
        query_words = query.lower().split()
        content_words = content.lower().split()
        common_words = sum(1 for word in query_words if word in content_words)
        return common_words / len(query_words) if query_words else 0.0


class EpisodicMemory(MemoryStore):
    """
    Memory for recent events, stored with timestamps.  Uses a simple list.
    """
    def __init__(self, name: str = "EpisodicMemory", decay_rate: float = 0.95):
        super().__init__(name)
        self.memory: List[Dict[str, Any]] = []
        self.decay_rate = decay_rate  # Decay relevance over time

    def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """
        Searches episodic memory, factoring in recency.
        """
        with self.lock:
            now = time.time()
            results = []
            for item in self.memory:
                if query.lower() in item['content'].lower():
                    relevance = self._calculate_relevance(query, item['content'])
                    time_decay = self.decay_rate ** (now - item['timestamp'])
                    relevance *= time_decay  # Reduce relevance based on age
                    results.append({'content': item['content'], 'relevance': relevance, 'source': self.name, 'key': item['key']})

            results.sort(key=lambda x: x['relevance'], reverse=True)
            return results[:top_k]

    def add(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None:
        """
        Adds a new episode to memory, storing the timestamp and a unique key.
        """
        with self.lock:
            key = str(hash(content + str(time.time())))  # More robust key
            self.memory.append({'content': content, 'timestamp': time.time(), 'metadata': metadata, 'key': key})

    def delete(self, key: str) -> bool:
        """Deletes an entry from episodic memory based on its key."""
        with self.lock:
            for i, item in enumerate(self.memory):
                if item['key'] == key:
                    del self.memory[i]
                    return True
            return False

    def get_size(self) -> int:
        with self.lock:
            return len(self.memory)

    def export_data(self) -> List[Dict[str, Any]]:
        with self.lock:
            return self.memory.copy()

    def import_data(self, data: List[Dict[str, Any]]) -> None:
        with self.lock:
            self.memory = data.copy()

    def _calculate_relevance(self, query: str, content: str) -> float:
        """
        Simple relevance calculation based on the proportion of query words present in the content.
        """
        query_words = query.lower().split()
        content_words = content.lower().split()
        common_words = sum(1 for word in query_words if word in content_words)
        return common_words / len(query_words) if query_words else 0.0


class SemanticMemory(MemoryStore):
    """
    Long-term, comprehensive memory.  Uses a dictionary for simplicity.  Could be replaced with a vector database.
    """
    def __init__(self, name: str = "SemanticMemory"):
        super().__init__(name)
        self.memory: Dict[str, str] = {} # Key: unique ID, Value: Content

    def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """
        Searches semantic memory using a simple string matching approach.  In a real system,
        this would be replaced with a vector search or other more sophisticated method.
        """
        with self.lock:
            results = []
            for key, content in self.memory.items():
                if query.lower() in content.lower():
                    relevance = self._calculate_relevance(query, content)
                    results.append({'content': content, 'relevance': relevance, 'source': self.name, 'key': key})

            results.sort(key=lambda x: x['relevance'], reverse=True)
            return results[:top_k]

    def add(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> None:
        """
        Adds information to the semantic memory.  Uses a simple key-value store.
        In a real system, this would involve embedding the content and storing it in a vector database.
        """
        with self.lock:
            key = str(hash(content))  # Simple unique ID
            self.memory[key] = content

    def delete(self, key: str) -> bool:
        """Deletes an entry from semantic memory based on its key."""
        with self.lock:
            if key in self.memory:
                del self.memory[key]
                return True
            return False

    def get_size(self) -> int:
        with self.lock:
            return len(self.memory)

    def export_data(self) -> Dict[str, str]:
        with self.lock:
            return self.memory.copy()

    def import_data(self, data: Dict[str, str]) -> None:
        with self.lock:
            self.memory = data.copy()

    def _calculate_relevance(self, query: str, content: str) -> float:
        """
        Simple relevance calculation based on the proportion of query words present in the content.
        """
        query_words = query.lower().split()
        content_words = content.lower().split()
        common_words = sum(1 for word in query_words if word in content_words)
        return common_words / len(query_words) if query_words else 0.0


class UnifiedMemoryManager:
    """
    Unified memory manager that handles all memory operations, tier routing,
    cross-memory search, migration, garbage collection, and provides statistics.
    """
    def __init__(self, working_memory: WorkingMemory, episodic_memory: EpisodicMemory, semantic_memory: SemanticMemory):
        self.working_memory = working_memory
        self.episodic_memory = episodic_memory
        self.semantic_memory = semantic_memory
        self.cache: Dict[str, List[Dict[str, Any]]] = {}  # Query cache
        self.migration_threshold = 5  # Number of accesses before migration
        self.access_counts: Dict[str, int] = {}  # Track access counts for migration
        self.migration_lock = threading.Lock()
        self.stats = {
            "working_memory_hits": 0,
            "episodic_memory_hits": 0,
            "semantic_memory_hits": 0,
            "cache_hits": 0,
            "migrations": 0
        }
        self.stats_lock = threading.Lock()
        self.garbage_collection_interval = 3600 # Run GC every hour
        self.gc_thread = threading.Thread(target=self._garbage_collection_loop, daemon=True)
        self.gc_thread.start()


    def _garbage_collection_loop(self):
        """Periodically performs garbage collection."""
        while True:
            time.sleep(self.garbage_collection_interval)
            self.garbage_collection()

    def add(self, content: str, memory_type: str = "working", metadata: Optional[Dict[str, Any]] = None) -> None:
        """
        Adds content to the specified memory type.
        """
        if memory_type == "working":
            self.working_memory.add(content, metadata)
        elif memory_type == "episodic":
            self.episodic_memory.add(content, metadata)
        elif memory_type == "semantic":
            self.semantic_memory.add(content, metadata)
        else:
            raise ValueError(f"Invalid memory type: {memory_type}")
        logging.info(f"Added content to {memory_type} memory.")


    def retrieve(self, query: str, query_type: str = "general", top_k: int = 5) -> List[Dict[str, Any]]:
        """
        Retrieves information from the unified memory system.
        """
        # 1. Check Cache
        if query in self.cache:
            with self.stats_lock:
                self.stats["cache_hits"] += 1
            logging.debug(f"Retrieving '{query}' from cache.")
            return self.cache[query][:top_k]

        # 2. Query Routing
        memory_stores = self._route_query(query_type)

        # 3. Multi-Memory Search
        results = []
        for memory_store in memory_stores:
            store_results = memory_store.search(query, top_k=top_k)
            results.extend(store_results)
            with self.stats_lock:
                if memory_store == self.working_memory:
                    self.stats["working_memory_hits"] += len(store_results)
                elif memory_store == self.episodic_memory:
                    self.stats["episodic_memory_hits"] += len(store_results)
                elif memory_store == self.semantic_memory:
                    self.stats["semantic_memory_hits"] += len(store_results)


        # 4. Result Fusion
        fused_results = self._fuse_results(results, top_k=top_k)

        # 5. Update Cache
        self.cache[query] = fused_results

        # 6. Track Access for potential memory migration
        self._track_access(query, fused_results)

        return fused_results

    def _route_query(self, query_type: str) -> List[MemoryStore]:
        """
        Routes the query to the appropriate memory stores based on the query type.
        """
        if query_type == "factual":
            return [self.semantic_memory]
        elif query_type == "recent":
            return [self.episodic_memory, self.semantic_memory] # Check episodic first, then semantic
        elif query_type == "contextual":
            return [self.working_memory, self.episodic_memory, self.semantic_memory] # Check working first
        else:  # general
            return [self.working_memory, self.episodic_memory, self.semantic_memory] # Working -> Episodic -> Semantic

    def _fuse_results(self, results: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
        """
        Merges results from different memory stores, deduplicates, and ranks them.
        """
        # Deduplication (based on content)
        seen_content = set()
        deduplicated_results = []
        for result in results:
            if result['content'] not in seen_content:
                deduplicated_results.append(result)
                seen_content.add(result['content'])

        # Ranking (based on relevance)
        deduplicated_results.sort(key=lambda x: x['relevance'], reverse=True)

        return deduplicated_results[:top_k]

    def _track_access(self, query: str, results: List[Dict[str, Any]]) -> None:
        """
        Tracks access counts for memory migration.
        """
        if not results:
            return  # Nothing found, no need to track

        best_result = results[0]
        key = best_result.get('key')  # Get the key if it exists (Episodic/Semantic)
        if key:
            with self.migration_lock:
                if key not in self.access_counts:
                    self.access_counts[key] = 0
                self.access_counts[key] += 1

                if self.access_counts[key] >= self.migration_threshold:
                    self._migrate_memory(key, best_result['content'], best_result.get('metadata'))  # Pass the metadata
                    del self.access_counts[key]  # Remove from tracking


    def _migrate_memory(self, key: str, content: str, metadata: Optional[Dict[str, Any]]) -> None:
        """
        Migrates memory from working -> episodic -> semantic based on access frequency.
        """
        with self.migration_lock:
            if self.working_memory.delete(key):
                self.episodic_memory.add(content, metadata) # Add with metadata
                logging.info(f"Migrated '{content[:50]}...' from Working to Episodic Memory.")
            elif self.episodic_memory.delete(key):
                self.semantic_memory.add(content, metadata) # Add with metadata
                logging.info(f"Migrated '{content[:50]}...' from Episodic to Semantic Memory.")
            else:
                logging.warning(f"Could not find memory with key '{key}' for migration.")
                return

            with self.stats_lock:
                self.stats["migrations"] += 1

    def delete(self, key: str, memory_type: str) -> bool:
        """Deletes an entry from the specified memory type."""
        if memory_type == "working":
            return self.working_memory.delete(key)
        elif memory_type == "episodic":
            return self.episodic_memory.delete(key)
        elif memory_type == "semantic":
            return self.semantic_memory.delete(key)
        else:
            raise ValueError(f"Invalid memory type: {memory_type}")

    def garbage_collection(self) -> None:
        """
        Performs garbage collection to remove irrelevant or outdated information.
        This is a placeholder for more sophisticated garbage collection logic.
        """
        logging.info("Starting garbage collection...")
        # Example: Remove old episodic memories
        now = time.time()
        cutoff_time = now - 86400 * 7  # 7 days ago
        with self.episodic_memory.lock:
            self.episodic_memory.memory = [item for item in self.episodic_memory.memory if item['timestamp'] > cutoff_time]

        # Example: Limit semantic memory size (very basic, replace with something better)
        max_semantic_size = 100
        with self.semantic_memory.lock:
            if len(self.semantic_memory.memory) > max_semantic_size:
                keys_to_delete = list(self.semantic_memory.memory.keys())[:len(self.semantic_memory.memory) - max_semantic_size]
                for key in keys_to_delete:
                    del self.semantic_memory.memory[key]

        logging.info("Garbage collection complete.")


    def get_memory_stats(self) -> Dict[str, int]:
        """Returns real-time memory statistics."""
        with self.stats_lock:
            stats = self.stats.copy()  # Return a copy to prevent external modification
        stats["working_memory_size"] = self.working_memory.get_size()
        stats["episodic_memory_size"] = self.episodic_memory.get_size()
        stats["semantic_memory_size"] = self.semantic_memory.get_size()
        return stats

    def export_memory(self) -> Dict[str, Any]:
        """Exports all memory data for backup."""
        return {
            "working_memory": self.working_memory.export_data(),
            "episodic_memory": self.episodic_memory.export_data(),
            "semantic_memory": self.semantic_memory.export_data(),
            "access_counts": self.access_counts.copy() # Export access counts for migration consistency
        }

    def import_memory(self, data: Dict[str, Any]) -> None:
        """Imports memory data from a backup."""
        self.working_memory.import_data(data["working_memory"])
        self.episodic_memory.import_data(data["episodic_memory"])
        self.semantic_memory.import_data(data["semantic_memory"])
        with self.migration_lock:
            self.access_counts = data.get("access_counts", {}).copy()  # Import access counts


    def preload_semantic_memory(self, data: List[str]) -> None:
        """
        Preloads the semantic memory with initial data.
        """
        for item in data:
            self.semantic_memory.add(item)

    def batch_retrieve(self, queries: List[str], query_type: str = "general", top_k: int = 5) -> Dict[str, List[Dict[str, Any]]]:
        """
        Retrieves information for a batch of queries, optimizing for performance.
        """
        results = {}
        for query in queries:
            results[query] = self.retrieve(query, query_type, top_k)
        return results


# Example Usage
if __name__ == '__main__':
    # Initialize Memory Stores
    working_memory = WorkingMemory()
    episodic_memory = EpisodicMemory()
    semantic_memory = SemanticMemory()

    # Initialize Unified Memory System
    unified_memory = UnifiedMemoryManager(working_memory, episodic_memory, semantic_memory)

    # Preload Semantic Memory
    semantic_memory_data = [
        "The capital of France is Paris.",
        "Elephants are the largest land animals.",
        "The speed of light is approximately 299,792,458 meters per second.",
        "Python is a popular programming language.",
        "Albert Einstein developed the theory of relativity."
    ]
    unified_memory.preload_semantic_memory(semantic_memory_data)

    # Add to Working Memory
    unified_memory.add("I am currently testing the unified memory system.", memory_type="working")
    unified_memory.add("The current task is to retrieve information efficiently.", memory_type="working")
    unified_memory.add("The sky is blue.", memory_type="working", metadata={"color": "blue"}) # Example with metadata

    # Add to Episodic Memory
    unified_memory.add("I had lunch at a cafe yesterday.", memory_type="episodic")
    unified_memory.add("I attended a meeting this morning.", memory_type="episodic")

    # Example Queries
    query1 = "What is the capital of France?"
    query2 = "programming language"
    query3 = "current task"
    query4 = "yesterday"
    query5 = "blue sky"

    # Retrieve Information
    results1 = unified_memory.retrieve(query1, query_type="factual")
    print(f"Results for '{query1}': {results1}")

    results2 = unified_memory.retrieve(query2, query_type="general")
    print(f"Results for '{query2}': {results2}")

    results3 = unified_memory.retrieve(query3, query_type="contextual")
    print(f"Results for '{query3}': {results3}")

    results4 = unified_memory.retrieve(query4, query_type="recent")
    print(f"Results for '{query4}': {results4}")

    results5 = unified_memory.retrieve(query5, query_type="general")
    print(f"Results for '{query5}': {results5}")


    # Demonstrate caching
    results1_cached = unified_memory.retrieve(query1, query_type="factual")
    print(f"Results for '{query1}' (cached): {results1_cached}")

    # Demonstrate batch retrieval
    queries = [query1, query2, query3]
    batch_results = unified_memory.batch_retrieve(queries, query_type="general")
    print(f"Batch results: {batch_results}")

    # Demonstrate Memory Migration (force multiple accesses to trigger)
    for _ in range(unified_memory.migration_threshold + 1):
        unified_memory.retrieve(query5, query_type="general") #Access "blue sky" repeatedly

    # Print memory statistics
    print(f"Memory Stats: {unified_memory.get_memory_stats()}")

    # Demonstrate Export/Import
    exported_data = unified_memory.export_memory()
    unified_memory.import_memory(exported_data)  # Reload the data

    print("Demonstrating Garbage Collection")
    unified_memory.garbage_collection()
