#!/usr/bin/env python3
"""
Genesis Associative Memory Engine (A-Mem)
==========================================
Autonomous Non-Obvious Connection Discovery

When a new memory is ingested, this engine:
1. Retrieves semantically similar memories
2. Uses Claude Haiku to detect non-obvious relationships
3. Stores associations with confidence scores

6 Relationship Types:
- CONTRADICTS: Old belief vs new info
- SAFETY_ALERT: Allergy + restaurant, medical + activity
- PREREQUISITE: Must do X before Y
- REINFORCES: Multiple sources confirm
- TEMPORAL_UPDATE: Info has changed over time
- CROSS_DOMAIN: Unexpected cross-topic connection

Cost: ~$0.0004/event × 200 events/day = $0.08/day
"""

import asyncio
import json
import os
import uuid
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from enum import Enum
import psycopg2
from psycopg2.extras import Json, RealDictCursor

# Use Anthropic directly (not Modal) for cost control
try:
    import anthropic
    ANTHROPIC_AVAILABLE = True
except ImportError:
    ANTHROPIC_AVAILABLE = False
    print("Warning: anthropic package not available")


class RelationshipType(Enum):
    """The 6 types of memory associations."""
    CONTRADICTS = "contradicts"        # Old belief vs new info
    SAFETY_ALERT = "safety_alert"      # Safety-critical connections
    PREREQUISITE = "prerequisite"      # Must do X before Y
    REINFORCES = "reinforces"          # Multiple sources confirm
    TEMPORAL_UPDATE = "temporal_update"  # Info changed over time
    CROSS_DOMAIN = "cross_domain"      # Unexpected connections


@dataclass
class MemoryAssociation:
    """A detected association between memories."""
    source_event_id: str
    target_event_id: str
    relationship_type: RelationshipType
    confidence: float
    reasoning: str
    detected_at: datetime = None

    def __post_init__(self):
        if self.detected_at is None:
            self.detected_at = datetime.now()


