#!/usr/bin/env python3
"""
Genesis Ambient Memory System
==============================
Memory in the Bloodstream - Automatic, Continuous, Unconscious

This module implements the automatic memory capture layer that operates
without explicit storage calls. Every action becomes an event, every
event flows through the memory bloodstream.

Architecture:
    Layer 0: Event Capture (hooks trigger this)
    Layer 1: Event Store (Redis Streams + PostgreSQL)
    Layer 2: Auto-Extraction (surprise scoring)
    Layer 3: Multi-Tier Storage (Working → Episodic → Semantic → Temporal KG)
    Layer 4: Consolidation (background promotion)

Usage:
    # Hook calls this automatically - no manual intervention needed
    python genesis_ambient_memory.py emit "Tool_Name" '{"input": "data"}' '{"output": "result"}'

    # Or programmatically:
    from genesis_ambient_memory import AmbientMemory
    ambient = AmbientMemory()
    ambient.emit_event("tool_call", tool_name="Read", tool_input={...}, tool_output={...})
"""

import json
import asyncio
import hashlib
import os
import sys
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Literal, Callable
from dataclasses import dataclass, asdict, field
from pathlib import Path
from enum import Enum
import redis
import psycopg2
from psycopg2.extras import Json

# Add parent path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

try:
    from surprise_memory import MemorySystem, MemoryItem, SurpriseScore
except ImportError:
    MemorySystem = None

# Try to import A-Mem for post-emit hook
try:
    from .associative_memory import process_new_event as amem_process
    AMEM_AVAILABLE = True
except ImportError:
    try:
        from associative_memory import process_new_event as amem_process
        AMEM_AVAILABLE = True
    except ImportError:
        AMEM_AVAILABLE = False
        amem_process = None


class EventType(Enum):
    """Types of events that flow through the bloodstream."""
    TOOL_CALL = "tool_call"
    DECISION = "decision"
    PLAN = "plan"
    ERROR = "error"
    LEARNING = "learning"
    USER_INPUT = "user_input"
    AGENT_THOUGHT = "agent_thought"
    SESSION_START = "session_start"
    SESSION_END = "session_end"


@dataclass
class AmbientEvent:
    """An event captured from the ambient memory stream."""
    event_id: str
    event_type: EventType
    timestamp: str
    session_id: str
    content: Dict[str, Any]
    metadata: Dict[str, Any] = field(default_factory=dict)
    surprise_score: float = 0.0
    tier: str = "pending"

    def to_dict(self) -> Dict:
        return {
            "event_id": self.event_id,
            "event_type": self.event_type.value,
            "timestamp": self.timestamp,
            "session_id": self.session_id,
            "content": self.content,
            "metadata": self.metadata,
            "surprise_score": self.surprise_score,
            "tier": self.tier
        }

    @classmethod
    def from_dict(cls, data: Dict) -> 'AmbientEvent':
        return cls(
            event_id=data["event_id"],
            event_type=EventType(data["event_type"]),
            timestamp=data["timestamp"],
            session_id=data["session_id"],
            content=data["content"],
            metadata=data.get("metadata", {}),
            surprise_score=data.get("surprise_score", 0.0),
            tier=data.get("tier", "pending")
        )


