"""
Genesis Memory System - Memory Consolidation
=============================================
Handles memory promotion across tiers and decay calculations.
Implements Ebbinghaus forgetting curve with spaced repetition.
"""

import math
from datetime import datetime, timezone, timedelta
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass
import sys

sys.path.insert(0, '/mnt/e/genesis-system/genesis-memory')
from storage.postgresql_store import get_postgresql_store, Episode
from storage.redis_store import get_redis_store, WorkingMemory
from storage.qdrant_store import get_qdrant_store
from intelligence.embedding_generator import get_embedding_generator


@dataclass
class ConsolidationResult:
    """Result of a consolidation run."""
    working_to_episodic: int
    episodic_strengthened: int
    episodic_decayed: int
    episodic_archived: int
    episodic_forgotten: int
    entities_extracted: int
    connections_created: int
    errors: List[str]


class MemoryConsolidator:
    """
    Handles memory tier transitions and consolidation.

    Tier Promotion:
    - Working → Episodic: After sufficient access or age
    - Episodic → Semantic: High importance memories get entity extraction

    Decay:
    - Ebbinghaus forgetting curve: R(t) = e^(-t/S)
    - Stability increases with each successful recall
    - Spaced repetition strengthens memories
    """

    # Thresholds
    WORKING_PROMOTE_AGE_SECONDS = 300  # 5 minutes
    EPISODIC_ARCHIVE_STRENGTH = 0.3
    EPISODIC_FORGET_STRENGTH = 0.1
    EPISODIC_FORGET_IMPORTANCE = 0.3

    # Decay parameters
    BASE_STABILITY_HOURS = 24  # Initial stability in hours
    STABILITY_MULTIPLIER = 0.5  # Increase per consolidation

    def __init__(self):
        self.pg_store = get_postgresql_store()
        self.redis_store = get_redis_store()
        self.qdrant_store = get_qdrant_store()
        self.embedding_gen = get_embedding_generator()

    def calculate_decay(
        self,
        hours_since_access: float,
        consolidation_count: int,
        base_decay_rate: float = 0.1
    ) -> float:
        """
        Calculate memory strength using Ebbinghaus forgetting curve.

        R(t) = e^(-t/S) where:
        - R = retention (strength)
        - t = time since last access (hours)
        - S = stability (increases with consolidations)

        Args:
            hours_since_access: Time since last access in hours
            consolidation_count: Number of times memory was consolidated
            base_decay_rate: Base decay rate (not used in this formula)

        Returns:
            Current strength (0.0 to 1.0)
        """
        # Stability increases with each consolidation (spaced repetition effect)
        stability = self.BASE_STABILITY_HOURS * (1 + consolidation_count * self.STABILITY_MULTIPLIER)

        # Ebbinghaus decay
        strength = math.exp(-hours_since_access / stability)

        return max(0.0, min(1.0, strength))

    def promote_working_to_episodic(
        self,
        agent_id: str = None,
        batch_size: int = 50
    ) -> Tuple[int, List[str]]:
        """
        Promote memories from working memory to episodic storage.

        Criteria:
        - Age > WORKING_PROMOTE_AGE_SECONDS
        - Or explicitly marked for promotion

        Returns:
            (count_promoted, list_of_episode_ids)
        """
        promoted = 0
        episode_ids = []
        errors = []

        # Get all agent sessions if no specific agent
        if agent_id:
            agents = [agent_id]
        else:
            # Get active sessions from Redis
            keys = self.redis_store.client.keys("session:*")
            agents = [k.split(":")[-1] for k in keys]

        for agent in agents:
            try:
                promotable = self.redis_store.get_promotable_memories(
                    agent,
                    age_threshold_seconds=self.WORKING_PROMOTE_AGE_SECONDS
                )

                for memory_id in promotable[:batch_size]:
                    # Get full memory from working memory
                    memories = self.redis_store.get_working_memory(agent, limit=100)
                    memory = next((m for m in memories if m.memory_id == memory_id), None)

                    if not memory:
                        continue

                    # Generate embedding
                    emb_result = self.embedding_gen.generate(memory.content)

                    # Store in PostgreSQL
                    episode_id = self.pg_store.store_episode(
                        content=memory.content,
                        source_type=memory.source,
                        agent_id=agent,
                        importance_score=memory.importance,
                        embedding_id=None  # Will update after Qdrant storage
                    )

                    if episode_id:
                        # Store in Qdrant
                        self.qdrant_store.store_episode_embedding(
                            episode_id=episode_id,
                            embedding=emb_result.embedding,
                            content=memory.content,
                            source_type=memory.source,
                            agent_id=agent,
                            importance=memory.importance
                        )

                        # Remove from working memory
                        self.redis_store.remove_from_working_memory(agent, memory_id)

                        episode_ids.append(episode_id)
                        promoted += 1

            except Exception as e:
                errors.append(f"Agent {agent}: {str(e)}")

        return promoted, episode_ids

    def run_decay_sweep(self, batch_size: int = 100) -> Dict[str, int]:
        """
        Run decay calculations and update episode statuses.

        Updates:
        - Recalculates strength based on time since access
        - Archives weak memories (strength < 0.3)
        - Forgets unimportant weak memories (strength < 0.1, importance < 0.3)

        Returns:
            Counts of updated, archived, and forgotten episodes
        """
        return self.pg_store.run_decay_sweep(batch_size)

    def strengthen_memory(
        self,
        episode_id: str,
        amount: float = 0.1
    ) -> float:
        """
        Strengthen a memory (spaced repetition effect).

        Called when:
        - Memory is accessed/recalled
        - Memory proves useful in a task
        - User explicitly reinforces

        Returns:
            New strength value
        """
        return self.pg_store.strengthen_episode(episode_id, amount)

    def extract_entities_from_episode(
        self,
        episode_id: str
    ) -> List[str]:
        """
        Extract semantic entities from an episode.

        Uses simple NER patterns for now.
        Could be enhanced with LLM-based extraction.

        Returns:
            List of entity IDs created
        """
        episode = self.pg_store.get_episode(episode_id)
        if not episode:
            return []

        entity_ids = []

        # Simple entity extraction patterns
        # In production, use spaCy or LLM-based NER
        import re

        # Extract capitalized phrases (potential named entities)
        capitalized = re.findall(r'\b([A-Z][a-z]+(?: [A-Z][a-z]+)*)\b', episode.content)

        # Extract technical terms (lowercase with underscores or camelCase)
        technical = re.findall(r'\b([a-z]+(?:_[a-z]+)+|[a-z]+(?:[A-Z][a-z]+)+)\b', episode.content)

        for entity_name in set(capitalized + technical):
            if len(entity_name) < 3:
                continue

            # Determine entity type
            entity_type = "concept"
            if entity_name[0].isupper():
                entity_type = "named_entity"

            # Generate embedding for entity
            emb_result = self.embedding_gen.generate(entity_name)

            # Store entity
            entity_id = self.pg_store.store_entity(
                name=entity_name,
                entity_type=entity_type,
                embedding_id=None
            )

            # Store entity embedding
            self.qdrant_store.store_entity_embedding(
                entity_id=entity_id,
                embedding=emb_result.embedding,
                name=entity_name,
                entity_type=entity_type
            )

            # Link to episode
            self.pg_store.link_entity_to_episode(
                episode_id=episode_id,
                entity_id=entity_id,
                mention_text=entity_name
            )

            entity_ids.append(entity_id)

        return entity_ids

    def create_temporal_connections(
        self,
        episode_id: str,
        window_size: int = 10
    ) -> List[str]:
        """
        Create connections to temporally adjacent episodes.

        Connects episodes that occurred close in time,
        building a temporal knowledge graph.

        Returns:
            List of connection IDs created
        """
        episode = self.pg_store.get_episode(episode_id)
        if not episode:
            return []

        connection_ids = []

        # Find recent episodes from same source/agent
        recent = self.pg_store.search_episodes(
            agent_id=episode.agent_id,
            limit=window_size * 2
        )

        # Create connections to nearby episodes
        for other in recent:
            if other.episode_id == episode_id:
                continue

            # Determine temporal relationship
            conn_id = self.pg_store.create_connection(
                source_id=episode_id,
                target_id=other.episode_id,
                relation_type="temporal_adjacent",
                strength=0.5,
                causal_direction=1  # Source came after
            )

            if conn_id:
                connection_ids.append(conn_id)

            if len(connection_ids) >= window_size:
                break

        return connection_ids

    def run_full_consolidation(
        self,
        batch_size: int = 50
    ) -> ConsolidationResult:
        """
        Run complete consolidation cycle.

        1. Promote working memory to episodic
        2. Calculate decay for all episodic memories
        3. Extract entities from high-importance episodes
        4. Create temporal connections

        Returns:
            ConsolidationResult with counts and errors
        """
        errors = []

        # 1. Promote working to episodic
        try:
            promoted, episode_ids = self.promote_working_to_episodic(batch_size=batch_size)
        except Exception as e:
            promoted = 0
            episode_ids = []
            errors.append(f"Promotion error: {str(e)}")

        # 2. Run decay sweep
        try:
            decay_result = self.run_decay_sweep(batch_size=batch_size)
        except Exception as e:
            decay_result = {"updated": 0, "archived": 0, "forgotten": 0}
            errors.append(f"Decay error: {str(e)}")

        # 3. Extract entities from recently promoted episodes
        entities_extracted = 0
        try:
            # Also process high-importance existing episodes
            important = self.pg_store.search_episodes(
                min_importance=0.7,
                limit=batch_size
            )

            for ep in important:
                entity_ids = self.extract_entities_from_episode(ep.episode_id)
                entities_extracted += len(entity_ids)
        except Exception as e:
            errors.append(f"Entity extraction error: {str(e)}")

        # 4. Create temporal connections for new episodes
        connections_created = 0
        try:
            for ep_id in episode_ids:
                conn_ids = self.create_temporal_connections(ep_id)
                connections_created += len(conn_ids)
        except Exception as e:
            errors.append(f"Connection error: {str(e)}")

        return ConsolidationResult(
            working_to_episodic=promoted,
            episodic_strengthened=decay_result.get("updated", 0),
            episodic_decayed=0,  # Included in archived/forgotten
            episodic_archived=decay_result.get("archived", 0),
            episodic_forgotten=decay_result.get("forgotten", 0),
            entities_extracted=entities_extracted,
            connections_created=connections_created,
            errors=errors
        )


