#!/usr/bin/env python3
"""
Sunaiva Memory MCP — Canonical Unified Server
===============================================
Consolidates three legacy memory systems into ONE canonical architecture:
  - System A: Sunaiva/ai-memory/mcp/ (FastMCP vault, keyword search)
  - System B: Sunaiva/memory_mcp/ (Google Takeout ingestor, 3072-dim Qdrant)
  - System C: mcp-servers/genesis-mem0/ (Mem0 + Qdrant 768-dim)

Architecture:
  - Qdrant:      768-dim vectors for semantic search (collection: sunaiva_memory_768)
  - PostgreSQL:  Metadata, user profiles, access control (Elestio)
  - Redis:       Hot cache for recent memories (Elestio)
  - Embeddings:  Gemini text-embedding-004 (768-dim output_dimensionality)

MCP Tools (6):
  - memory_store      — Store a memory with embeddings
  - memory_search     — Semantic search across memories
  - memory_get_all    — Retrieve all memories for a user
  - memory_delete     — Delete a specific memory
  - memory_summarize  — Summarize a user's memory context
  - memory_ingest_takeout — Import from Google Takeout export

Usage:
    # stdio transport (Claude Code MCP)
    python server.py --transport stdio --user-id kinan

    # SSE transport (remote connections)
    python server.py --transport sse --port 8100

Environment Variables (ALL credentials from env — no hardcoded secrets):
    GENESIS_QDRANT_HOST       Qdrant hostname
    GENESIS_QDRANT_PORT       Qdrant port (default: 6333)
    GENESIS_QDRANT_API_KEY    Qdrant API key
    GENESIS_POSTGRES_HOST     PostgreSQL hostname
    GENESIS_POSTGRES_PORT     PostgreSQL port (default: 25432)
    GENESIS_POSTGRES_USER     PostgreSQL user
    GENESIS_POSTGRES_PASSWORD PostgreSQL password
    GENESIS_POSTGRES_DATABASE PostgreSQL database name
    GENESIS_REDIS_HOST        Redis hostname
    GENESIS_REDIS_PORT        Redis port (default: 26379)
    GENESIS_REDIS_USER        Redis user
    GENESIS_REDIS_PASSWORD    Redis password
    GEMINI_API_KEY            Gemini API key for embeddings

VERIFICATION_STAMP
  Story: MEMORY-MCP-CONSOLIDATION
  Verified By: Claude Opus 4.6
  Verified At: 2026-02-26
  Tests: See tests/memory_mcp/test_sunaiva_memory.py
"""

import json
import hashlib
import logging
import os
import sys
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional

# ---------------------------------------------------------------------------
# MCP SDK
# ---------------------------------------------------------------------------
try:
    from mcp.server.fastmcp import FastMCP
except ImportError:
    print("ERROR: FastMCP not available — install with: pip install 'mcp[server]'", file=sys.stderr)
    sys.exit(1)

# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
    datefmt="%H:%M:%S",
)
log = logging.getLogger("sunaiva-memory")

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
EMBED_DIM = 768
QDRANT_COLLECTION = "sunaiva_memory_768"
PG_SCHEMA = "sunaiva_memory"
REDIS_PREFIX = "sunaiva:mem:"
REDIS_TTL = 3600  # 1 hour hot cache

# ---------------------------------------------------------------------------
# Configuration (all from environment — NO hardcoded secrets)
# ---------------------------------------------------------------------------

def _env(key: str, default: str = "") -> str:
    return os.environ.get(key, default)