class EventStore:
    """
    Event sourcing backbone for Genesis Ambient Memory.

    Uses Redis Streams for real-time consumers and PostgreSQL for durability.
    Every event is immutable - enables time-travel and full audit trail.
    """

    def __init__(
        self,
        redis_url: str = None,
        postgres_config: Dict = None
    ):
        # Redis for real-time streaming
        self.redis_url = redis_url or os.getenv(
            "GENESIS_REDIS_URL",
            "redis://default:e2ZyYYr4oWRdASI2CaLc-@redis-genesis-u50607.vm.elestio.app:26379"
        )

        # PostgreSQL for durable storage
        self.postgres_config = postgres_config or {
            "host": os.getenv("GENESIS_PG_HOST", "postgresql-genesis-u50607.vm.elestio.app"),
            "port": int(os.getenv("GENESIS_PG_PORT", "25432")),
            "database": os.getenv("GENESIS_PG_DATABASE", "postgres"),
            "user": os.getenv("GENESIS_PG_USER", "postgres"),
            "password": os.getenv("GENESIS_PG_PASSWORD", "etY0eog17tD-dDuj--IRH")
        }

        self._redis = None
        self._pg_conn = None
        self._ensure_schema()

    def _get_redis(self) -> redis.Redis:
        """Lazy Redis connection."""
        if self._redis is None:
            self._redis = redis.from_url(self.redis_url)
        return self._redis

    def _get_postgres(self):
        """Lazy PostgreSQL connection."""
        if self._pg_conn is None or self._pg_conn.closed:
            self._pg_conn = psycopg2.connect(**self.postgres_config)
            self._pg_conn.autocommit = True
        return self._pg_conn

    def _ensure_schema(self):
        """Create event store tables if they don't exist."""
        try:
            conn = self._get_postgres()
            with conn.cursor() as cur:
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS ambient_events (
                        event_id VARCHAR(64) PRIMARY KEY,
                        event_type VARCHAR(50) NOT NULL,
                        session_id VARCHAR(64) NOT NULL,
                        timestamp TIMESTAMPTZ NOT NULL,
                        content JSONB NOT NULL,
                        metadata JSONB DEFAULT '{}',
                        surprise_score FLOAT DEFAULT 0.0,
                        tier VARCHAR(20) DEFAULT 'pending',
                        created_at TIMESTAMPTZ DEFAULT NOW(),
                        processed_at TIMESTAMPTZ,
                        CONSTRAINT valid_tier CHECK (tier IN ('pending', 'discarded', 'working', 'episodic', 'semantic', 'temporal'))
                    );

                    CREATE INDEX IF NOT EXISTS idx_ambient_events_session
                        ON ambient_events(session_id);
                    CREATE INDEX IF NOT EXISTS idx_ambient_events_type
                        ON ambient_events(event_type);
                    CREATE INDEX IF NOT EXISTS idx_ambient_events_timestamp
                        ON ambient_events(timestamp DESC);
                    CREATE INDEX IF NOT EXISTS idx_ambient_events_tier
                        ON ambient_events(tier);
                """)
        except Exception as e:
            print(f"Warning: Could not ensure schema: {e}", file=sys.stderr)

    def emit(self, event: AmbientEvent) -> str:
        """
        Emit an event to both Redis Stream and PostgreSQL.
        Returns the event_id.
        """
        event_data = event.to_dict()

        # 1. Append to Redis Stream (real-time consumers)
        try:
            r = self._get_redis()
            r.xadd(
                "genesis:ambient:events",
                {"data": json.dumps(event_data)},
                maxlen=10000  # Keep last 10k events in stream
            )
        except Exception as e:
            print(f"Redis emit warning: {e}", file=sys.stderr)

        # 2. Persist to PostgreSQL (durability)
        try:
            conn = self._get_postgres()
            with conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO ambient_events
                    (event_id, event_type, session_id, timestamp, content, metadata, surprise_score, tier)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
                    ON CONFLICT (event_id) DO UPDATE SET
                        surprise_score = EXCLUDED.surprise_score,
                        tier = EXCLUDED.tier,
                        processed_at = NOW()
                """, (
                    event.event_id,
                    event.event_type.value,
                    event.session_id,
                    event.timestamp,
                    Json(event.content),
                    Json(event.metadata),
                    event.surprise_score,
                    event.tier
                ))
        except Exception as e:
            print(f"PostgreSQL emit warning: {e}", file=sys.stderr)

        return event.event_id

    def track_access(self, event_ids: List[str]) -> int:
        """
        Track access for VoI calculations.

        Called by retrieval paths to increment access_count and
        update last_accessed for the given event IDs.

        Returns number of events updated.
        """
        if not event_ids:
            return 0

        try:
            conn = self._get_postgres()
            with conn.cursor() as cur:
                cur.execute("""
                    UPDATE ambient_events
                    SET access_count = COALESCE(access_count, 0) + 1,
                        last_accessed = NOW()
                    WHERE event_id = ANY(%s)
                """, (event_ids,))
                return cur.rowcount
        except Exception as e:
            print(f"Access tracking error: {e}", file=sys.stderr)
            return 0

    def replay(
        self,
        from_timestamp: datetime = None,
        session_id: str = None,
        event_type: EventType = None,
        limit: int = 1000,
        track_access: bool = True
    ) -> List[AmbientEvent]:
        """
        Replay events from a point in time.
        Enables time-travel debugging and state reconstruction.

        If track_access=True (default), updates access_count and
        last_accessed for VoI calculations.
        """
        try:
            conn = self._get_postgres()
            with conn.cursor() as cur:
                query = "SELECT * FROM ambient_events WHERE 1=1"
                params = []

                if from_timestamp:
                    query += " AND timestamp >= %s"
                    params.append(from_timestamp)

                if session_id:
                    query += " AND session_id = %s"
                    params.append(session_id)

                if event_type:
                    query += " AND event_type = %s"
                    params.append(event_type.value)

                query += " ORDER BY timestamp ASC LIMIT %s"
                params.append(limit)

                cur.execute(query, params)
                rows = cur.fetchall()

                events = []
                event_ids = []
                for row in rows:
                    event_ids.append(row[0])
                    events.append(AmbientEvent(
                        event_id=row[0],
                        event_type=EventType(row[1]),
                        session_id=row[2],
                        timestamp=row[3].isoformat(),
                        content=row[4],
                        metadata=row[5],
                        surprise_score=row[6],
                        tier=row[7]
                    ))

                # Track access for VoI calculations
                if track_access and event_ids:
                    self.track_access(event_ids)

                return events
        except Exception as e:
            print(f"Replay error: {e}", file=sys.stderr)
            return []

    def get_session_events(self, session_id: str) -> List[AmbientEvent]:
        """Get all events for a specific session."""
        return self.replay(session_id=session_id)