# Singleton instance
_consolidator = None

def get_consolidator() -> MemoryConsolidator:
    """Get or create the consolidator singleton."""
    global _consolidator
    if _consolidator is None:
        _consolidator = MemoryConsolidator()
    return _consolidator


if __name__ == "__main__":
    # Test consolidation
    consolidator = get_consolidator()

    print("Memory Consolidation Test")
    print("=" * 50)

    # Test decay calculation
    for hours in [0, 1, 6, 24, 48, 168]:
        for consol in [0, 1, 3, 5]:
            strength = consolidator.calculate_decay(hours, consol)
            print(f"Hours: {hours:3d}, Consolidations: {consol}, Strength: {strength:.3f}")

    print("\nRunning full consolidation...")
    result = consolidator.run_full_consolidation(batch_size=10)

    print(f"\nConsolidation Results:")
    print(f"  Working → Episodic: {result.working_to_episodic}")
    print(f"  Strengthened: {result.episodic_strengthened}")
    print(f"  Archived: {result.episodic_archived}")
    print(f"  Forgotten: {result.episodic_forgotten}")
    print(f"  Entities extracted: {result.entities_extracted}")
    print(f"  Connections created: {result.connections_created}")

    if result.errors:
        print(f"\nErrors:")
        for e in result.errors:
            print(f"  - {e}")
