#!/usr/bin/env python3
"""
Genesis Memory Consolidation Loop (Enhanced)
=============================================
Background "Sleep" Process for Memory Optimization

Like human sleep consolidation, this process:
1. Promotes frequently accessed episodic memories to semantic tier
2. Links related memories in the knowledge graph
3. Applies VoI-based decay to unused memories
4. Compresses low-value memories to archive
5. Extracts patterns using LLM for meta-learning

Enhanced with:
- VoI (Value of Information) pruning
- LLM-powered pattern extraction via Claude Haiku
- Memory compression/archiving

Runs periodically (every 4-6 hours) or on-demand.

Usage:
    python consolidation_loop.py run       # Run full consolidation
    python consolidation_loop.py dry-run   # Show what would change
    python consolidation_loop.py status    # Show current status
    python consolidation_loop.py voi       # Show VoI distribution
"""

import asyncio
import json
import math
import os
import sys
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
import psycopg2
from psycopg2.extras import Json, RealDictCursor

# Use Anthropic directly for LLM calls
try:
    import anthropic
    ANTHROPIC_AVAILABLE = True
except ImportError:
    ANTHROPIC_AVAILABLE = False

# Import our modules
try:
    from genesis_ambient_memory import AmbientEvent, EventStore, EventType
    from graphiti_integration import GenesisGraphiti
except ImportError:
    from .genesis_ambient_memory import AmbientEvent, EventStore, EventType
    from .graphiti_integration import GenesisGraphiti


@dataclass
class ConsolidationResult:
    """Result of a consolidation run."""
    started_at: datetime
    completed_at: datetime
    events_processed: int
    promoted_to_semantic: int
    decayed: int
    archived: int
    deleted: int
    linked: int
    patterns_extracted: int
    llm_patterns: int
    errors: List[str]

    def to_dict(self) -> Dict:
        return {
            "started_at": self.started_at.isoformat(),
            "completed_at": self.completed_at.isoformat(),
            "duration_seconds": (self.completed_at - self.started_at).total_seconds(),
            "events_processed": self.events_processed,
            "promoted_to_semantic": self.promoted_to_semantic,
            "decayed": self.decayed,
            "archived": self.archived,
            "deleted": self.deleted,
            "linked": self.linked,
            "patterns_extracted": self.patterns_extracted,
            "llm_patterns": self.llm_patterns,
            "errors": self.errors
        }


class VoICalculator:
    """
    Value of Information calculator.

    VoI = Importance × Frequency × Recency

    Where:
    - Importance = surprise_score × 10 (1-10)
    - Frequency = access_count / max_access_count (0-1)
    - Recency = e^(-0.1 × days_since_access) (0-1)

    Decision thresholds:
    - VoI < 2.0 → DELETE (forgotten)
    - VoI 2.0-5.0 → COMPRESS (summarize to archive)
    - VoI > 5.0 → KEEP (full fidelity)
    """

    DELETE_THRESHOLD = 2.0
    COMPRESS_THRESHOLD = 5.0
    DECAY_RATE = 0.1  # Lambda for exponential decay

    @staticmethod
    def calculate(
        surprise_score: float,
        access_count: int,
        max_access_count: int,
        days_since_access: float
    ) -> float:
        """Calculate VoI score."""
        # Importance (1-10 scale)
        importance = surprise_score * 10

        # Frequency (0-1 scale)
        frequency = access_count / max(max_access_count, 1)

        # Recency (0-1 scale, exponential decay)
        recency = math.exp(-VoICalculator.DECAY_RATE * days_since_access)

        return importance * frequency * recency

    @staticmethod
    def get_action(voi: float) -> str:
        """Determine action based on VoI score."""
        if voi < VoICalculator.DELETE_THRESHOLD:
            return "delete"
        elif voi < VoICalculator.COMPRESS_THRESHOLD:
            return "compress"
        else:
            return "keep"