class AmbientExtractor:
    """
    Auto-extracts structured memories from raw events.
    Uses surprise scoring to determine what's worth keeping.
    """

    def __init__(self, memory_system: 'MemorySystem' = None):
        self.memory = memory_system or (MemorySystem() if MemorySystem else None)

    def extract(self, event: AmbientEvent) -> Dict[str, Any]:
        """
        Extract memory-worthy information from an event.
        Returns surprise score and recommended tier.
        """
        # Build content string for surprise calculation
        content = self._event_to_content(event)

        # Determine domain from event type
        domain = self._event_type_to_domain(event.event_type)

        # Calculate surprise score
        if self.memory:
            result = self.memory.evaluate(content, source="ambient", domain=domain)
            return {
                "content": content,
                "surprise_score": result["score"]["total"],
                "tier": self._score_to_tier(result["score"]["total"]),
                "score_breakdown": result["score"]
            }
        else:
            # Fallback: simple heuristic scoring
            score = self._heuristic_score(event)
            return {
                "content": content,
                "surprise_score": score,
                "tier": self._score_to_tier(score),
                "score_breakdown": {"heuristic": score}
            }

    def _event_to_content(self, event: AmbientEvent) -> str:
        """Convert event to content string for analysis."""
        parts = []

        if event.event_type == EventType.TOOL_CALL:
            tool = event.content.get("tool_name", "unknown")
            parts.append(f"Tool '{tool}' was called")
            if event.content.get("tool_output"):
                output = str(event.content["tool_output"])[:200]
                parts.append(f"with result: {output}")

        elif event.event_type == EventType.DECISION:
            parts.append(f"Decision made: {event.content.get('decision', 'unknown')}")
            if event.content.get("reasoning"):
                parts.append(f"Reasoning: {event.content['reasoning']}")

        elif event.event_type == EventType.PLAN:
            parts.append(f"Plan created: {event.content.get('summary', json.dumps(event.content)[:200])}")

        elif event.event_type == EventType.ERROR:
            parts.append(f"Error occurred: {event.content.get('error', 'unknown')}")

        elif event.event_type == EventType.LEARNING:
            parts.append(f"Learning: {event.content.get('learning', json.dumps(event.content)[:200])}")

        else:
            parts.append(json.dumps(event.content)[:300])

        return " ".join(parts)

    def _event_type_to_domain(self, event_type: EventType) -> str:
        """Map event type to memory domain."""
        mapping = {
            EventType.TOOL_CALL: "technical",
            EventType.DECISION: "decision",
            EventType.PLAN: "decision",
            EventType.ERROR: "error",
            EventType.LEARNING: "learning",
            EventType.USER_INPUT: "general",
            EventType.AGENT_THOUGHT: "learning",
            EventType.SESSION_START: "general",
            EventType.SESSION_END: "general"
        }
        return mapping.get(event_type, "general")

    def _score_to_tier(self, score: float) -> str:
        """Map surprise score to memory tier."""
        if score < 0.3:
            return "discarded"
        elif score < 0.5:
            return "working"
        elif score < 0.7:
            return "episodic"
        else:
            return "semantic"

    def _heuristic_score(self, event: AmbientEvent) -> float:
        """Fallback heuristic scoring when MemorySystem unavailable."""
        base_scores = {
            EventType.ERROR: 0.8,
            EventType.LEARNING: 0.75,
            EventType.DECISION: 0.7,
            EventType.PLAN: 0.65,
            EventType.TOOL_CALL: 0.4,
            EventType.USER_INPUT: 0.5,
            EventType.AGENT_THOUGHT: 0.45,
            EventType.SESSION_START: 0.3,
            EventType.SESSION_END: 0.35
        }
        return base_scores.get(event.event_type, 0.4)