class Config:
    """Centralised configuration — every credential comes from os.environ."""

    # Qdrant
    qdrant_host: str = _env("GENESIS_QDRANT_HOST")
    qdrant_port: int = int(_env("GENESIS_QDRANT_PORT", "6333"))
    qdrant_api_key: str = _env("GENESIS_QDRANT_API_KEY")

    # PostgreSQL
    pg_host: str = _env("GENESIS_POSTGRES_HOST")
    pg_port: int = int(_env("GENESIS_POSTGRES_PORT", "25432"))
    pg_user: str = _env("GENESIS_POSTGRES_USER", "postgres")
    pg_password: str = _env("GENESIS_POSTGRES_PASSWORD")
    pg_database: str = _env("GENESIS_POSTGRES_DATABASE", "postgres")

    # Redis
    redis_host: str = _env("GENESIS_REDIS_HOST")
    redis_port: int = int(_env("GENESIS_REDIS_PORT", "26379"))
    redis_user: str = _env("GENESIS_REDIS_USER", "default")
    redis_password: str = _env("GENESIS_REDIS_PASSWORD")

    # Gemini
    gemini_api_key: str = _env("GEMINI_API_KEY") or _env("GEMINI_API_KEY_NEW")

    @classmethod
    def qdrant_url(cls) -> str:
        return f"https://{cls.qdrant_host}:{cls.qdrant_port}"

    @classmethod
    def pg_dsn(cls) -> str:
        return (
            f"host={cls.pg_host} port={cls.pg_port} user={cls.pg_user} "
            f"password={cls.pg_password} dbname={cls.pg_database}"
        )


# ---------------------------------------------------------------------------
# Embedding engine
# ---------------------------------------------------------------------------

def get_embedding(text: str, api_key: Optional[str] = None) -> List[float]:
    """
    Generate a 768-dimensional embedding via Gemini text-embedding-004.

    Falls back to a zero vector if the API is unavailable, so storage
    never fails (search quality degrades gracefully).
    """
    key = api_key or Config.gemini_api_key
    if not key:
        log.warning("No GEMINI_API_KEY set — returning zero vector")
        return [0.0] * EMBED_DIM

    try:
        from google import genai as google_genai

        client = google_genai.Client(api_key=key)
        result = client.models.embed_content(
            model="text-embedding-004",
            contents=text,
            config={"output_dimensionality": EMBED_DIM},
        )
        vec = list(result.embeddings[0].values)
        if len(vec) != EMBED_DIM:
            log.warning(f"Embedding dim mismatch: got {len(vec)}, expected {EMBED_DIM}")
            return [0.0] * EMBED_DIM
        return vec
    except Exception as e:
        log.error(f"Embedding failed: {e}")
        return [0.0] * EMBED_DIM


# ---------------------------------------------------------------------------
# Qdrant client
# ---------------------------------------------------------------------------

_qdrant_client = None


def _get_qdrant():
    """Lazy singleton Qdrant client."""
    global _qdrant_client
    if _qdrant_client is not None:
        return _qdrant_client

    if not Config.qdrant_host or not Config.qdrant_api_key:
        log.warning("Qdrant not configured — vector operations will be unavailable")
        return None

    try:
        from qdrant_client import QdrantClient
        from qdrant_client.models import Distance, VectorParams

        _qdrant_client = QdrantClient(
            url=Config.qdrant_url(),
            api_key=Config.qdrant_api_key,
            timeout=30,
        )
        # Ensure collection exists
        existing = {c.name for c in _qdrant_client.get_collections().collections}
        if QDRANT_COLLECTION not in existing:
            _qdrant_client.create_collection(
                collection_name=QDRANT_COLLECTION,
                vectors_config=VectorParams(size=EMBED_DIM, distance=Distance.COSINE),
            )
            log.info(f"Created Qdrant collection: {QDRANT_COLLECTION}")
        return _qdrant_client
    except Exception as e:
        log.error(f"Qdrant init failed: {e}")
        _qdrant_client = None
        return None


# ---------------------------------------------------------------------------
# PostgreSQL helpers
# ---------------------------------------------------------------------------

_pg_pool = None