class MemoryConsolidator:
    """
    Enhanced memory consolidation engine.

    Implements the "sleep cycle" for Genesis memory:
    - Runs periodically to optimize memory storage
    - Promotes important memories to higher tiers
    - Uses VoI for intelligent forgetting
    - Compresses memories via LLM summarization
    - Extracts patterns with LLM analysis
    """

    # Thresholds
    PROMOTION_ACCESS_THRESHOLD = 3
    DECAY_DAYS = 30
    SEMANTIC_SURPRISE_THRESHOLD = 0.7
    CLUSTER_MIN_SIZE = 5  # Minimum events for LLM pattern extraction

    # LLM Configuration
    HAIKU_MODEL = "claude-3-5-haiku-20241022"

    COMPRESSION_PROMPT = """Summarize this memory event in 1-2 sentences, preserving the key information:

Event Type: {event_type}
Content: {content}
Timestamp: {timestamp}

Summary:"""

    PATTERN_PROMPT = """Analyze these related memory events and extract meaningful patterns.

EVENTS:
{events}

Extract:
1. Key learnings (what was discovered)
2. Behavioral patterns (how problems were approached)
3. Entity relationships (who/what interacted)
4. Meta-patterns (recurring strategies)

OUTPUT FORMAT (JSON):
{{
  "learnings": ["learning1", "learning2"],
  "behavioral_patterns": ["pattern1", "pattern2"],
  "entity_relations": [{{"from": "A", "to": "B", "relation": "uses"}}],
  "meta_patterns": ["meta1", "meta2"]
}}

Return ONLY the JSON, no other text."""

    def __init__(
        self,
        postgres_config: Dict = None,
        graphiti: GenesisGraphiti = None,
        anthropic_api_key: str = None
    ):
        self.postgres_config = postgres_config or {
            "host": os.getenv("GENESIS_PG_HOST", "postgresql-genesis-u50607.vm.elestio.app"),
            "port": int(os.getenv("GENESIS_PG_PORT", "25432")),
            "database": os.getenv("GENESIS_PG_DATABASE", "postgres"),
            "user": os.getenv("GENESIS_PG_USER", "postgres"),
            "password": os.getenv("GENESIS_PG_PASSWORD", "etY0eog17tD-dDuj--IRH")
        }
        self.graphiti = graphiti or GenesisGraphiti()
        self.api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY")
        self._client = None
        self._conn = None

    def _get_conn(self):
        """Get PostgreSQL connection."""
        if self._conn is None or self._conn.closed:
            self._conn = psycopg2.connect(**self.postgres_config)
            self._conn.autocommit = True
        return self._conn

    def _get_client(self):
        """Get Anthropic client."""
        if self._client is None and ANTHROPIC_AVAILABLE:
            self._client = anthropic.Anthropic(api_key=self.api_key)
        return self._client

    async def run_consolidation(self, dry_run: bool = False) -> ConsolidationResult:
        """
        Run a full consolidation cycle with VoI pruning and LLM patterns.
        """
        started_at = datetime.now()
        errors = []
        promoted = 0
        decayed = 0
        archived = 0
        deleted = 0
        linked = 0
        patterns = 0
        llm_patterns = 0
        events_processed = 0

        try:
            await self.graphiti.initialize()
            conn = self._get_conn()

            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                # Get max access count for VoI calculation
                cur.execute("SELECT MAX(COALESCE(access_count, 0)) as max_access FROM ambient_events")
                row = cur.fetchone()
                max_access = (row['max_access'] if row else 0) or 1

                # === 1. PROMOTION: High-value episodic → semantic ===
                cur.execute("""
                    SELECT event_id, event_type, content, surprise_score, tier,
                           COALESCE(access_count, 0) as access_count
                    FROM ambient_events
                    WHERE tier = 'episodic'
                      AND COALESCE(access_count, 0) >= %s
                      AND surprise_score >= %s
                    ORDER BY surprise_score DESC
                    LIMIT 100
                """, (self.PROMOTION_ACCESS_THRESHOLD, self.SEMANTIC_SURPRISE_THRESHOLD))

                for row in cur.fetchall():
                    events_processed += 1
                    if dry_run:
                        print(f"[DRY-RUN] PROMOTE: {row['event_id']} (score: {row['surprise_score']:.2f})")
                    else:
                        cur.execute("""
                            UPDATE ambient_events
                            SET tier = 'semantic',
                                processed_at = NOW(),
                                metadata = COALESCE(metadata, '{}'::jsonb) || %s
                            WHERE event_id = %s
                        """, (
                            Json({"promoted_at": datetime.now().isoformat()}),
                            row['event_id']
                        ))
                        try:
                            event = AmbientEvent(
                                event_id=row['event_id'],
                                event_type=EventType(row['event_type']),
                                timestamp=datetime.now().isoformat(),
                                session_id="consolidation",
                                content=row['content'],
                                surprise_score=row['surprise_score'],
                                tier='semantic'
                            )
                            await self.graphiti.ingest_event(event)
                        except Exception as e:
                            errors.append(f"Graphiti error: {e}")
                        promoted += 1

                # === 2. VoI-BASED PRUNING ===
                cur.execute("""
                    SELECT event_id, event_type, content, surprise_score,
                           COALESCE(access_count, 0) as access_count,
                           COALESCE(last_accessed, timestamp) as last_accessed,
                           timestamp
                    FROM ambient_events
                    WHERE tier IN ('episodic', 'working')
                      AND timestamp < NOW() - INTERVAL '7 days'
                    ORDER BY timestamp ASC
                    LIMIT 500
                """)

                for row in cur.fetchall():
                    events_processed += 1
                    last_access = row['last_accessed'] or row['timestamp']
                    # Handle timezone-aware datetimes from PostgreSQL
                    now = datetime.now(timezone.utc)
                    if last_access.tzinfo is None:
                        last_access = last_access.replace(tzinfo=timezone.utc)
                    days_since = (now - last_access).days
                    voi = VoICalculator.calculate(
                        row['surprise_score'],
                        row['access_count'],
                        max_access,
                        days_since
                    )
                    action = VoICalculator.get_action(voi)

                    if action == "delete":
                        if dry_run:
                            print(f"[DRY-RUN] DELETE: {row['event_id']} (VoI: {voi:.2f})")
                        else:
                            cur.execute("""
                                DELETE FROM ambient_events WHERE event_id = %s
                            """, (row['event_id'],))
                            deleted += 1

                    elif action == "compress":
                        if dry_run:
                            print(f"[DRY-RUN] COMPRESS: {row['event_id']} (VoI: {voi:.2f})")
                        else:
                            summary = await self._compress_memory(row)
                            if summary:
                                cur.execute("""
                                    INSERT INTO archived_memories
                                    (original_event_id, summary, original_content, compressed_at)
                                    VALUES (%s, %s, %s, NOW())
                                """, (row['event_id'], summary, Json(row['content'])))
                                cur.execute("""
                                    DELETE FROM ambient_events WHERE event_id = %s
                                """, (row['event_id'],))
                                archived += 1
                            else:
                                decayed += 1

                    # Update VoI score for kept events
                    if action == "keep":
                        cur.execute("""
                            UPDATE ambient_events
                            SET voi_score = %s
                            WHERE event_id = %s
                        """, (voi, row['event_id']))

                # === 3. LINK RELATED MEMORIES ===
                cur.execute("""
                    WITH event_words AS (
                        SELECT event_id,
                               array_agg(word) as words
                        FROM ambient_events,
                             LATERAL unnest(string_to_array(lower(content::text), ' ')) as word
                        WHERE tier IN ('episodic', 'semantic')
                          AND length(word) > 3
                        GROUP BY event_id
                    )
                    SELECT e1.event_id as id1, e2.event_id as id2,
                           array_length(ARRAY(
                               SELECT unnest(e1.words)
                               INTERSECT
                               SELECT unnest(e2.words)
                           ), 1) as overlap
                    FROM event_words e1
                    JOIN event_words e2 ON e1.event_id < e2.event_id
                    WHERE array_length(ARRAY(
                              SELECT unnest(e1.words)
                              INTERSECT
                              SELECT unnest(e2.words)
                          ), 1) > 5
                    LIMIT 100
                """)

                for row in cur.fetchall():
                    if row['overlap'] and row['overlap'] > 5:
                        if dry_run:
                            print(f"[DRY-RUN] LINK: {row['id1']} ↔ {row['id2']} (overlap: {row['overlap']})")
                        else:
                            cur.execute("""
                                UPDATE ambient_events
                                SET metadata = jsonb_set(
                                    COALESCE(metadata, '{}'),
                                    '{linked_to}',
                                    COALESCE(metadata->'linked_to', '[]'::jsonb) || %s
                                )
                                WHERE event_id = %s
                            """, (Json([row['id2']]), row['id1']))
                            linked += 1

                # === 4. BASIC PATTERN EXTRACTION ===
                cur.execute("""
                    SELECT event_type, COUNT(*) as count,
                           AVG(surprise_score) as avg_surprise,
                           array_agg(DISTINCT content->>'tool_name')
                               FILTER (WHERE content->>'tool_name' IS NOT NULL) as tools
                    FROM ambient_events
                    WHERE timestamp > NOW() - INTERVAL '7 days'
                    GROUP BY event_type
                    ORDER BY count DESC
                """)

                for row in cur.fetchall():
                    pattern = {
                        "type": "usage_pattern",
                        "event_type": row['event_type'],
                        "count_7d": row['count'],
                        "avg_surprise": float(row['avg_surprise'] or 0),
                        "tools": row['tools'] or [],
                        "extracted_at": datetime.now().isoformat()
                    }

                    if dry_run:
                        print(f"[DRY-RUN] PATTERN: {row['event_type']} ({row['count']} events)")
                    else:
                        cur.execute("""
                            INSERT INTO ambient_events
                            (event_id, event_type, session_id, timestamp, content, metadata, surprise_score, tier)
                            VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
                            ON CONFLICT (event_id) DO UPDATE SET
                                content = EXCLUDED.content,
                                processed_at = NOW()
                        """, (
                            f"pattern-{row['event_type']}-{datetime.now().strftime('%Y%m%d')}",
                            'learning',
                            'consolidation',
                            datetime.now(),
                            Json(pattern),
                            Json({"source": "consolidation"}),
                            0.8,
                            'semantic'
                        ))
                        patterns += 1

                # === 5. LLM-POWERED PATTERN EXTRACTION (Enhanced REM) ===
                llm_patterns = await self._extract_llm_patterns(cur, dry_run)

        except Exception as e:
            errors.append(str(e))
            import traceback
            errors.append(traceback.format_exc())
        finally:
            await self.graphiti.close()

        return ConsolidationResult(
            started_at=started_at,
            completed_at=datetime.now(),
            events_processed=events_processed,
            promoted_to_semantic=promoted,
            decayed=decayed,
            archived=archived,
            deleted=deleted,
            linked=linked,
            patterns_extracted=patterns,
            llm_patterns=llm_patterns,
            errors=errors
        )

    async def _compress_memory(self, event: Dict) -> Optional[str]:
        """Use Haiku to summarize a memory for archiving."""
        if not ANTHROPIC_AVAILABLE:
            return None

        client = self._get_client()
        if not client:
            return None

        try:
            prompt = self.COMPRESSION_PROMPT.format(
                event_type=event['event_type'],
                content=json.dumps(event['content'], default=str)[:1000],
                timestamp=event['timestamp'].isoformat() if event['timestamp'] else "unknown"
            )

            response = client.messages.create(
                model=self.HAIKU_MODEL,
                max_tokens=150,
                messages=[{"role": "user", "content": prompt}]
            )

            return response.content[0].text.strip()

        except Exception as e:
            print(f"Compression error: {e}")
            return None

    async def _extract_llm_patterns(self, cur, dry_run: bool) -> int:
        """Extract patterns using LLM analysis of event clusters."""
        if not ANTHROPIC_AVAILABLE:
            return 0

        client = self._get_client()
        if not client:
            return 0

        extracted = 0

        try:
            # Find clusters of related events
            cur.execute("""
                SELECT event_type,
                       array_agg(json_build_object(
                           'id', event_id,
                           'content', content,
                           'timestamp', timestamp,
                           'surprise', surprise_score
                       ) ORDER BY timestamp DESC) as events,
                       COUNT(*) as count
                FROM ambient_events
                WHERE tier IN ('episodic', 'semantic')
                  AND timestamp > NOW() - INTERVAL '7 days'
                GROUP BY event_type
                HAVING COUNT(*) >= %s
                LIMIT 10
            """, (self.CLUSTER_MIN_SIZE,))

            clusters = cur.fetchall()

            for cluster in clusters:
                if dry_run:
                    print(f"[DRY-RUN] LLM PATTERN: {cluster['event_type']} ({cluster['count']} events)")
                    continue

                # Format events for LLM
                events_text = json.dumps(cluster['events'][:20], indent=2, default=str)

                prompt = self.PATTERN_PROMPT.format(events=events_text)

                try:
                    response = client.messages.create(
                        model=self.HAIKU_MODEL,
                        max_tokens=512,
                        messages=[{"role": "user", "content": prompt}]
                    )

                    response_text = response.content[0].text.strip()

                    # Parse JSON response
                    if response_text.startswith("```"):
                        response_text = response_text.split("```")[1]
                        if response_text.startswith("json"):
                            response_text = response_text[4:]

                    pattern_data = json.loads(response_text)

                    # Store extracted patterns
                    cur.execute("""
                        INSERT INTO extracted_patterns
                        (pattern_type, content, confidence, source_event_ids, extracted_at)
                        VALUES (%s, %s, %s, %s, NOW())
                    """, (
                        cluster['event_type'],
                        json.dumps(pattern_data),
                        0.8,
                        [e['id'] for e in cluster['events'][:20]]
                    ))

                    extracted += 1

                except Exception as e:
                    print(f"LLM pattern error for {cluster['event_type']}: {e}")

        except Exception as e:
            print(f"Cluster query error: {e}")

        return extracted

    async def get_status(self) -> Dict[str, Any]:
        """Get current consolidation status and statistics."""
        try:
            conn = self._get_conn()
            with conn.cursor() as cur:
                cur.execute("""
                    SELECT
                        COUNT(*) as total,
                        COUNT(*) FILTER (WHERE tier = 'semantic') as semantic,
                        COUNT(*) FILTER (WHERE tier = 'episodic') as episodic,
                        COUNT(*) FILTER (WHERE tier = 'working') as working,
                        COUNT(*) FILTER (WHERE tier = 'discarded') as discarded,
                        AVG(surprise_score) as avg_surprise,
                        AVG(COALESCE(voi_score, 0)) as avg_voi,
                        MAX(timestamp) as latest_event,
                        MAX(processed_at) as last_consolidation
                    FROM ambient_events
                """)
                row = cur.fetchone()

                # Count archived
                cur.execute("SELECT COUNT(*) FROM archived_memories")
                archived = cur.fetchone()[0]

                # Count patterns
                cur.execute("SELECT COUNT(*) FROM extracted_patterns")
                patterns = cur.fetchone()[0]

                return {
                    "total_events": row[0],
                    "by_tier": {
                        "semantic": row[1],
                        "episodic": row[2],
                        "working": row[3],
                        "discarded": row[4]
                    },
                    "avg_surprise_score": round(float(row[5] or 0), 3),
                    "avg_voi_score": round(float(row[6] or 0), 3),
                    "latest_event": row[7].isoformat() if row[7] else None,
                    "last_consolidation": row[8].isoformat() if row[8] else None,
                    "archived_memories": archived,
                    "extracted_patterns": patterns,
                    "promotion_ready": self._count_promotion_ready()
                }
        except Exception as e:
            return {"error": str(e)}

    async def get_voi_distribution(self) -> Dict[str, Any]:
        """Get VoI score distribution for analysis."""
        try:
            conn = self._get_conn()
            with conn.cursor() as cur:
                cur.execute("""
                    SELECT
                        CASE
                            WHEN COALESCE(voi_score, 0) < 2 THEN 'delete_zone'
                            WHEN COALESCE(voi_score, 0) < 5 THEN 'compress_zone'
                            ELSE 'keep_zone'
                        END as zone,
                        COUNT(*) as count,
                        AVG(surprise_score) as avg_surprise,
                        AVG(COALESCE(access_count, 0)) as avg_access
                    FROM ambient_events
                    WHERE tier IN ('episodic', 'working')
                    GROUP BY zone
                    ORDER BY zone
                """)

                zones = {}
                for row in cur.fetchall():
                    zones[row[0]] = {
                        "count": row[1],
                        "avg_surprise": round(float(row[2] or 0), 3),
                        "avg_access": round(float(row[3] or 0), 1)
                    }

                return {
                    "zones": zones,
                    "thresholds": {
                        "delete": VoICalculator.DELETE_THRESHOLD,
                        "compress": VoICalculator.COMPRESS_THRESHOLD
                    }
                }
        except Exception as e:
            return {"error": str(e)}

    def _count_promotion_ready(self) -> int:
        """Count memories ready for promotion."""
        try:
            conn = self._get_conn()
            with conn.cursor() as cur:
                cur.execute("""
                    SELECT COUNT(*)
                    FROM ambient_events
                    WHERE tier = 'episodic'
                      AND surprise_score >= %s
                """, (self.SEMANTIC_SURPRISE_THRESHOLD,))
                return cur.fetchone()[0]
        except:
            return 0


# CLI Interface
async def main():
    if len(sys.argv) < 2:
        print("Genesis Memory Consolidation Loop (Enhanced)")
        print("\nUsage:")
        print("  python consolidation_loop.py run       # Run full consolidation")
        print("  python consolidation_loop.py dry-run   # Show what would change")
        print("  python consolidation_loop.py status    # Show current status")
        print("  python consolidation_loop.py voi       # Show VoI distribution")
        return

    command = sys.argv[1]
    consolidator = MemoryConsolidator()

    if command == "run":
        print("Starting consolidation run...")
        result = await consolidator.run_consolidation(dry_run=False)
        print("\nConsolidation complete:")
        print(json.dumps(result.to_dict(), indent=2))

    elif command == "dry-run":
        print("Dry run (no changes will be made)...")
        result = await consolidator.run_consolidation(dry_run=True)
        print("\nDry run complete:")
        print(json.dumps(result.to_dict(), indent=2))

    elif command == "status":
        status = await consolidator.get_status()
        print(json.dumps(status, indent=2))

    elif command == "voi":
        dist = await consolidator.get_voi_distribution()
        print(json.dumps(dist, indent=2))

    else:
        print(f"Unknown command: {command}")


if __name__ == "__main__":
    asyncio.run(main())
