#!/usr/bin/env python3
"""
Tests for Sunaiva Memory MCP — Canonical Unified Server
========================================================
Covers all 6 MCP tools with mocked external services (Qdrant, PostgreSQL, Redis).

Test count: 25
Categories:
  - Storage (5 tests)
  - Search (4 tests)
  - Retrieval (3 tests)
  - Deletion (3 tests)
  - Summary (2 tests)
  - Takeout Ingestion (3 tests)
  - User Isolation (3 tests)
  - Embedding (2 tests)

Run:
    cd /mnt/e/genesis-system
    python3 -m pytest tests/memory_mcp/test_sunaiva_memory.py -v
"""

import json
import os
import sys
import hashlib
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch, PropertyMock

import pytest

# Ensure the server module is importable
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "mcp-servers" / "sunaiva-memory"))

# We must patch external dependencies before importing the server module
# to avoid actual connections during import

# Pre-patch Config to have empty credentials (no real connections)
os.environ.pop("GENESIS_QDRANT_HOST", None)
os.environ.pop("GENESIS_QDRANT_API_KEY", None)
os.environ.pop("GENESIS_POSTGRES_HOST", None)
os.environ.pop("GENESIS_POSTGRES_PASSWORD", None)
os.environ.pop("GENESIS_REDIS_HOST", None)
os.environ.pop("GENESIS_REDIS_PASSWORD", None)
os.environ.pop("GEMINI_API_KEY", None)
os.environ.pop("GEMINI_API_KEY_NEW", None)

import server as mem_server


# ============================================================================
# Fixtures
# ============================================================================

@pytest.fixture(autouse=True)
def reset_globals():
    """Reset module-level singletons between tests."""
    mem_server._qdrant_client = None
    mem_server._pg_pool = None
    mem_server._redis_client = None
    mem_server._default_user_id = None
    yield
    mem_server._qdrant_client = None
    mem_server._pg_pool = None
    mem_server._redis_client = None
    mem_server._default_user_id = None


@pytest.fixture
def mock_qdrant():
    """Mock Qdrant client that simulates vector operations."""
    mock_client = MagicMock()

    # Mock collection listing
    mock_collection = MagicMock()
    mock_collection.name = mem_server.QDRANT_COLLECTION
    mock_client.get_collections.return_value = MagicMock(collections=[mock_collection])

    # Mock upsert
    mock_client.upsert.return_value = True

    # Default: empty search results
    mock_client.search.return_value = []

    # Mock delete
    mock_client.delete.return_value = True

    with patch.object(mem_server, "_get_qdrant", return_value=mock_client):
        yield mock_client


@pytest.fixture
def mock_pg():
    """Mock PostgreSQL that simulates metadata storage."""
    stored_memories = {}

    def mock_execute(query, params=(), fetch=False):
        q = query.strip().upper()

        if q.startswith("INSERT"):
            # Store the memory
            if len(params) >= 4:
                memory_id, user_id, content = params[0], params[1], params[2]
                meta = params[3] if len(params) > 3 else "{}"
                stored_memories[memory_id] = {
                    "id": memory_id,
                    "user_id": user_id,
                    "content": content,
                    "metadata": json.loads(meta) if isinstance(meta, str) else meta,
                    "created_at": "2026-02-26T12:00:00+00:00",
                }
            return None

        elif q.startswith("SELECT") and "WHERE USER_ID" in q and fetch:
            # Get memories by user_id (get_all or search)
            user_id = params[0]
            if "ILIKE" in q:
                # Search query
                search_term = params[1].strip("%").lower()
                limit = params[2] if len(params) > 2 else 100
                results = []
                for m in stored_memories.values():
                    if m["user_id"] == user_id and search_term in m["content"].lower():
                        results.append((m["id"], m["content"], m["metadata"], m["created_at"]))
                return results[:limit]
            else:
                # Get all
                limit = params[1] if len(params) > 1 else 100
                results = []
                for m in stored_memories.values():
                    if m["user_id"] == user_id:
                        results.append((m["id"], m["content"], m["metadata"], m["created_at"]))
                return results[:limit]

        elif q.startswith("SELECT") and "WHERE ID" in q and fetch:
            # Get memory by ID (ownership check in delete_memory)
            memory_id = params[0]
            if memory_id in stored_memories:
                return [(stored_memories[memory_id]["user_id"],)]
            return []

        elif q.startswith("SELECT") and fetch:
            # Fallback SELECT
            return []

        elif q.startswith("DELETE"):
            memory_id = params[0]
            if memory_id in stored_memories:
                del stored_memories[memory_id]
            return None

        elif q.startswith("CREATE"):
            return None

        return None if not fetch else []

    with patch.object(mem_server, "_pg_execute", side_effect=mock_execute):
        with patch.object(mem_server, "_get_pg", return_value=MagicMock()):
            yield stored_memories