class AssociativeMemoryEngine:
    """
    A-Mem: Autonomous associative memory engine.

    Finds non-obvious connections between memories using Claude Haiku.
    Designed for low-latency, low-cost operation.
    """

    # Configuration
    SURPRISE_THRESHOLD = 0.5  # Only analyze high-surprise events
    SIMILAR_MEMORIES_LIMIT = 20  # How many similar memories to retrieve
    CONFIDENCE_THRESHOLD = 0.7  # Minimum confidence to store association
    HAIKU_MODEL = "claude-3-5-haiku-20241022"

    # System prompt for association detection
    ASSOCIATION_PROMPT = """You are a cognitive association engine for Genesis, an autonomous AI system.

Your task: Analyze a NEW memory and a set of EXISTING memories to find non-obvious but meaningful connections.

RELATIONSHIP TYPES (choose one):
- CONTRADICTS: The new info contradicts or conflicts with existing info
- SAFETY_ALERT: Safety-critical connection (allergies, medical, security risks)
- PREREQUISITE: One thing must happen before another
- REINFORCES: Multiple independent sources confirm the same thing
- TEMPORAL_UPDATE: Information has changed over time
- CROSS_DOMAIN: Unexpected connection between different domains/topics

CRITICAL RULES:
1. Only report associations with confidence >= 0.7
2. Prioritize SAFETY_ALERT - these are most important
3. Look for non-obvious connections, not just keyword matches
4. Be specific in your reasoning
5. Return empty list if no meaningful associations found

OUTPUT FORMAT (JSON array):
[
  {
    "target_id": "<event_id of related memory>",
    "relationship_type": "<one of the 6 types>",
    "confidence": 0.7-1.0,
    "reasoning": "<brief explanation of the connection>"
  }
]

Return ONLY the JSON array, no other text."""

    def __init__(
        self,
        postgres_config: Dict = None,
        anthropic_api_key: str = None
    ):
        self.postgres_config = postgres_config or {
            "host": os.getenv("GENESIS_PG_HOST", "postgresql-genesis-u50607.vm.elestio.app"),
            "port": int(os.getenv("GENESIS_PG_PORT", "25432")),
            "database": os.getenv("GENESIS_PG_DATABASE", "postgres"),
            "user": os.getenv("GENESIS_PG_USER", "postgres"),
            "password": os.getenv("GENESIS_PG_PASSWORD", "etY0eog17tD-dDuj--IRH")
        }

        self.api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY")
        self._client = None
        self._conn = None

    def _get_conn(self):
        """Get PostgreSQL connection."""
        if self._conn is None or self._conn.closed:
            self._conn = psycopg2.connect(**self.postgres_config)
            self._conn.autocommit = True
        return self._conn

    def _get_client(self):
        """Get Anthropic client."""
        if self._client is None and ANTHROPIC_AVAILABLE:
            self._client = anthropic.Anthropic(api_key=self.api_key)
        return self._client

    async def find_associations(
        self,
        event_id: str,
        content: Dict,
        surprise_score: float
    ) -> List[MemoryAssociation]:
        """
        Find associations for a newly ingested event.

        Called automatically after emit_event() in genesis_ambient_memory.py

        Args:
            event_id: The new event's ID
            content: The event content
            surprise_score: The event's surprise score

        Returns:
            List of detected associations
        """
        # Skip low-surprise events (they're likely routine)
        if surprise_score < self.SURPRISE_THRESHOLD:
            return []

        # Get similar memories
        similar = await self._get_similar_memories(event_id, content)
        if not similar:
            return []

        # Use Haiku to find associations
        associations = await self._detect_associations(event_id, content, similar)

        # Store valid associations
        stored = await self._store_associations(event_id, associations)

        return stored

    async def _get_similar_memories(
        self,
        event_id: str,
        content: Dict
    ) -> List[Dict]:
        """Retrieve semantically similar memories."""
        try:
            conn = self._get_conn()
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                # Extract key terms from content
                content_text = json.dumps(content).lower()

                # Get recent high-value memories (exclude self)
                cur.execute("""
                    SELECT event_id, event_type, content, surprise_score,
                           timestamp, tier, metadata
                    FROM ambient_events
                    WHERE event_id != %s
                      AND tier IN ('episodic', 'semantic')
                      AND surprise_score >= 0.4
                    ORDER BY timestamp DESC
                    LIMIT %s
                """, (event_id, self.SIMILAR_MEMORIES_LIMIT))

                return [dict(row) for row in cur.fetchall()]

        except Exception as e:
            print(f"Error getting similar memories: {e}")
            return []

    async def _detect_associations(
        self,
        event_id: str,
        content: Dict,
        similar_memories: List[Dict]
    ) -> List[MemoryAssociation]:
        """Use Claude Haiku to detect associations."""
        if not ANTHROPIC_AVAILABLE:
            return []

        client = self._get_client()
        if not client:
            return []

        # Prepare the prompt
        new_memory = {
            "event_id": event_id,
            "content": content
        }

        existing_memories = [
            {
                "event_id": m["event_id"],
                "type": m["event_type"],
                "content": m["content"] if isinstance(m["content"], dict) else json.loads(m["content"]) if m["content"] else {},
                "timestamp": m["timestamp"].isoformat() if m["timestamp"] else None
            }
            for m in similar_memories
        ]

        user_prompt = f"""NEW MEMORY:
{json.dumps(new_memory, indent=2, default=str)}

EXISTING MEMORIES:
{json.dumps(existing_memories, indent=2, default=str)}

Find meaningful associations between the NEW memory and EXISTING memories."""

        try:
            response = client.messages.create(
                model=self.HAIKU_MODEL,
                max_tokens=1024,
                system=self.ASSOCIATION_PROMPT,
                messages=[{"role": "user", "content": user_prompt}]
            )

            # Parse response
            response_text = response.content[0].text.strip()

            # Handle potential markdown code blocks
            if response_text.startswith("```"):
                response_text = response_text.split("```")[1]
                if response_text.startswith("json"):
                    response_text = response_text[4:]

            associations_data = json.loads(response_text)

            # Convert to MemoryAssociation objects
            associations = []
            for a in associations_data:
                if a.get("confidence", 0) >= self.CONFIDENCE_THRESHOLD:
                    try:
                        rel_type = RelationshipType(a["relationship_type"].lower())
                        associations.append(MemoryAssociation(
                            source_event_id=event_id,
                            target_event_id=a["target_id"],
                            relationship_type=rel_type,
                            confidence=a["confidence"],
                            reasoning=a.get("reasoning", "")
                        ))
                    except ValueError:
                        continue  # Invalid relationship type

            return associations

        except json.JSONDecodeError as e:
            print(f"Failed to parse Haiku response: {e}")
            return []
        except Exception as e:
            print(f"Haiku association detection error: {e}")
            return []

    async def _store_associations(
        self,
        source_id: str,
        associations: List[MemoryAssociation]
    ) -> List[MemoryAssociation]:
        """Store associations in PostgreSQL."""
        if not associations:
            return []

        stored = []
        try:
            conn = self._get_conn()
            with conn.cursor() as cur:
                for assoc in associations:
                    cur.execute("""
                        INSERT INTO memory_associations
                        (source_event_id, target_event_id, relationship_type,
                         confidence, reasoning, detected_at)
                        VALUES (%s, %s, %s, %s, %s, %s)
                        ON CONFLICT DO NOTHING
                    """, (
                        assoc.source_event_id,
                        assoc.target_event_id,
                        assoc.relationship_type.value,
                        assoc.confidence,
                        assoc.reasoning,
                        assoc.detected_at
                    ))
                    stored.append(assoc)

        except Exception as e:
            print(f"Error storing associations: {e}")

        return stored

    async def get_associations_for_event(
        self,
        event_id: str
    ) -> List[Dict]:
        """Get all associations for a specific event."""
        try:
            conn = self._get_conn()
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute("""
                    SELECT * FROM memory_associations
                    WHERE source_event_id = %s OR target_event_id = %s
                    ORDER BY confidence DESC
                """, (event_id, event_id))
                return [dict(row) for row in cur.fetchall()]
        except Exception as e:
            print(f"Error getting associations: {e}")
            return []

    async def get_safety_alerts(self) -> List[Dict]:
        """Get all SAFETY_ALERT associations (highest priority)."""
        try:
            conn = self._get_conn()
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute("""
                    SELECT ma.*,
                           ae1.content as source_content,
                           ae2.content as target_content
                    FROM memory_associations ma
                    JOIN ambient_events ae1 ON ma.source_event_id = ae1.event_id
                    JOIN ambient_events ae2 ON ma.target_event_id = ae2.event_id
                    WHERE ma.relationship_type = 'safety_alert'
                    ORDER BY ma.confidence DESC, ma.detected_at DESC
                """)
                return [dict(row) for row in cur.fetchall()]
        except Exception as e:
            print(f"Error getting safety alerts: {e}")
            return []

    async def get_contradictions(self) -> List[Dict]:
        """Get all CONTRADICTS associations for conflict resolution."""
        try:
            conn = self._get_conn()
            with conn.cursor(cursor_factory=RealDictCursor) as cur:
                cur.execute("""
                    SELECT ma.*,
                           ae1.content as source_content,
                           ae2.content as target_content
                    FROM memory_associations ma
                    JOIN ambient_events ae1 ON ma.source_event_id = ae1.event_id
                    JOIN ambient_events ae2 ON ma.target_event_id = ae2.event_id
                    WHERE ma.relationship_type = 'contradicts'
                    ORDER BY ma.detected_at DESC
                """)
                return [dict(row) for row in cur.fetchall()]
        except Exception as e:
            print(f"Error getting contradictions: {e}")
            return []

    def close(self):
        """Close connections."""
        if self._conn:
            self._conn.close()