def _get_pg():
    """Lazy PostgreSQL connection pool."""
    global _pg_pool
    if _pg_pool is not None and not _pg_pool.closed:
        return _pg_pool

    if not Config.pg_host or not Config.pg_password:
        log.warning("PostgreSQL not configured — metadata operations unavailable")
        return None

    try:
        import psycopg2
        from psycopg2.pool import ThreadedConnectionPool

        _pg_pool = ThreadedConnectionPool(
            1, 5,
            host=Config.pg_host,
            port=Config.pg_port,
            user=Config.pg_user,
            password=Config.pg_password,
            database=Config.pg_database,
            connect_timeout=10,
        )
        # Ensure schema and table exist
        conn = _pg_pool.getconn()
        try:
            cur = conn.cursor()
            cur.execute(f"CREATE SCHEMA IF NOT EXISTS {PG_SCHEMA}")
            cur.execute(f"""
                CREATE TABLE IF NOT EXISTS {PG_SCHEMA}.memories (
                    id          TEXT PRIMARY KEY,
                    user_id     TEXT NOT NULL,
                    content     TEXT NOT NULL,
                    metadata    JSONB DEFAULT '{{}}'::jsonb,
                    created_at  TIMESTAMPTZ DEFAULT NOW(),
                    updated_at  TIMESTAMPTZ DEFAULT NOW()
                )
            """)
            cur.execute(f"""
                CREATE INDEX IF NOT EXISTS idx_memories_user_id
                ON {PG_SCHEMA}.memories (user_id)
            """)
            conn.commit()
            cur.close()
        finally:
            _pg_pool.putconn(conn)
        return _pg_pool
    except Exception as e:
        log.error(f"PostgreSQL init failed: {e}")
        _pg_pool = None
        return None


def _pg_execute(query: str, params: tuple = (), fetch: bool = False) -> Optional[List]:
    """Execute a query against PostgreSQL with connection pool management."""
    pool = _get_pg()
    if pool is None:
        return None
    conn = pool.getconn()
    try:
        cur = conn.cursor()
        cur.execute(query, params)
        result = cur.fetchall() if fetch else None
        conn.commit()
        cur.close()
        return result
    except Exception as e:
        conn.rollback()
        log.error(f"PostgreSQL query failed: {e}")
        return None
    finally:
        pool.putconn(conn)


# ---------------------------------------------------------------------------
# Redis helpers
# ---------------------------------------------------------------------------

_redis_client = None


def _get_redis():
    """Lazy Redis client."""
    global _redis_client
    if _redis_client is not None:
        try:
            _redis_client.ping()
            return _redis_client
        except Exception:
            _redis_client = None

    if not Config.redis_host or not Config.redis_password:
        return None

    try:
        import redis
        _redis_client = redis.Redis(
            host=Config.redis_host,
            port=Config.redis_port,
            username=Config.redis_user,
            password=Config.redis_password,
            decode_responses=True,
            socket_timeout=5,
        )
        _redis_client.ping()
        return _redis_client
    except Exception as e:
        log.warning(f"Redis unavailable: {e}")
        _redis_client = None
        return None


def _cache_set(key: str, value: str, ttl: int = REDIS_TTL) -> None:
    """Set a value in Redis cache."""
    r = _get_redis()
    if r:
        try:
            r.setex(f"{REDIS_PREFIX}{key}", ttl, value)
        except Exception:
            pass


def _cache_get(key: str) -> Optional[str]:
    """Get a value from Redis cache."""
    r = _get_redis()
    if r:
        try:
            return r.get(f"{REDIS_PREFIX}{key}")
        except Exception:
            return None
    return None


def _cache_delete(key: str) -> None:
    """Delete a value from Redis cache."""
    r = _get_redis()
    if r:
        try:
            r.delete(f"{REDIS_PREFIX}{key}")
        except Exception:
            pass


# ---------------------------------------------------------------------------
# Core memory operations
# ---------------------------------------------------------------------------

def _generate_memory_id(content: str, user_id: str) -> str:
    """Generate a deterministic memory ID from content + user_id."""
    h = hashlib.sha256(f"{user_id}:{content}".encode()).hexdigest()[:24]
    return f"mem_{h}"


