"""
Genesis Memory System - PostgreSQL Episodic Store
================================================
Handles all episodic memory operations on Elestio PostgreSQL.
Implements temporal knowledge graph, decay tracking, and consolidation.
"""

import psycopg2
from psycopg2 import pool
from psycopg2.extras import RealDictCursor, execute_values
import hashlib
import json
from datetime import datetime, timezone
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
import math


@dataclass
class Episode:
    """Represents an episodic memory."""
    episode_id: str
    content: str
    content_hash: str
    embedding_id: Optional[str] = None
    created_at: Optional[str] = None
    accessed_at: Optional[str] = None
    strength: float = 1.0
    decay_rate: float = 0.1
    consolidation_count: int = 0
    source_type: str = "observation"
    source_ref: Optional[str] = None
    agent_id: Optional[str] = None
    importance_score: float = 0.5
    surprise_score: float = 0.0
    status: str = "active"

    def to_dict(self) -> Dict:
        return asdict(self)


@dataclass
class Connection:
    """Represents a connection between episodes."""
    connection_id: str
    source_episode_id: str
    target_episode_id: str
    relation_type: str
    strength: float = 1.0
    confidence: float = 1.0
    temporal_distance: Optional[str] = None
    causal_direction: int = 0


@dataclass
class Entity:
    """Represents a semantic entity."""
    entity_id: str
    name: str
    entity_type: str
    canonical_name: Optional[str] = None
    properties: Optional[Dict] = None
    embedding_id: Optional[str] = None
    importance: float = 0.5