# Singleton instance for use as hook
_engine_instance: Optional[AssociativeMemoryEngine] = None


def get_engine() -> AssociativeMemoryEngine:
    """Get singleton A-Mem engine instance."""
    global _engine_instance
    if _engine_instance is None:
        _engine_instance = AssociativeMemoryEngine()
    return _engine_instance


async def process_new_event(event_id: str, content: Dict, surprise_score: float) -> List[Dict]:
    """
    Hook function called after emit_event().

    Returns list of associations found (for logging/debugging).
    """
    engine = get_engine()
    associations = await engine.find_associations(event_id, content, surprise_score)
    return [
        {
            "target": a.target_event_id,
            "type": a.relationship_type.value,
            "confidence": a.confidence,
            "reasoning": a.reasoning
        }
        for a in associations
    ]


# CLI Interface
async def main():
    import sys

    if len(sys.argv) < 2:
        print("Genesis Associative Memory Engine (A-Mem)")
        print("\nUsage:")
        print("  python associative_memory.py safety-alerts    # Show safety alerts")
        print("  python associative_memory.py contradictions   # Show contradictions")
        print("  python associative_memory.py event <id>       # Show associations for event")
        print("  python associative_memory.py test             # Test with sample event")
        return

    engine = AssociativeMemoryEngine()
    command = sys.argv[1]

    try:
        if command == "safety-alerts":
            alerts = await engine.get_safety_alerts()
            print(json.dumps(alerts, indent=2, default=str))

        elif command == "contradictions":
            contradictions = await engine.get_contradictions()
            print(json.dumps(contradictions, indent=2, default=str))

        elif command == "event" and len(sys.argv) > 2:
            event_id = sys.argv[2]
            associations = await engine.get_associations_for_event(event_id)
            print(json.dumps(associations, indent=2, default=str))

        elif command == "test":
            # Test with a sample event
            test_content = {
                "tool_name": "Edit",
                "action": "Modified authentication logic",
                "file": "src/auth/login.py"
            }
            associations = await engine.find_associations(
                f"test-{uuid.uuid4().hex[:8]}",
                test_content,
                surprise_score=0.75
            )
            print(f"Found {len(associations)} associations:")
            for a in associations:
                print(f"  - {a.relationship_type.value}: {a.target_event_id}")
                print(f"    Confidence: {a.confidence}")
                print(f"    Reasoning: {a.reasoning}")

        else:
            print(f"Unknown command: {command}")

    finally:
        engine.close()


if __name__ == "__main__":
    asyncio.run(main())