def store_memory(content: str, user_id: str, metadata: Optional[Dict] = None) -> Dict[str, Any]:
    """
    Store a memory: embed → Qdrant + PostgreSQL + Redis cache.

    Returns:
        {"memory_id": str, "stored": bool, "vector_stored": bool, "metadata_stored": bool}
    """
    memory_id = _generate_memory_id(content, user_id)
    meta = metadata or {}
    meta["stored_at"] = datetime.now(timezone.utc).isoformat()
    result = {"memory_id": memory_id, "stored": False, "vector_stored": False, "metadata_stored": False}

    # 1. Embed and store in Qdrant
    embedding = get_embedding(content)
    qdrant = _get_qdrant()
    if qdrant is not None:
        try:
            from qdrant_client.models import PointStruct

            point = PointStruct(
                id=memory_id,
                vector=embedding,
                payload={
                    "user_id": user_id,
                    "content": content,
                    "metadata": meta,
                    "created_at": meta["stored_at"],
                },
            )
            qdrant.upsert(collection_name=QDRANT_COLLECTION, points=[point])
            result["vector_stored"] = True
        except Exception as e:
            log.error(f"Qdrant upsert failed: {e}")

    # 2. Store metadata in PostgreSQL
    pg_result = _pg_execute(
        f"""
        INSERT INTO {PG_SCHEMA}.memories (id, user_id, content, metadata, created_at, updated_at)
        VALUES (%s, %s, %s, %s::jsonb, NOW(), NOW())
        ON CONFLICT (id) DO UPDATE SET
            content = EXCLUDED.content,
            metadata = EXCLUDED.metadata,
            updated_at = NOW()
        """,
        (memory_id, user_id, content, json.dumps(meta)),
    )
    if pg_result is not None or _get_pg() is not None:
        result["metadata_stored"] = True

    # 3. Invalidate Redis cache for this user
    _cache_delete(f"all:{user_id}")
    _cache_delete(f"summary:{user_id}")

    result["stored"] = result["vector_stored"] or result["metadata_stored"]
    return result