class PostgreSQLStore:
    """
    PostgreSQL-backed episodic memory store.

    Features:
    - Connection pooling for high throughput
    - Temporal knowledge graph with edges
    - Ebbinghaus decay calculations
    - Entity extraction and linking
    """

    def __init__(
        self,
        host: str = "postgresql-genesis-u50607.vm.elestio.app",
        port: int = 25432,
        user: str = "postgres",
        password: str = "etY0eog17tD-dDuj--IRH",
        database: str = "postgres",
        min_connections: int = 2,
        max_connections: int = 10
    ):
        self.connection_params = {
            "host": host,
            "port": port,
            "user": user,
            "password": password,
            "database": database
        }

        # Initialize connection pool
        self.pool = pool.ThreadedConnectionPool(
            min_connections,
            max_connections,
            **self.connection_params
        )

    def _get_conn(self):
        """Get connection from pool."""
        return self.pool.getconn()

    def _put_conn(self, conn):
        """Return connection to pool."""
        self.pool.putconn(conn)

    def _generate_hash(self, content: str) -> str:
        """Generate content hash for deduplication."""
        return hashlib.sha256(content.encode()).hexdigest()

    # =========================================================================
    # EPISODE OPERATIONS
    # =========================================================================

    def store_episode(
        self,
        content: str,
        source_type: str = "observation",
        source_ref: Optional[str] = None,
        agent_id: Optional[str] = None,
        importance_score: float = 0.5,
        surprise_score: float = 0.0,
        embedding_id: Optional[str] = None
    ) -> Optional[str]:
        """
        Store a new episodic memory.

        Returns episode_id if successful, None if duplicate.
        """
        content_hash = self._generate_hash(content)
        conn = self._get_conn()

        try:
            with conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO memory_episodes (
                        content, content_hash, embedding_id,
                        source_type, source_ref, agent_id,
                        importance_score, surprise_score
                    ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
                    ON CONFLICT (content_hash) DO UPDATE SET
                        accessed_at = NOW(),
                        consolidation_count = memory_episodes.consolidation_count + 1
                    RETURNING episode_id, (xmax = 0) as is_new
                """, (
                    content, content_hash, embedding_id,
                    source_type, source_ref, agent_id,
                    importance_score, surprise_score
                ))

                row = cur.fetchone()
                conn.commit()

                if row:
                    episode_id = str(row[0])
                    is_new = row[1]
                    return episode_id if is_new else None
                return None

        except Exception as e:
            conn.rollback()
            raise e
        finally:
            self._put_conn(conn)

    def get_episode(self, episode_id: str) -> Optional[Episode]:
        """Retrieve an episode by ID and update access time."""
        conn = self._get_conn()

        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                # Update access time and get episode
                cur.execute("""
                    UPDATE memory_episodes
                    SET accessed_at = NOW()
                    WHERE episode_id = %s
                    RETURNING *
                """, (episode_id,))

                row = cur.fetchone()
                conn.commit()

                if row:
                    return Episode(
                        episode_id=str(row['episode_id']),
                        content=row['content'],
                        content_hash=row['content_hash'],
                        embedding_id=str(row['embedding_id']) if row['embedding_id'] else None,
                        created_at=row['created_at'].isoformat() if row['created_at'] else None,
                        accessed_at=row['accessed_at'].isoformat() if row['accessed_at'] else None,
                        strength=row['strength'],
                        decay_rate=row['decay_rate'],
                        consolidation_count=row['consolidation_count'],
                        source_type=row['source_type'],
                        source_ref=row['source_ref'],
                        agent_id=row['agent_id'],
                        importance_score=row['importance_score'],
                        surprise_score=row['surprise_score'],
                        status=row['status']
                    )
                return None

        finally:
            self._put_conn(conn)

    def search_episodes(
        self,
        query: Optional[str] = None,
        source_type: Optional[str] = None,
        agent_id: Optional[str] = None,
        min_importance: float = 0.0,
        min_strength: float = 0.0,
        status: str = "active",
        limit: int = 20,
        offset: int = 0
    ) -> List[Episode]:
        """Search episodes with various filters."""
        conn = self._get_conn()

        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                conditions = ["status = %s"]
                params = [status]

                if query:
                    conditions.append("content ILIKE %s")
                    params.append(f"%{query}%")

                if source_type:
                    conditions.append("source_type = %s")
                    params.append(source_type)

                if agent_id:
                    conditions.append("agent_id = %s")
                    params.append(agent_id)

                conditions.append("importance_score >= %s")
                params.append(min_importance)

                conditions.append("strength >= %s")
                params.append(min_strength)

                params.extend([limit, offset])

                cur.execute(f"""
                    SELECT * FROM memory_episodes
                    WHERE {' AND '.join(conditions)}
                    ORDER BY importance_score DESC, created_at DESC
                    LIMIT %s OFFSET %s
                """, params)

                rows = cur.fetchall()

                return [Episode(
                    episode_id=str(row['episode_id']),
                    content=row['content'],
                    content_hash=row['content_hash'],
                    embedding_id=str(row['embedding_id']) if row['embedding_id'] else None,
                    created_at=row['created_at'].isoformat() if row['created_at'] else None,
                    accessed_at=row['accessed_at'].isoformat() if row['accessed_at'] else None,
                    strength=row['strength'],
                    decay_rate=row['decay_rate'],
                    consolidation_count=row['consolidation_count'],
                    source_type=row['source_type'],
                    source_ref=row['source_ref'],
                    agent_id=row['agent_id'],
                    importance_score=row['importance_score'],
                    surprise_score=row['surprise_score'],
                    status=row['status']
                ) for row in rows]

        finally:
            self._put_conn(conn)

    # =========================================================================
    # DECAY AND CONSOLIDATION
    # =========================================================================

    def calculate_decay(self, episode_id: str) -> float:
        """
        Calculate current strength using Ebbinghaus forgetting curve.

        R(t) = e^(-t/S) where:
        - R = retention (strength)
        - t = time since last access (hours)
        - S = stability (increases with rehearsals)
        """
        conn = self._get_conn()

        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute("""
                    SELECT
                        strength,
                        decay_rate,
                        consolidation_count,
                        EXTRACT(EPOCH FROM (NOW() - accessed_at)) / 3600 as hours_since_access
                    FROM memory_episodes
                    WHERE episode_id = %s
                """, (episode_id,))

                row = cur.fetchone()
                if not row:
                    return 0.0

                hours = row['hours_since_access'] or 0
                consolidation = row['consolidation_count'] or 0
                base_decay = row['decay_rate'] or 0.1

                # Stability increases with consolidation
                stability = 24 * (1 + consolidation * 0.5)

                # Ebbinghaus decay
                decayed_strength = math.exp(-hours / stability)

                return max(0.0, min(1.0, decayed_strength))

        finally:
            self._put_conn(conn)

    def run_decay_sweep(self, batch_size: int = 100) -> Dict[str, int]:
        """
        Run decay calculations on all active episodes.

        Returns counts of episodes updated, archived, and forgotten.
        """
        conn = self._get_conn()
        results = {"updated": 0, "archived": 0, "forgotten": 0}

        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                # Get episodes to process
                cur.execute("""
                    SELECT episode_id, strength, decay_rate, consolidation_count,
                           importance_score,
                           EXTRACT(EPOCH FROM (NOW() - accessed_at)) / 3600 as hours
                    FROM memory_episodes
                    WHERE status = 'active'
                    ORDER BY accessed_at ASC
                    LIMIT %s
                """, (batch_size,))

                episodes = cur.fetchall()

                for ep in episodes:
                    hours = ep['hours'] or 0
                    consolidation = ep['consolidation_count'] or 0
                    importance = ep['importance_score'] or 0.5

                    # Calculate new strength
                    stability = 24 * (1 + consolidation * 0.5)
                    new_strength = math.exp(-hours / stability)

                    # Determine action
                    if new_strength < 0.1 and importance < 0.3:
                        # Forget unimportant weak memories
                        cur.execute("""
                            UPDATE memory_episodes
                            SET status = 'forgotten', strength = %s
                            WHERE episode_id = %s
                        """, (new_strength, ep['episode_id']))
                        results["forgotten"] += 1
                    elif new_strength < 0.3:
                        # Archive weak memories
                        cur.execute("""
                            UPDATE memory_episodes
                            SET status = 'archived', strength = %s
                            WHERE episode_id = %s
                        """, (new_strength, ep['episode_id']))
                        results["archived"] += 1
                    else:
                        # Update strength
                        cur.execute("""
                            UPDATE memory_episodes
                            SET strength = %s
                            WHERE episode_id = %s
                        """, (new_strength, ep['episode_id']))
                        results["updated"] += 1

                conn.commit()
                return results

        except Exception as e:
            conn.rollback()
            raise e
        finally:
            self._put_conn(conn)

    def strengthen_episode(self, episode_id: str, amount: float = 0.1) -> float:
        """Strengthen a memory (spaced repetition effect)."""
        conn = self._get_conn()

        try:
            with conn.cursor() as cur:
                cur.execute("""
                    UPDATE memory_episodes
                    SET
                        strength = LEAST(1.0, strength + %s),
                        consolidation_count = consolidation_count + 1,
                        last_rehearsal = NOW(),
                        accessed_at = NOW()
                    WHERE episode_id = %s
                    RETURNING strength
                """, (amount, episode_id))

                row = cur.fetchone()
                conn.commit()

                return row[0] if row else 0.0

        finally:
            self._put_conn(conn)

    # =========================================================================
    # CONNECTIONS (TEMPORAL KNOWLEDGE GRAPH)
    # =========================================================================

    def create_connection(
        self,
        source_id: str,
        target_id: str,
        relation_type: str,
        strength: float = 1.0,
        confidence: float = 1.0,
        causal_direction: int = 0
    ) -> Optional[str]:
        """Create a connection between two episodes."""
        conn = self._get_conn()

        try:
            with conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO memory_connections (
                        source_episode_id, target_episode_id,
                        relation_type, strength, confidence, causal_direction
                    ) VALUES (%s, %s, %s, %s, %s, %s)
                    ON CONFLICT DO NOTHING
                    RETURNING connection_id
                """, (source_id, target_id, relation_type, strength, confidence, causal_direction))

                row = cur.fetchone()
                conn.commit()

                return str(row[0]) if row else None

        finally:
            self._put_conn(conn)

    def get_connected_episodes(
        self,
        episode_id: str,
        relation_type: Optional[str] = None,
        direction: str = "both",
        limit: int = 20
    ) -> List[Tuple[Episode, Connection]]:
        """Get episodes connected to the given episode."""
        conn = self._get_conn()

        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                conditions = []
                params = []

                if direction in ("both", "outgoing"):
                    conditions.append("c.source_episode_id = %s")
                    params.append(episode_id)

                if direction in ("both", "incoming"):
                    conditions.append("c.target_episode_id = %s")
                    params.append(episode_id)

                where_clause = " OR ".join(conditions) if direction == "both" else conditions[0]

                if relation_type:
                    where_clause = f"({where_clause}) AND c.relation_type = %s"
                    params.append(relation_type)

                params.append(limit)

                cur.execute(f"""
                    SELECT e.*, c.connection_id, c.source_episode_id, c.target_episode_id,
                           c.relation_type, c.strength as conn_strength, c.confidence
                    FROM memory_connections c
                    JOIN memory_episodes e ON (
                        e.episode_id = CASE
                            WHEN c.source_episode_id = %s THEN c.target_episode_id
                            ELSE c.source_episode_id
                        END
                    )
                    WHERE {where_clause}
                    ORDER BY c.strength DESC
                    LIMIT %s
                """, [episode_id] + params)

                rows = cur.fetchall()
                results = []

                for row in rows:
                    episode = Episode(
                        episode_id=str(row['episode_id']),
                        content=row['content'],
                        content_hash=row['content_hash'],
                        strength=row['strength'],
                        importance_score=row['importance_score'],
                        source_type=row['source_type'],
                        status=row['status']
                    )
                    connection = Connection(
                        connection_id=str(row['connection_id']),
                        source_episode_id=str(row['source_episode_id']),
                        target_episode_id=str(row['target_episode_id']),
                        relation_type=row['relation_type'],
                        strength=row['conn_strength'],
                        confidence=row['confidence']
                    )
                    results.append((episode, connection))

                return results

        finally:
            self._put_conn(conn)

    # =========================================================================
    # ENTITIES
    # =========================================================================

    def store_entity(
        self,
        name: str,
        entity_type: str,
        properties: Optional[Dict] = None,
        embedding_id: Optional[str] = None,
        importance: float = 0.5
    ) -> str:
        """Store or update a semantic entity."""
        conn = self._get_conn()
        canonical = name.lower().strip()

        try:
            with conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO semantic_entities (
                        name, entity_type, canonical_name,
                        properties, embedding_id, importance
                    ) VALUES (%s, %s, %s, %s, %s, %s)
                    ON CONFLICT (canonical_name, entity_type) DO UPDATE SET
                        mention_count = semantic_entities.mention_count + 1,
                        last_seen = NOW(),
                        importance = GREATEST(semantic_entities.importance, EXCLUDED.importance)
                    RETURNING entity_id
                """, (
                    name, entity_type, canonical,
                    json.dumps(properties) if properties else None,
                    embedding_id, importance
                ))

                row = cur.fetchone()
                conn.commit()

                return str(row[0])

        finally:
            self._put_conn(conn)

    def link_entity_to_episode(
        self,
        episode_id: str,
        entity_id: str,
        mention_text: str,
        confidence: float = 1.0
    ) -> str:
        """Link an entity mention to an episode."""
        conn = self._get_conn()

        try:
            with conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO entity_mentions (
                        episode_id, entity_id, mention_text, confidence
                    ) VALUES (%s, %s, %s, %s)
                    RETURNING mention_id
                """, (episode_id, entity_id, mention_text, confidence))

                row = cur.fetchone()
                conn.commit()

                return str(row[0])

        finally:
            self._put_conn(conn)

    def get_entity_episodes(
        self,
        entity_id: str,
        limit: int = 20
    ) -> List[Episode]:
        """Get all episodes mentioning an entity."""
        conn = self._get_conn()

        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute("""
                    SELECT DISTINCT e.*
                    FROM memory_episodes e
                    JOIN entity_mentions m ON e.episode_id = m.episode_id
                    WHERE m.entity_id = %s
                    ORDER BY e.created_at DESC
                    LIMIT %s
                """, (entity_id, limit))

                rows = cur.fetchall()

                return [Episode(
                    episode_id=str(row['episode_id']),
                    content=row['content'],
                    content_hash=row['content_hash'],
                    strength=row['strength'],
                    importance_score=row['importance_score'],
                    source_type=row['source_type'],
                    status=row['status']
                ) for row in rows]

        finally:
            self._put_conn(conn)
    def get_entity(self, name: str) -> Optional[Dict]:
        """Retrieve a semantic entity by name."""
        conn = self._get_conn()
        canonical = name.lower().strip()

        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute("""
                    SELECT * FROM semantic_entities 
                    WHERE canonical_name = %s
                """, (canonical,))
                
                return cur.fetchone()
        finally:
            self._put_conn(conn)

    def get_all_entities(self, limit: int = 100) -> List[Dict]:
        """Retrieve all semantic entities."""
        conn = self._get_conn()

        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute("""
                    SELECT * FROM semantic_entities 
                    ORDER BY importance DESC, mention_count DESC
                    LIMIT %s
                """, (limit,))
                
                return cur.fetchall()
        finally:
            self._put_conn(conn)

    # =========================================================================
    # STATISTICS
    # =========================================================================

    def get_stats(self) -> Dict[str, Any]:
        """Get comprehensive memory statistics."""
        conn = self._get_conn()

        try:
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute("""
                    SELECT
                        COUNT(*) as total_episodes,
                        COUNT(*) FILTER (WHERE status = 'active') as active,
                        COUNT(*) FILTER (WHERE status = 'archived') as archived,
                        COUNT(*) FILTER (WHERE status = 'forgotten') as forgotten,
                        AVG(strength) as avg_strength,
                        AVG(importance_score) as avg_importance
                    FROM memory_episodes
                """)
                episode_stats = cur.fetchone()

                cur.execute("SELECT COUNT(*) as total FROM memory_connections")
                connection_count = cur.fetchone()['total']

                cur.execute("SELECT COUNT(*) as total FROM semantic_entities")
                entity_count = cur.fetchone()['total']

                cur.execute("""
                    SELECT source_type, COUNT(*) as count
                    FROM memory_episodes
                    GROUP BY source_type
                """)
                by_source = {row['source_type']: row['count'] for row in cur.fetchall()}

                return {
                    "episodes": {
                        "total": episode_stats['total_episodes'],
                        "active": episode_stats['active'],
                        "archived": episode_stats['archived'],
                        "forgotten": episode_stats['forgotten'],
                        "avg_strength": round(episode_stats['avg_strength'] or 0, 3),
                        "avg_importance": round(episode_stats['avg_importance'] or 0, 3)
                    },
                    "connections": connection_count,
                    "entities": entity_count,
                    "by_source": by_source
                }

        finally:
            self._put_conn(conn)

    def close(self):
        """Close all connections in the pool."""
        self.pool.closeall()


# Singleton instance
_store = None

def get_postgresql_store() -> PostgreSQLStore:
    """Get or create the PostgreSQL store singleton."""
    global _store
    if _store is None:
        _store = PostgreSQLStore()
    return _store


if __name__ == "__main__":
    # Test the store
    store = get_postgresql_store()

    print("PostgreSQL Episodic Store Test")
    print("=" * 50)

    # Store a test episode
    episode_id = store.store_episode(
        content="Genesis Memory System initialized successfully with PostgreSQL backend.",
        source_type="system",
        agent_id="test_agent",
        importance_score=0.8
    )

    if episode_id:
        print(f"Created episode: {episode_id}")

        # Retrieve it
        episode = store.get_episode(episode_id)
        if episode:
            print(f"Retrieved: {episode.content[:50]}...")
            print(f"Strength: {episode.strength}")
            print(f"Importance: {episode.importance_score}")

    # Get stats
    stats = store.get_stats()
    print(f"\nStats: {stats}")

    store.close()