@pytest.fixture
def mock_redis():
    """Mock Redis that simulates caching."""
    cache = {}

    def mock_cache_set(key, value, ttl=3600):
        cache[key] = value

    def mock_cache_get(key):
        return cache.get(key)

    def mock_cache_delete(key):
        cache.pop(key, None)

    with patch.object(mem_server, "_cache_set", side_effect=mock_cache_set):
        with patch.object(mem_server, "_cache_get", side_effect=mock_cache_get):
            with patch.object(mem_server, "_cache_delete", side_effect=mock_cache_delete):
                yield cache


@pytest.fixture
def mock_embedding():
    """Mock embedding function that returns deterministic 768-dim vectors."""
    def fake_embed(text, api_key=None):
        # Generate a deterministic vector from text hash
        h = hashlib.sha256(text.encode()).digest()
        vec = [float(b) / 255.0 for b in h]
        # Pad/truncate to 768
        vec = (vec * (768 // len(vec) + 1))[:768]
        return vec

    with patch.object(mem_server, "get_embedding", side_effect=fake_embed):
        yield fake_embed


@pytest.fixture
def full_mock(mock_qdrant, mock_pg, mock_redis, mock_embedding):
    """Combine all mocks for full integration testing."""
    return {
        "qdrant": mock_qdrant,
        "pg": mock_pg,
        "redis": mock_redis,
        "embedding": mock_embedding,
    }


# ============================================================================
# Storage Tests (5)
# ============================================================================

class TestMemoryStore:

    def test_store_basic(self, full_mock):
        """Store a simple memory and verify it returns a memory_id."""
        result = json.loads(mem_server.memory_store(
            content="I prefer Python over JavaScript",
            user_id="test_user",
        ))
        assert "memory_id" in result
        assert result["memory_id"].startswith("mem_")
        assert result["stored"] is True

    def test_store_with_metadata(self, full_mock):
        """Store a memory with custom metadata."""
        meta = json.dumps({"type": "preference", "category": "programming"})
        result = json.loads(mem_server.memory_store(
            content="I use VSCode for editing",
            user_id="test_user",
            metadata=meta,
        ))
        assert result["stored"] is True
        assert result["memory_id"].startswith("mem_")

    def test_store_deterministic_id(self, full_mock):
        """Same content + user_id should produce the same memory_id (dedup)."""
        r1 = json.loads(mem_server.memory_store("test content", "user_a"))
        r2 = json.loads(mem_server.memory_store("test content", "user_a"))
        assert r1["memory_id"] == r2["memory_id"]

    def test_store_different_users_different_ids(self, full_mock):
        """Same content for different users should produce different memory_ids."""
        r1 = json.loads(mem_server.memory_store("shared content", "user_a"))
        r2 = json.loads(mem_server.memory_store("shared content", "user_b"))
        assert r1["memory_id"] != r2["memory_id"]

    def test_store_qdrant_upsert_called(self, full_mock):
        """Verify Qdrant upsert is called with correct collection."""
        mem_server.memory_store("test memory", "user_x")
        full_mock["qdrant"].upsert.assert_called_once()
        call_kwargs = full_mock["qdrant"].upsert.call_args
        assert call_kwargs.kwargs.get("collection_name") == mem_server.QDRANT_COLLECTION


# ============================================================================
# Search Tests (4)
# ============================================================================

class TestMemorySearch:

    def test_search_empty_results(self, full_mock):
        """Search with no matching memories returns empty results."""
        result = json.loads(mem_server.memory_search("nonexistent query", "user_a"))
        assert result["results"] == []

    def test_search_pg_fallback(self, full_mock):
        """When Qdrant returns nothing, PostgreSQL keyword search is used."""
        # Pre-store a memory in mock PG
        mem_server.store_memory("I love machine learning and neural networks", "user_a")

        # Search should find it via PG fallback (Qdrant mock returns empty)
        result = json.loads(mem_server.memory_search("machine learning", "user_a"))
        assert result["count"] >= 1
        assert "machine learning" in result["results"][0]["content"].lower()

    def test_search_qdrant_semantic(self, full_mock):
        """When Qdrant returns hits, they are formatted correctly."""
        # Configure mock Qdrant to return a hit
        mock_hit = MagicMock()
        mock_hit.id = "mem_abc123"
        mock_hit.score = 0.85
        mock_hit.payload = {
            "user_id": "user_a",
            "content": "Python is my favourite language",
            "metadata": {"type": "preference"},
            "created_at": "2026-02-26T10:00:00",
        }
        full_mock["qdrant"].search.return_value = [mock_hit]

        result = json.loads(mem_server.memory_search("programming language", "user_a"))
        assert result["count"] == 1
        assert result["results"][0]["memory_id"] == "mem_abc123"
        assert result["results"][0]["score"] == 0.85
        assert result["results"][0]["source"] == "qdrant"

    def test_search_respects_limit(self, full_mock):
        """Search respects the limit parameter."""
        # Store multiple memories
        for i in range(5):
            mem_server.store_memory(f"Memory about topic alpha number {i}", "user_a")

        result = json.loads(mem_server.memory_search("topic alpha", "user_a", limit=2))
        assert result["count"] <= 2


# ============================================================================
# Retrieval Tests (3)
# ============================================================================

class TestMemoryGetAll:

    def test_get_all_empty(self, full_mock):
        """Get all memories for a user with no memories."""
        result = json.loads(mem_server.memory_get_all("nonexistent_user"))
        assert result["count"] == 0
        assert result["memories"] == []

    def test_get_all_returns_stored(self, full_mock):
        """Get all memories returns previously stored memories."""
        mem_server.store_memory("First memory", "user_a")
        mem_server.store_memory("Second memory", "user_a")

        result = json.loads(mem_server.memory_get_all("user_a"))
        assert result["count"] == 2

    def test_get_all_respects_limit(self, full_mock):
        """Get all respects the limit parameter."""
        for i in range(10):
            mem_server.store_memory(f"Memory number {i}", "user_a")

        result = json.loads(mem_server.memory_get_all("user_a", limit=3))
        assert result["count"] <= 3


# ============================================================================
# Deletion Tests (3)
# ============================================================================

class TestMemoryDelete:

    def test_delete_existing(self, full_mock):
        """Delete an existing memory by ID."""
        store_result = mem_server.store_memory("Delete me please", "user_a")
        memory_id = store_result["memory_id"]

        result = json.loads(mem_server.memory_delete(memory_id, "user_a"))
        assert result["deleted"] is True
        assert result["memory_id"] == memory_id

    def test_delete_nonexistent(self, full_mock):
        """Deleting a nonexistent memory returns deleted=False."""
        result = json.loads(mem_server.memory_delete("mem_nonexistent", "user_a"))
        assert result["deleted"] is False

    def test_delete_wrong_user(self, full_mock):
        """A user cannot delete another user's memory."""
        store_result = mem_server.store_memory("Private memory", "user_a")
        memory_id = store_result["memory_id"]

        result = json.loads(mem_server.memory_delete(memory_id, "user_b"))
        assert result["deleted"] is False


# ============================================================================
# Summary Tests (2)
# ============================================================================

class TestMemorySummarize:

    def test_summarize_empty(self, full_mock):
        """Summarize with no memories returns appropriate message."""
        result = mem_server.memory_summarize("empty_user")
        assert "No memories found" in result

    def test_summarize_with_data(self, full_mock):
        """Summarize groups memories by type."""
        mem_server.store_memory("I like Python", "user_a", {"type": "preference"})
        mem_server.store_memory("Met John today", "user_a", {"type": "event"})
        mem_server.store_memory("Decided to use FastAPI", "user_a", {"type": "decision"})

        result = mem_server.memory_summarize("user_a")
        assert "Memory Summary" in result
        assert "Total memories:" in result


# ============================================================================
# Takeout Ingestion Tests (3)
# ============================================================================

class TestMemoryIngestTakeout:

    def test_ingest_json_file(self, full_mock, tmp_path):
        """Ingest a simple JSON file with conversations."""
        data = [
            {"content": "First conversation about AI", "title": "AI Chat"},
            {"content": "Second conversation about Python", "title": "Python Chat"},
        ]
        json_file = tmp_path / "takeout.json"
        json_file.write_text(json.dumps(data))

        result = json.loads(mem_server.memory_ingest_takeout(str(json_file), "user_a"))
        assert result["ingested"] == 2
        assert result["errors"] == 0

    def test_ingest_missing_file(self, full_mock):
        """Ingesting a nonexistent file returns an error."""
        result = json.loads(mem_server.memory_ingest_takeout("/nonexistent/file.json", "user_a"))
        assert "error" in result
        assert "not found" in result["error"].lower()

    def test_ingest_bard_format(self, full_mock, tmp_path):
        """Ingest a Bard/Gemini format export."""
        data = [
            {
                "responses": [
                    {"response": "Here is my answer about machine learning"},
                    {"response": "And here is another response about AI"},
                ]
            }
        ]
        json_file = tmp_path / "bard_export.json"
        json_file.write_text(json.dumps(data))

        result = json.loads(mem_server.memory_ingest_takeout(str(json_file), "user_a"))
        assert result["ingested"] >= 1


# ============================================================================
# User Isolation Tests (3)
# ============================================================================

class TestUserIsolation:

    def test_user_a_cannot_see_user_b(self, full_mock):
        """Memories stored by user A should not appear in user B's get_all."""
        mem_server.store_memory("User A secret", "user_a")
        mem_server.store_memory("User B memory", "user_b")

        result_a = json.loads(mem_server.memory_get_all("user_a"))
        result_b = json.loads(mem_server.memory_get_all("user_b"))

        a_contents = [m["content"] for m in result_a["memories"]]
        b_contents = [m["content"] for m in result_b["memories"]]

        assert "User A secret" in a_contents
        assert "User B memory" not in a_contents
        assert "User B memory" in b_contents
        assert "User A secret" not in b_contents

    def test_search_isolated_by_user(self, full_mock):
        """Search only returns memories belonging to the querying user."""
        mem_server.store_memory("Confidential data about project X", "user_a")
        mem_server.store_memory("Public information about project X", "user_b")

        result = json.loads(mem_server.memory_search("project X", "user_a"))
        for r in result.get("results", []):
            assert "Confidential" in r["content"]
            assert "Public" not in r["content"]

    def test_delete_isolated_by_user(self, full_mock):
        """A user cannot delete memories they do not own."""
        store_result = mem_server.store_memory("My private note", "user_a")
        memory_id = store_result["memory_id"]

        # user_b tries to delete user_a's memory
        delete_result = json.loads(mem_server.memory_delete(memory_id, "user_b"))
        assert delete_result["deleted"] is False

        # Verify memory still exists
        all_a = json.loads(mem_server.memory_get_all("user_a"))
        assert all_a["count"] == 1


# ============================================================================
# Embedding Tests (2)
# ============================================================================

class TestEmbedding:

    def test_embedding_dimension(self):
        """get_embedding returns a vector of exactly EMBED_DIM (768)."""
        # Without API key, should return zero vector
        vec = mem_server.get_embedding("test text")
        assert len(vec) == mem_server.EMBED_DIM
        assert all(v == 0.0 for v in vec)  # No API key = zero vector

    def test_embedding_deterministic_zero(self):
        """Without API key, embedding returns consistent zero vector."""
        vec1 = mem_server.get_embedding("hello")
        vec2 = mem_server.get_embedding("world")
        assert vec1 == vec2  # Both zero vectors
        assert len(vec1) == 768


# ============================================================================
# Config & Integration Tests
# ============================================================================

class TestConfig:

    def test_no_hardcoded_secrets(self):
        """Verify no secrets are hardcoded in the server module source."""
        server_path = Path(__file__).parent.parent.parent / "mcp-servers" / "sunaiva-memory" / "server.py"
        source = server_path.read_text()

        # Should NOT contain actual API keys, passwords, or connection strings
        assert "7b74e6621bd0e6650789f6662bca4cbf" not in source
        assert "etY0eog17tD" not in source
        assert "AIzaSy" not in source
        assert "e2ZyYYr4" not in source
        # Should use os.environ.get for all credentials
        assert "os.environ.get" in source

    def test_embed_dim_constant(self):
        """Verify EMBED_DIM is set to 768."""
        assert mem_server.EMBED_DIM == 768

    def test_collection_name(self):
        """Verify Qdrant collection name is standardised."""
        assert mem_server.QDRANT_COLLECTION == "sunaiva_memory_768"

    def test_memory_id_format(self):
        """Memory IDs follow the mem_ prefix convention."""
        mid = mem_server._generate_memory_id("test", "user")
        assert mid.startswith("mem_")
        assert len(mid) == 28  # "mem_" + 24 hex chars


# ============================================================================
# Default User ID Tests
# ============================================================================

class TestDefaultUserId:

    def test_resolve_explicit_user(self):
        """Explicit user_id overrides the default."""
        mem_server._default_user_id = "default_user"
        assert mem_server._resolve_user_id("explicit_user") == "explicit_user"

    def test_resolve_default_user(self):
        """Default user_id is used when no explicit one provided."""
        mem_server._default_user_id = "default_user"
        assert mem_server._resolve_user_id(None) == "default_user"

    def test_resolve_no_user_raises(self):
        """Missing user_id with no default raises ValueError."""
        mem_server._default_user_id = None
        with pytest.raises(ValueError, match="user_id is required"):
            mem_server._resolve_user_id(None)