def search_memories(query: str, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
    """
    Semantic search across a user's memories via Qdrant.

    Falls back to PostgreSQL keyword search if Qdrant is unavailable.
    """
    # Check Redis cache first
    cache_key = f"search:{user_id}:{hashlib.md5(f'{query}:{limit}'.encode()).hexdigest()[:12]}"
    cached = _cache_get(cache_key)
    if cached:
        try:
            return json.loads(cached)
        except (json.JSONDecodeError, TypeError):
            pass

    results = []

    # Try Qdrant semantic search
    qdrant = _get_qdrant()
    if qdrant is not None:
        try:
            from qdrant_client.models import Filter, FieldCondition, MatchValue

            query_vector = get_embedding(query)
            # Skip Qdrant search if we got a zero vector
            if not all(v == 0.0 for v in query_vector):
                hits = qdrant.search(
                    collection_name=QDRANT_COLLECTION,
                    query_vector=query_vector,
                    query_filter=Filter(
                        must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
                    ),
                    limit=limit,
                    score_threshold=0.3,
                )
                for hit in hits:
                    payload = hit.payload or {}
                    results.append({
                        "memory_id": hit.id,
                        "content": payload.get("content", ""),
                        "score": round(hit.score, 4),
                        "metadata": payload.get("metadata", {}),
                        "created_at": payload.get("created_at", ""),
                        "source": "qdrant",
                    })
        except Exception as e:
            log.error(f"Qdrant search failed: {e}")

    # Fallback: PostgreSQL keyword search
    if not results:
        rows = _pg_execute(
            f"""
            SELECT id, content, metadata, created_at
            FROM {PG_SCHEMA}.memories
            WHERE user_id = %s AND content ILIKE %s
            ORDER BY created_at DESC
            LIMIT %s
            """,
            (user_id, f"%{query}%", limit),
            fetch=True,
        )
        if rows:
            for row in rows:
                meta = row[2] if isinstance(row[2], dict) else json.loads(row[2]) if row[2] else {}
                results.append({
                    "memory_id": row[0],
                    "content": row[1],
                    "score": 0.5,  # Fixed score for keyword matches
                    "metadata": meta,
                    "created_at": str(row[3]) if row[3] else "",
                    "source": "postgresql_fallback",
                })

    # Cache results
    if results:
        _cache_set(cache_key, json.dumps(results, default=str), ttl=300)

    return results


def get_all_memories(user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
    """Retrieve all memories for a user, ordered by creation time (newest first)."""
    # Check Redis cache
    cache_key = f"all:{user_id}"
    cached = _cache_get(cache_key)
    if cached:
        try:
            result = json.loads(cached)
            return result[:limit]
        except (json.JSONDecodeError, TypeError):
            pass

    results = []
    rows = _pg_execute(
        f"""
        SELECT id, content, metadata, created_at
        FROM {PG_SCHEMA}.memories
        WHERE user_id = %s
        ORDER BY created_at DESC
        LIMIT %s
        """,
        (user_id, limit),
        fetch=True,
    )
    if rows:
        for row in rows:
            meta = row[2] if isinstance(row[2], dict) else json.loads(row[2]) if row[2] else {}
            results.append({
                "memory_id": row[0],
                "content": row[1],
                "metadata": meta,
                "created_at": str(row[3]) if row[3] else "",
            })

    # Cache result
    if results:
        _cache_set(cache_key, json.dumps(results, default=str))

    return results


def delete_memory(memory_id: str, user_id: str) -> Dict[str, Any]:
    """
    Delete a specific memory by ID, verifying user ownership.

    Returns:
        {"deleted": bool, "memory_id": str, "vector_deleted": bool, "metadata_deleted": bool}
    """
    result = {"deleted": False, "memory_id": memory_id, "vector_deleted": False, "metadata_deleted": False}

    # Verify ownership in PostgreSQL
    rows = _pg_execute(
        f"SELECT user_id FROM {PG_SCHEMA}.memories WHERE id = %s",
        (memory_id,),
        fetch=True,
    )
    if not rows:
        return result
    if rows[0][0] != user_id:
        log.warning(f"User {user_id} attempted to delete memory owned by {rows[0][0]}")
        return result

    # Delete from Qdrant
    qdrant = _get_qdrant()
    if qdrant is not None:
        try:
            from qdrant_client.models import PointIdsList

            qdrant.delete(
                collection_name=QDRANT_COLLECTION,
                points_selector=PointIdsList(points=[memory_id]),
            )
            result["vector_deleted"] = True
        except Exception as e:
            log.error(f"Qdrant delete failed: {e}")

    # Delete from PostgreSQL
    _pg_execute(
        f"DELETE FROM {PG_SCHEMA}.memories WHERE id = %s AND user_id = %s",
        (memory_id, user_id),
    )
    result["metadata_deleted"] = True
    result["deleted"] = True

    # Invalidate cache
    _cache_delete(f"all:{user_id}")
    _cache_delete(f"summary:{user_id}")

    return result


def summarize_memories(user_id: str) -> str:
    """
    Generate a summary of a user's memory context.

    Returns a structured text summary of all stored memories,
    grouped by metadata categories.
    """
    # Check cache
    cache_key = f"summary:{user_id}"
    cached = _cache_get(cache_key)
    if cached:
        return cached

    memories = get_all_memories(user_id, limit=200)

    if not memories:
        return f"No memories found for user '{user_id}'."

    # Group by metadata type/source
    by_type: Dict[str, List[str]] = {}
    for mem in memories:
        meta = mem.get("metadata", {})
        mtype = meta.get("type", meta.get("source", "general"))
        if mtype not in by_type:
            by_type[mtype] = []
        # Truncate long content for summary
        content = mem["content"]
        if len(content) > 200:
            content = content[:200] + "..."
        by_type[mtype].append(content)

    summary = f"# Memory Summary for {user_id}\n"
    summary += f"Total memories: {len(memories)}\n\n"

    for mtype, items in sorted(by_type.items()):
        summary += f"## {mtype.replace('_', ' ').title()} ({len(items)})\n"
        for item in items[:10]:  # Max 10 per category in summary
            summary += f"- {item}\n"
        if len(items) > 10:
            summary += f"- ... and {len(items) - 10} more\n"
        summary += "\n"

    # Cache summary
    _cache_set(cache_key, summary, ttl=600)

    return summary


def ingest_takeout(file_path: str, user_id: str) -> Dict[str, Any]:
    """
    Import memories from a Google Takeout export (ZIP or JSON).

    Parses conversation exports and stores each message/exchange as a memory.

    Returns:
        {"ingested": int, "errors": int, "file": str, "user_id": str}
    """
    import zipfile
    from pathlib import Path

    result = {"ingested": 0, "errors": 0, "file": file_path, "user_id": user_id}
    path = Path(file_path)

    if not path.exists():
        return {**result, "error": f"File not found: {file_path}"}

    conversations = []

    # Handle ZIP files
    if path.suffix.lower() == ".zip":
        try:
            with zipfile.ZipFile(path, "r") as zf:
                for name in zf.namelist():
                    if name.endswith(".json"):
                        try:
                            data = json.loads(zf.read(name))
                            if isinstance(data, list):
                                conversations.extend(data)
                            elif isinstance(data, dict):
                                conversations.append(data)
                        except (json.JSONDecodeError, UnicodeDecodeError):
                            result["errors"] += 1
        except zipfile.BadZipFile:
            return {**result, "error": "Invalid ZIP file"}

    # Handle JSON files
    elif path.suffix.lower() == ".json":
        try:
            with open(path, "r", encoding="utf-8") as f:
                data = json.load(f)
            if isinstance(data, list):
                conversations = data
            elif isinstance(data, dict):
                conversations = [data]
        except (json.JSONDecodeError, UnicodeDecodeError) as e:
            return {**result, "error": f"Invalid JSON: {e}"}
    else:
        return {**result, "error": f"Unsupported file type: {path.suffix}"}

    # Process each conversation
    for conv in conversations:
        try:
            # Extract text content from various Takeout formats
            text_parts = []

            # Format 1: Bard/Gemini conversation export
            if "responses" in conv:
                for resp in conv.get("responses", []):
                    text_parts.append(str(resp.get("response", "")))

            # Format 2: MyActivity.json format
            elif "title" in conv and "time" in conv:
                text_parts.append(f"{conv.get('title', '')}: {conv.get('description', '')}")

            # Format 3: Generic conversation with messages
            elif "messages" in conv:
                for msg in conv.get("messages", []):
                    role = msg.get("role", msg.get("author", "unknown"))
                    content = msg.get("content", msg.get("text", ""))
                    text_parts.append(f"[{role}]: {content}")

            # Format 4: Simple text/content field
            elif "content" in conv:
                text_parts.append(str(conv.get("content", "")))
            elif "text" in conv:
                text_parts.append(str(conv.get("text", "")))

            if not text_parts:
                continue

            combined_text = "\n".join(text_parts).strip()
            if not combined_text or len(combined_text) < 10:
                continue

            # Chunk long texts (max 2000 chars per memory)
            chunks = [combined_text[i:i + 2000] for i in range(0, len(combined_text), 2000)]

            for chunk_idx, chunk in enumerate(chunks):
                meta = {
                    "type": "google_takeout",
                    "source_file": path.name,
                    "chunk_index": chunk_idx,
                }
                if "time" in conv:
                    meta["original_date"] = str(conv["time"])
                if "title" in conv:
                    meta["title"] = str(conv.get("title", ""))[:200]

                store_result = store_memory(chunk, user_id, metadata=meta)
                if store_result.get("stored"):
                    result["ingested"] += 1
                else:
                    result["errors"] += 1

        except Exception as e:
            log.error(f"Error processing conversation: {e}")
            result["errors"] += 1

    return result


# ---------------------------------------------------------------------------
# MCP Server Definition
# ---------------------------------------------------------------------------

mcp = FastMCP("Sunaiva Memory")

# Module-level default user_id (set via --user-id for stdio mode)
_default_user_id: Optional[str] = None


def _resolve_user_id(user_id: Optional[str] = None) -> str:
    """Resolve user_id from argument or default."""
    uid = user_id or _default_user_id
    if not uid:
        raise ValueError("user_id is required (provide via argument or --user-id flag)")
    return uid


# ---------------------------------------------------------------------------
# MCP Tools
# ---------------------------------------------------------------------------

@mcp.tool()
def memory_store(content: str, user_id: str = "", metadata: str = "{}") -> str:
    """Store a new memory with semantic embeddings.

    The memory is embedded (768-dim), stored in Qdrant for semantic search,
    and indexed in PostgreSQL for metadata queries. Redis caches are invalidated.

    Args:
        content: The text content to remember
        user_id: User identifier for memory isolation (optional if --user-id set)
        metadata: JSON string of additional metadata (e.g., '{"type": "preference"}')

    Returns:
        JSON result with memory_id and storage status
    """
    uid = _resolve_user_id(user_id or None)
    try:
        meta = json.loads(metadata) if metadata else {}
    except json.JSONDecodeError:
        meta = {"raw_metadata": metadata}

    result = store_memory(content, uid, meta)
    return json.dumps(result, indent=2)


@mcp.tool()
def memory_search(query: str, user_id: str = "", limit: int = 10) -> str:
    """Semantic search across your memories.

    Uses 768-dim vector similarity (Qdrant) with PostgreSQL keyword fallback.
    Results are ranked by relevance score.

    Args:
        query: Natural language search query
        user_id: User identifier (optional if --user-id set)
        limit: Maximum results to return (default: 10)

    Returns:
        JSON array of matching memories with scores
    """
    uid = _resolve_user_id(user_id or None)
    results = search_memories(query, uid, limit)

    if not results:
        return json.dumps({"results": [], "message": f"No memories found matching '{query}'"})

    return json.dumps({"results": results, "count": len(results)}, indent=2, default=str)


@mcp.tool()
def memory_get_all(user_id: str = "", limit: int = 100) -> str:
    """Retrieve all memories for a user, ordered by most recent first.

    Args:
        user_id: User identifier (optional if --user-id set)
        limit: Maximum memories to return (default: 100)

    Returns:
        JSON array of all memories
    """
    uid = _resolve_user_id(user_id or None)
    results = get_all_memories(uid, limit)
    return json.dumps({"memories": results, "count": len(results)}, indent=2, default=str)


@mcp.tool()
def memory_delete(memory_id: str, user_id: str = "") -> str:
    """Delete a specific memory by its ID.

    Verifies user ownership before deletion. Removes from both
    Qdrant (vector) and PostgreSQL (metadata).

    Args:
        memory_id: The memory ID to delete (e.g., 'mem_abc123...')
        user_id: User identifier for ownership verification

    Returns:
        JSON result with deletion status
    """
    uid = _resolve_user_id(user_id or None)
    result = delete_memory(memory_id, uid)
    return json.dumps(result, indent=2)


@mcp.tool()
def memory_summarize(user_id: str = "") -> str:
    """Generate a structured summary of all stored memories.

    Groups memories by type/source and provides an overview.
    Useful for understanding what the memory system knows about a user.

    Args:
        user_id: User identifier (optional if --user-id set)

    Returns:
        Markdown-formatted memory summary
    """
    uid = _resolve_user_id(user_id or None)
    return summarize_memories(uid)


@mcp.tool()
def memory_ingest_takeout(file_path: str, user_id: str = "") -> str:
    """Import memories from a Google Takeout export file.

    Supports ZIP archives and JSON files from Google Takeout.
    Parses multiple conversation formats (Bard/Gemini, MyActivity, generic).

    Args:
        file_path: Path to the Takeout export (ZIP or JSON)
        user_id: User identifier for the imported memories

    Returns:
        JSON result with ingestion count and any errors
    """
    uid = _resolve_user_id(user_id or None)
    result = ingest_takeout(file_path, uid)
    return json.dumps(result, indent=2)


# ---------------------------------------------------------------------------
# Server Entry Point
# ---------------------------------------------------------------------------

def main():
    import argparse

    parser = argparse.ArgumentParser(description="Sunaiva Memory MCP Server (Canonical)")
    parser.add_argument(
        "--transport",
        choices=["stdio", "sse"],
        default="stdio",
        help="Transport mode (default: stdio for Claude Code)",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=int(os.environ.get("MCP_PORT", "8100")),
        help="Port for SSE transport (default: 8100)",
    )
    parser.add_argument(
        "--user-id",
        type=str,
        default=os.environ.get("SUNAIVA_DEFAULT_USER_ID", ""),
        help="Default user_id for stdio mode",
    )

    args = parser.parse_args()

    global _default_user_id
    if args.user_id:
        _default_user_id = args.user_id

    if args.transport == "sse":
        log.info(f"Starting Sunaiva Memory MCP (SSE) on port {args.port}")
        mcp.run(transport="sse", port=args.port)
    else:
        log.info(f"Starting Sunaiva Memory MCP (stdio)")
        if _default_user_id:
            log.info(f"Default user_id: {_default_user_id}")
        mcp.run(transport="stdio")


if __name__ == "__main__":
    main()
