"""
Genesis Memory Bridge — MCP Client for Sovereign Memory
=========================================================
Connects to the Graphiti MCP endpoint at 152.53.201.221:8001/mcp.
Used by both Claude Code and Gemini CLI to read/write shared memory.

Usage:
    from core.memory_bridge import MemoryBridge

    bridge = MemoryBridge()
    bridge.save("My learning", "genesis-kinan", name="Learning Title")
    results = bridge.search("Sunaiva pricing", "genesis-kinan")
    facts = bridge.search_facts("product architecture", "genesis-kinan")
    episodes = bridge.get_episodes("genesis-kinan")
    status = bridge.health()
"""

import json
import urllib.request
import urllib.error
import uuid
import logging
from typing import Optional, List, Dict, Any

logger = logging.getLogger(__name__)

MCP_ENDPOINT = "http://152.53.201.221:8001/mcp"
CLIENT_INFO = {"name": "genesis-memory-bridge", "version": "1.0.0"}
PROTOCOL_VERSION = "2025-03-26"


class MemoryBridge:
    """MCP client for the sovereign Graphiti memory endpoint."""

    def __init__(self, endpoint: str = MCP_ENDPOINT):
        self.endpoint = endpoint
        self._session_id: Optional[str] = None

    def _init_session(self) -> str:
        """Initialize an MCP session and return the session ID."""
        payload = json.dumps({
            "jsonrpc": "2.0",
            "method": "initialize",
            "params": {
                "protocolVersion": PROTOCOL_VERSION,
                "capabilities": {},
                "clientInfo": CLIENT_INFO,
            },
            "id": f"init-{uuid.uuid4().hex[:8]}",
        }).encode("utf-8")

        req = urllib.request.Request(
            self.endpoint,
            data=payload,
            headers={
                "Content-Type": "application/json",
                "Accept": "application/json, text/event-stream",
            },
            method="POST",
        )

        with urllib.request.urlopen(req, timeout=15) as resp:
            session_id = resp.headers.get("mcp-session-id", "")
            if not session_id:
                for header, value in resp.headers.items():
                    if header.lower() == "mcp-session-id":
                        session_id = value
                        break

            self._session_id = session_id
            return session_id

    def _call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
        """Call an MCP tool with proper session handling."""
        session_id = self._init_session()

        if not session_id:
            return {"error": "Failed to initialize MCP session"}

        payload = json.dumps({
            "jsonrpc": "2.0",
            "method": "tools/call",
            "params": {
                "name": tool_name,
                "arguments": arguments,
            },
            "id": f"{tool_name}-{uuid.uuid4().hex[:8]}",
        }).encode("utf-8")

        req = urllib.request.Request(
            self.endpoint,
            data=payload,
            headers={
                "Content-Type": "application/json",
                "Accept": "application/json, text/event-stream",
                "Mcp-Session-Id": session_id,
            },
            method="POST",
        )

        try:
            with urllib.request.urlopen(req, timeout=120) as resp:
                body = resp.read().decode("utf-8")
                if body.startswith("event:"):
                    for line in body.split("\n"):
                        if line.startswith("data: "):
                            data = json.loads(line[6:])
                            result = data.get("result", {})
                            if result.get("isError"):
                                error_text = ""
                                for content in result.get("content", []):
                                    error_text += content.get("text", "")
                                return {"error": error_text}
                            structured = result.get("structuredContent")
                            if structured:
                                return structured
                            for content in result.get("content", []):
                                text = content.get("text", "")
                                try:
                                    return json.loads(text)
                                except json.JSONDecodeError:
                                    return {"message": text}
                            return result
                else:
                    return json.loads(body)
        except urllib.error.URLError as e:
            logger.error(f"MCP call failed: {e}")
            return {"error": str(e)}
        except Exception as e:
            logger.error(f"MCP call error: {e}")
            return {"error": str(e)}

    # ============================================
    # PUBLIC API
    # ============================================

    def save(self, content: str, group_id: str = "genesis-kinan", name: Optional[str] = None,
             source: str = "text", source_description: str = "") -> Dict[str, Any]:
        if not name:
            name = f"Memory-{uuid.uuid4().hex[:8]}"
        args: Dict[str, Any] = {"name": name, "episode_body": content, "group_id": group_id, "source": source}
        if source_description:
            args["source_description"] = source_description
        result = self._call_tool("add_memory", args)
        logger.info(f"Memory saved: {name} -> {result}")
        return result

    def search(self, query: str, group_id: str = "genesis-kinan", max_nodes: int = 10) -> List[Dict[str, Any]]:
        args: Dict[str, Any] = {"query": query, "group_ids": [group_id], "max_nodes": max_nodes}
        result = self._call_tool("search_nodes", args)
        if "error" in result:
            logger.warning(f"Search failed: {result['error']}")
            return []
        return result.get("nodes", [])

    def search_facts(self, query: str, group_id: str = "genesis-kinan", max_facts: int = 10) -> List[Dict[str, Any]]:
        args: Dict[str, Any] = {"query": query, "group_ids": [group_id], "max_facts": max_facts}
        result = self._call_tool("search_memory_facts", args)
        if "error" in result:
            logger.warning(f"Fact search failed: {result['error']}")
            return []
        return result.get("facts", [])

    def get_episodes(self, group_id: str = "genesis-kinan", max_episodes: int = 10) -> List[Dict[str, Any]]:
        args: Dict[str, Any] = {"group_ids": [group_id], "max_episodes": max_episodes}
        result = self._call_tool("get_episodes", args)
        if "error" in result:
            logger.warning(f"Get episodes failed: {result['error']}")
            return []
        return result.get("episodes", [])

    def health(self) -> Dict[str, Any]:
        return self._call_tool("get_status", {})


# Global instance for convenience
memory = MemoryBridge()