class AmbientMemory:
    """
    Main interface for Genesis Ambient Memory System.

    This is the entry point called by hooks. It captures events
    automatically and routes them through the memory bloodstream.

    Enhanced with:
    - Post-emit hook for A-Mem associative memory engine
    - Access tracking for VoI calculations
    - Background async task support

    Usage:
        ambient = AmbientMemory()

        # Called automatically by hooks:
        ambient.emit_event(
            event_type="tool_call",
            tool_name="Read",
            tool_input={"file_path": "/some/file"},
            tool_output={"content": "..."}
        )
    """

    # Configuration
    AMEM_SURPRISE_THRESHOLD = 0.5  # Only run A-Mem on high-surprise events

    def __init__(
        self,
        session_id: str = None,
        redis_url: str = None,
        postgres_config: Dict = None,
        enable_amem: bool = True
    ):
        self.session_id = session_id or self._generate_session_id()
        self.event_store = EventStore(redis_url, postgres_config)
        self.extractor = AmbientExtractor()
        self.enable_amem = enable_amem and AMEM_AVAILABLE

        # Track events in current session
        self._session_events: List[str] = []

        # Background task queue for async operations
        self._pending_tasks: List[asyncio.Task] = []

    def _generate_session_id(self) -> str:
        """Generate unique session ID."""
        timestamp = datetime.now().isoformat()
        return hashlib.sha256(f"session-{timestamp}".encode()).hexdigest()[:16]

    def _generate_event_id(self, content: Dict) -> str:
        """Generate unique event ID."""
        timestamp = datetime.now().isoformat()
        content_hash = hashlib.sha256(json.dumps(content, sort_keys=True).encode()).hexdigest()[:8]
        return f"evt-{timestamp[:19].replace(':', '')}-{content_hash}"

    def emit_event(
        self,
        event_type: str,
        **content
    ) -> Dict[str, Any]:
        """
        Emit an ambient event. This is the main entry point.

        Enhanced with post-emit A-Mem hook for association detection.

        Args:
            event_type: Type of event (tool_call, decision, plan, error, learning, etc.)
            **content: Event content as keyword arguments

        Returns:
            Dict with event_id, surprise_score, tier, and associations
        """
        # Create event
        event = AmbientEvent(
            event_id=self._generate_event_id(content),
            event_type=EventType(event_type),
            timestamp=datetime.now().isoformat(),
            session_id=self.session_id,
            content=content,
            metadata={
                "hostname": os.uname().nodename if hasattr(os, 'uname') else "unknown",
                "cwd": os.getcwd()
            }
        )

        # Extract memory-worthy content and score
        extraction = self.extractor.extract(event)
        event.surprise_score = extraction["surprise_score"]
        event.tier = extraction["tier"]

        # Store in event stream
        self.event_store.emit(event)
        self._session_events.append(event.event_id)

        # Prepare result
        result = {
            "event_id": event.event_id,
            "surprise_score": event.surprise_score,
            "tier": event.tier,
            "stored": event.tier not in ["discarded", "working"],
            "associations": []
        }

        # POST-EMIT HOOK: A-Mem associative memory detection
        # Run async in background for high-surprise events
        if (self.enable_amem and
            event.surprise_score >= self.AMEM_SURPRISE_THRESHOLD and
            amem_process is not None):
            try:
                # Try to run async, fall back to sync
                loop = asyncio.get_event_loop()
                if loop.is_running():
                    # Schedule for later execution
                    task = asyncio.create_task(
                        amem_process(event.event_id, content, event.surprise_score)
                    )
                    self._pending_tasks.append(task)
                else:
                    # Run synchronously
                    associations = loop.run_until_complete(
                        amem_process(event.event_id, content, event.surprise_score)
                    )
                    result["associations"] = associations
            except RuntimeError:
                # No event loop, skip A-Mem
                pass
            except Exception as e:
                # Don't fail the main emit on A-Mem errors
                result["amem_error"] = str(e)

        return result

    def emit_tool_call(
        self,
        tool_name: str,
        tool_input: Dict,
        tool_output: Dict = None,
        success: bool = True
    ) -> Dict[str, Any]:
        """Convenience method for tool call events."""
        return self.emit_event(
            event_type="tool_call",
            tool_name=tool_name,
            tool_input=tool_input,
            tool_output=tool_output,
            success=success
        )

    def emit_decision(
        self,
        decision: str,
        reasoning: str = None,
        alternatives: List[str] = None
    ) -> Dict[str, Any]:
        """Convenience method for decision events."""
        return self.emit_event(
            event_type="decision",
            decision=decision,
            reasoning=reasoning,
            alternatives=alternatives or []
        )

    def emit_plan(
        self,
        summary: str,
        steps: List[str] = None,
        context: str = None
    ) -> Dict[str, Any]:
        """Convenience method for plan events."""
        return self.emit_event(
            event_type="plan",
            summary=summary,
            steps=steps or [],
            context=context
        )

    def emit_error(
        self,
        error: str,
        stack_trace: str = None,
        context: Dict = None
    ) -> Dict[str, Any]:
        """Convenience method for error events."""
        return self.emit_event(
            event_type="error",
            error=error,
            stack_trace=stack_trace,
            context=context or {}
        )

    def emit_learning(
        self,
        learning: str,
        source: str = None,
        confidence: float = 1.0
    ) -> Dict[str, Any]:
        """Convenience method for learning events."""
        return self.emit_event(
            event_type="learning",
            learning=learning,
            source=source,
            confidence=confidence
        )

    def get_session_summary(self) -> Dict[str, Any]:
        """Get summary of current session's events."""
        events = self.event_store.get_session_events(self.session_id)

        tier_counts = {}
        type_counts = {}
        for event in events:
            tier_counts[event.tier] = tier_counts.get(event.tier, 0) + 1
            type_counts[event.event_type.value] = type_counts.get(event.event_type.value, 0) + 1

        avg_score = 0
        if events:
            avg_score = sum(e.surprise_score for e in events) / len(events)

        return {
            "session_id": self.session_id,
            "total_events": len(events),
            "by_tier": tier_counts,
            "by_type": type_counts,
            "average_surprise": round(avg_score, 3),
            "stored_count": sum(1 for e in events if e.tier in ["episodic", "semantic"])
        }

    def replay_session(self, session_id: str = None) -> List[Dict]:
        """Replay all events from a session."""
        sid = session_id or self.session_id
        events = self.event_store.get_session_events(sid)
        return [e.to_dict() for e in events]

    def time_travel(self, to_timestamp: datetime) -> List[Dict]:
        """Get state of all events up to a specific time."""
        events = self.event_store.replay(from_timestamp=datetime.min)
        return [e.to_dict() for e in events if datetime.fromisoformat(e.timestamp) <= to_timestamp]


