#!/usr/bin/env python3
"""
Letta ↔ MCP Memory Bridge
Syncs Letta agent memory to @modelcontextprotocol/server-memory knowledge graph

Phase 1: Core memory extraction and entity creation
"""

import json
import psycopg2
from datetime import datetime
from typing import List, Dict, Optional
import sys
import argparse

# Configuration
LETTA_DB = {
    "host": "localhost",
    "database": "letta",
    "user": "letta",
    "password": "genesis2025"
}

AIVA_AGENT_ID = "agent-35039b24-fbb9-4139-b036-4a04e9c3d6ac"
AIVA_MEMORY_BLOCK_ID = "block-3376a951-6e25-4d94-b96d-417efde9eb08"


class LettaMCPBridge:
    """Bridge between Letta PostgreSQL and MCP Memory knowledge graph"""

    def __init__(self, db_config: Dict):
        self.db_config = db_config
        self.conn = None

    def connect(self):
        """Connect to Letta PostgreSQL database"""
        try:
            self.conn = psycopg2.connect(**self.db_config)
            print(f"✓ Connected to Letta database: {self.db_config['database']}")
            return True
        except Exception as e:
            print(f"✗ Failed to connect to database: {e}")
            return False

    def disconnect(self):
        """Close database connection"""
        if self.conn:
            self.conn.close()
            print("✓ Database connection closed")

    def extract_core_memory(self, agent_id: str) -> Dict[str, str]:
        """Extract core memory blocks (persona, human) for agent"""
        query = """
            SELECT b.label, b.value, b.template_name
            FROM block b
            JOIN blocks_agents ba ON b.id = ba.block_id
            WHERE ba.agent_id = %s
            AND b.is_deleted = false
            ORDER BY b.label;
        """

        try:
            cursor = self.conn.cursor()
            cursor.execute(query, (agent_id,))
            results = cursor.fetchall()
            cursor.close()

            core_memory = {}
            for label, value, template in results:
                core_memory[label] = value
                print(f"  ✓ Extracted block: {label} ({len(value)} chars)")

            return core_memory

        except Exception as e:
            print(f"✗ Failed to extract core memory: {e}")
            return {}

    def extract_archival_memories(self, agent_id: str, since: Optional[datetime] = None) -> List[Dict]:
        """Extract archival passages for agent"""
        # First, get the archive_ids for this agent
        archive_query = """
            SELECT DISTINCT a.id, a.name
            FROM archives a
            JOIN archives_agents aa ON a.id = aa.archive_id
            WHERE aa.agent_id = %s
            AND a.is_deleted = false;
        """

        passages_query = """
            SELECT id, text, metadata_, created_at
            FROM archival_passages
            WHERE archive_id = %s
            AND is_deleted = false
        """
        if since:
            passages_query += " AND created_at > %s"
        passages_query += " ORDER BY created_at DESC;"

        try:
            cursor = self.conn.cursor()

            # Get archives
            cursor.execute(archive_query, (agent_id,))
            archives = cursor.fetchall()

            if not archives:
                print("  ⚠ No archives found for agent")
                return []

            # Get passages from all archives
            all_passages = []
            for archive_id, archive_name in archives:
                if since:
                    cursor.execute(passages_query, (archive_id, since))
                else:
                    cursor.execute(passages_query, (archive_id,))

                passages = cursor.fetchall()
                for passage_id, text, metadata, created_at in passages:
                    all_passages.append({
                        "id": passage_id,
                        "text": text,
                        "metadata": metadata,
                        "created_at": created_at.isoformat() if created_at else None,
                        "archive": archive_name
                    })

            cursor.close()
            print(f"  ✓ Extracted {len(all_passages)} archival passages")
            return all_passages

        except Exception as e:
            print(f"✗ Failed to extract archival memories: {e}")
            return []

    def extract_conversation_summary(self, agent_id: str, limit: int = 10) -> List[Dict]:
        """Extract recent conversation messages for context"""
        query = """
            SELECT role, text, created_at
            FROM messages
            WHERE agent_id = %s
            AND is_deleted = false
            AND text IS NOT NULL
            ORDER BY sequence_id DESC
            LIMIT %s;
        """

        try:
            cursor = self.conn.cursor()
            cursor.execute(query, (agent_id, limit))
            results = cursor.fetchall()
            cursor.close()

            messages = []
            for role, text, created_at in results:
                messages.append({
                    "role": role,
                    "text": text,
                    "created_at": created_at.isoformat() if created_at else None
                })

            print(f"  ✓ Extracted {len(messages)} recent messages")
            return messages

        except Exception as e:
            print(f"✗ Failed to extract messages: {e}")
            return []

    def transform_to_entities(self, core_memory: Dict, archival: List[Dict]) -> List[Dict]:
        """Transform Letta data into MCP entity format"""
        entities = []

        # Create Aiva entity from persona block
        if "persona" in core_memory:
            aiva_entity = {
                "name": "Aiva",
                "entityType": "agent",
                "observations": [
                    core_memory["persona"],
                    "Architecture: Letta + PostgreSQL + MCP Memory",
                    "Memory System: Bidirectional sync with knowledge graph",
                    f"Agent ID: {AIVA_AGENT_ID}",
                    f"Synced: {datetime.now().isoformat()}"
                ]
            }
            entities.append(aiva_entity)
            print(f"  ✓ Created entity: Aiva (agent)")

        # Create Kinan entity from human block
        if "human" in core_memory:
            kinan_entity = {
                "name": "Kinan",
                "entityType": "person",
                "observations": [
                    core_memory["human"],
                    "Role: Creator and director of Genesis System",
                    "Project: Building autonomous development framework",
                    f"Synced: {datetime.now().isoformat()}"
                ]
            }
            entities.append(kinan_entity)
            print(f"  ✓ Created entity: Kinan (person)")

        # Create memory entities from archival passages
        for passage in archival[:20]:  # Limit to 20 most recent for now
            memory_entity = {
                "name": f"Memory_{passage['id'][:8]}",
                "entityType": "memory",
                "observations": [
                    passage["text"],
                    f"Created: {passage['created_at']}",
                    f"Archive: {passage['archive']}",
                    "Source: Letta archival memory"
                ]
            }
            entities.append(memory_entity)
            print(f"  ✓ Created entity: Memory_{passage['id'][:8]} (memory)")

        return entities

    def create_relations(self, entities: List[Dict], archival_count: int) -> List[Dict]:
        """Create relations between entities"""
        relations = []

        # Check which entities exist
        has_aiva = any(e["name"] == "Aiva" for e in entities)
        has_kinan = any(e["name"] == "Kinan" for e in entities)
        has_genesis = True  # Assume Genesis System entity exists in graph

        # Aiva ← serves → Kinan
        if has_aiva and has_kinan:
            relations.append({
                "from": "Aiva",
                "to": "Kinan",
                "relationType": "serves"
            })
            print(f"  ✓ Created relation: Aiva → serves → Kinan")

        # Aiva ← uses → Genesis System
        if has_aiva and has_genesis:
            relations.append({
                "from": "Aiva",
                "to": "Genesis System",
                "relationType": "uses"
            })
            print(f"  ✓ Created relation: Aiva → uses → Genesis System")

        # Aiva ← created → Memory entities
        for entity in entities:
            if entity["entityType"] == "memory":
                relations.append({
                    "from": "Aiva",
                    "to": entity["name"],
                    "relationType": "created"
                })

        if archival_count > 0:
            print(f"  ✓ Created {archival_count} memory creation relations")

        return relations

    def full_sync(self) -> Dict:
        """Perform full sync of Letta data to MCP format"""
        print("\n=== FULL SYNC: Letta → MCP Memory ===\n")

        if not self.connect():
            return {"entities": [], "relations": [], "error": "Database connection failed"}

        try:
            # Extract data
            print("1. Extracting core memory...")
            core_memory = self.extract_core_memory(AIVA_AGENT_ID)

            print("\n2. Extracting archival memories...")
            archival = self.extract_archival_memories(AIVA_AGENT_ID)

            print("\n3. Extracting conversation summary...")
            messages = self.extract_conversation_summary(AIVA_AGENT_ID)

            # Transform to knowledge graph
            print("\n4. Transforming to MCP entities...")
            entities = self.transform_to_entities(core_memory, archival)

            print("\n5. Creating relations...")
            relations = self.create_relations(entities, len([e for e in entities if e["entityType"] == "memory"]))

            # Prepare output
            result = {
                "entities": entities,
                "relations": relations,
                "stats": {
                    "core_memory_blocks": len(core_memory),
                    "archival_passages": len(archival),
                    "recent_messages": len(messages),
                    "entities_created": len(entities),
                    "relations_created": len(relations),
                    "sync_time": datetime.now().isoformat()
                }
            }

            print("\n=== Sync Data Ready ===")
            print(f"Entities: {len(entities)}")
            print(f"Relations: {len(relations)}")

            return result

        except Exception as e:
            print(f"\n✗ Sync failed: {e}")
            import traceback
            traceback.print_exc()
            return {"entities": [], "relations": [], "error": str(e)}

        finally:
            self.disconnect()


def main():
    parser = argparse.ArgumentParser(description="Letta ↔ MCP Memory Bridge")
    parser.add_argument("--full", action="store_true", help="Full sync (all data)")
    parser.add_argument("--dry-run", action="store_true", help="Show what would be synced")
    parser.add_argument("--output", default="sync_data.json", help="Output file for sync data")

    args = parser.parse_args()

    bridge = LettaMCPBridge(LETTA_DB)

    if args.full or args.dry_run:
        result = bridge.full_sync()

        # Save to file
        output_path = f"/mnt/e/genesis-system/{args.output}"
        with open(output_path, "w") as f:
            json.dump(result, f, indent=2)

        print(f"\n✓ Sync data saved to: {output_path}")

        if args.dry_run:
            print("\n=== DRY RUN - No changes made ===")
            print("\nTo execute sync, run:")
            print("  python letta_mcp_bridge.py --full")
            print("\nThen use Claude Code to execute MCP tool calls:")
            print("  mcp__memory__create_entities(result['entities'])")
            print("  mcp__memory__create_relations(result['relations'])")

        return result

    else:
        parser.print_help()


if __name__ == "__main__":
    main()