# CLI Interface - Called by hooks
def main():
    """
    CLI interface for hook integration.

    Usage:
        # Emit a tool call event
        python genesis_ambient_memory.py emit tool_call "Read" '{"file": "x.py"}' '{"content": "..."}'

        # Get session summary
        python genesis_ambient_memory.py summary

        # Replay session
        python genesis_ambient_memory.py replay [session_id]
    """
    if len(sys.argv) < 2:
        print("Genesis Ambient Memory - Memory in the Bloodstream")
        print("\nUsage:")
        print("  emit <event_type> <tool_name> [tool_input_json] [tool_output_json]")
        print("  summary")
        print("  replay [session_id]")
        print("  stats")
        sys.exit(0)

    command = sys.argv[1]

    # Get or create session ID from environment
    session_id = os.getenv("GENESIS_SESSION_ID")
    ambient = AmbientMemory(session_id=session_id)

    if command == "emit":
        if len(sys.argv) < 4:
            print("Usage: emit <event_type> <tool_name> [tool_input] [tool_output]")
            sys.exit(1)

        event_type = sys.argv[2]
        tool_name = sys.argv[3]
        tool_input = json.loads(sys.argv[4]) if len(sys.argv) > 4 else {}
        tool_output = json.loads(sys.argv[5]) if len(sys.argv) > 5 else {}

        if event_type == "tool_call":
            result = ambient.emit_tool_call(tool_name, tool_input, tool_output)
        else:
            result = ambient.emit_event(event_type, **{"tool_name": tool_name, **tool_input, **tool_output})

        print(json.dumps(result, indent=2))

    elif command == "summary":
        summary = ambient.get_session_summary()
        print(json.dumps(summary, indent=2))

    elif command == "replay":
        session_id = sys.argv[2] if len(sys.argv) > 2 else None
        events = ambient.replay_session(session_id)
        print(json.dumps(events, indent=2))

    elif command == "stats":
        # Get overall stats from PostgreSQL
        try:
            conn = ambient.event_store._get_postgres()
            with conn.cursor() as cur:
                cur.execute("""
                    SELECT
                        COUNT(*) as total,
                        COUNT(DISTINCT session_id) as sessions,
                        AVG(surprise_score) as avg_score,
                        COUNT(*) FILTER (WHERE tier = 'semantic') as semantic_count,
                        COUNT(*) FILTER (WHERE tier = 'episodic') as episodic_count
                    FROM ambient_events
                    WHERE timestamp > NOW() - INTERVAL '7 days'
                """)
                row = cur.fetchone()
                print(json.dumps({
                    "total_events_7d": row[0],
                    "unique_sessions": row[1],
                    "avg_surprise_score": round(float(row[2] or 0), 3),
                    "semantic_memories": row[3],
                    "episodic_memories": row[4]
                }, indent=2))
        except Exception as e:
            print(f"Error getting stats: {e}")
            sys.exit(1)

    else:
        print(f"Unknown command: {command}")
        sys.exit(1)


if __name__ == "__main__":
    main()
