#!/usr/bin/env python3
"""
Genesis Voice Bridge MCP Server - Test Suite
=============================================
Tests the MCP server tools and connectivity using streamable HTTP.

Run locally:
    python test_server.py

Run against deployed server:
    GENESIS_MCP_URL=https://your-server:8900 python test_server.py
"""

import os
import sys
import json
import requests

# Configuration
MCP_URL = os.environ.get("GENESIS_MCP_URL", "http://localhost:8900")
AUTH_TOKEN = os.environ.get(
    "GENESIS_MCP_AUTH_TOKEN",
    "genesis-voice-bridge-2026-production-key",
)

# Test counters
tests_run = 0
tests_passed = 0
tests_failed = 0
failures = []

# Session state
session_id = None


def test(name: str, passed: bool, detail: str = ""):
    """Record test result."""
    global tests_run, tests_passed, tests_failed
    tests_run += 1
    if passed:
        tests_passed += 1
        print(f"  PASS: {name}")
    else:
        tests_failed += 1
        failures.append(f"{name}: {detail}")
        print(f"  FAIL: {name} -- {detail}")


def mcp_request(method: str, params: dict = None, request_id: int = 1) -> dict:
    """Send a request via the streamable HTTP MCP protocol."""
    global session_id

    payload = {
        "jsonrpc": "2.0",
        "id": request_id,
        "method": method,
    }
    if params:
        payload["params"] = params

    headers = {
        "Authorization": f"Bearer {AUTH_TOKEN}",
        "Content-Type": "application/json",
        "Accept": "application/json, text/event-stream",
    }
    if session_id:
        headers["mcp-session-id"] = session_id

    resp = requests.post(f"{MCP_URL}/mcp", json=payload, headers=headers, timeout=30)

    # Extract session ID from response headers
    if "mcp-session-id" in resp.headers:
        session_id = resp.headers["mcp-session-id"]

    # Parse SSE response
    if resp.headers.get("content-type", "").startswith("text/event-stream"):
        for line in resp.text.split("\n"):
            if line.startswith("data: "):
                try:
                    return json.loads(line[6:])
                except json.JSONDecodeError:
                    pass
        return {}
    else:
        try:
            return resp.json()
        except Exception:
            return {}


def mcp_notify(method: str, params: dict = None):
    """Send a notification (no id, no response expected)."""
    payload = {"jsonrpc": "2.0", "method": method}
    if params:
        payload["params"] = params

    headers = {
        "Authorization": f"Bearer {AUTH_TOKEN}",
        "Content-Type": "application/json",
        "Accept": "application/json, text/event-stream",
    }
    if session_id:
        headers["mcp-session-id"] = session_id

    requests.post(f"{MCP_URL}/mcp", json=payload, headers=headers, timeout=10)


def call_tool(name: str, arguments: dict = None) -> str:
    """Call an MCP tool and return the text result."""
    params = {"name": name, "arguments": arguments or {}}
    data = mcp_request("tools/call", params)
    content = data.get("result", {}).get("content", [])
    if content:
        return content[0].get("text", "")
    error = data.get("error", {})
    if error:
        return f"ERROR: {error.get('message', str(error))}"
    return ""


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_health():
    """Test health endpoint (no auth required)."""
    print("\n--- Health Check ---")
    try:
        resp = requests.get(f"{MCP_URL}/health", timeout=5)
        test("health returns 200", resp.status_code == 200, f"got {resp.status_code}")
        data = resp.json()
        test("status is healthy", data.get("status") == "healthy", f"got {data}")
        test("service name correct", "genesis-voice-bridge" in data.get("service", ""), str(data))
    except Exception as e:
        test("health reachable", False, str(e))


def test_auth():
    """Test authentication enforcement."""
    print("\n--- Authentication ---")
    try:
        # No auth
        resp = requests.post(f"{MCP_URL}/mcp", json={}, timeout=5)
        test("rejects no auth (401)", resp.status_code == 401, f"got {resp.status_code}")

        # Wrong auth
        resp = requests.post(
            f"{MCP_URL}/mcp",
            json={},
            headers={"Authorization": "Bearer wrong-token"},
            timeout=5,
        )
        test("rejects wrong token (403)", resp.status_code == 403, f"got {resp.status_code}")
    except Exception as e:
        test("auth enforcement", False, str(e))


def test_initialize():
    """Test MCP session initialization."""
    print("\n--- Initialize ---")
    global session_id
    try:
        data = mcp_request("initialize", {
            "protocolVersion": "2025-03-26",
            "capabilities": {},
            "clientInfo": {"name": "genesis-test", "version": "1.0"},
        })
        test("initialize returns result", "result" in data, str(data)[:200])
        server_info = data.get("result", {}).get("serverInfo", {})
        test("server name correct", server_info.get("name") == "Genesis Voice Bridge", str(server_info))
        test("session ID obtained", session_id is not None, "no session ID")

        # Send initialized notification
        mcp_notify("notifications/initialized")
        test("initialized notification sent", True)
    except Exception as e:
        test("initialize", False, str(e))


def test_tools_list():
    """Test listing available tools."""
    print("\n--- Tools List ---")
    try:
        data = mcp_request("tools/list", {})
        tools = data.get("result", {}).get("tools", [])
        test("returns tools", len(tools) > 0, f"got {len(tools)} tools")

        expected = [
            "get_project_status", "get_war_room", "search_memory",
            "query_knowledge_graph", "get_recent_decisions", "get_agent_status",
            "get_memory_context", "get_session_progress",
            "get_architecture_summary", "get_revenue_status",
        ]
        tool_names = [t["name"] for t in tools]
        for name in expected:
            test(f"tool '{name}' exists", name in tool_names, f"not in {tool_names}")
    except Exception as e:
        test("tools/list", False, str(e))


def test_get_project_status():
    """Test get_project_status tool."""
    print("\n--- get_project_status ---")
    result = call_tool("get_project_status", {"project": "all"})
    test("returns content", len(result) > 50, f"got {len(result)} chars")
    test("contains ReceptionistAI", "ReceptionistAI" in result, "missing")
    test("contains Sunaiva", "Sunaiva" in result, "missing")
    test("contains AIVA", "AIVA" in result, "missing")

    # Test specific project
    result = call_tool("get_project_status", {"project": "receptionistai"})
    test("single project works", "ReceptionistAI" in result, "missing")


def test_get_war_room():
    """Test get_war_room tool."""
    print("\n--- get_war_room ---")
    result = call_tool("get_war_room")
    test("returns content", len(result) > 50, f"got {len(result)} chars")
    test("contains mission", "mission" in result.lower(), "missing")
    test("contains completed", "completed" in result.lower(), "missing")
    test("contains blocker", "blocker" in result.lower(), "missing")


def test_search_memory():
    """Test search_memory tool."""
    print("\n--- search_memory ---")
    result = call_tool("search_memory", {"query": "Genesis architecture", "limit": 3})
    test("returns content", len(result) > 10, f"got {len(result)} chars")
    test("not an error", "error" not in result.lower()[:50], result[:100])


def test_query_knowledge_graph():
    """Test query_knowledge_graph tool."""
    print("\n--- query_knowledge_graph ---")
    result = call_tool("query_knowledge_graph", {"query": "stats", "query_type": "stats"})
    test("returns content", len(result) > 10, f"got {len(result)} chars")
    # May return "unavailable" if FalkorDB is not reachable -- that's OK
    test("handles gracefully", result is not None and len(result) > 0, "empty")


def test_get_recent_decisions():
    """Test get_recent_decisions tool."""
    print("\n--- get_recent_decisions ---")
    result = call_tool("get_recent_decisions", {"days": 14})
    test("returns content", len(result) > 10, f"got {len(result)} chars")


def test_get_agent_status():
    """Test get_agent_status tool."""
    print("\n--- get_agent_status ---")
    result = call_tool("get_agent_status")
    test("returns content", len(result) > 50, f"got {len(result)} chars")
    test("contains agents", "agent" in result.lower(), "missing")
    test("contains skills count", "64" in result or "skills" in result.lower(), "missing")


def test_get_memory_context():
    """Test get_memory_context tool."""
    print("\n--- get_memory_context ---")
    result = call_tool("get_memory_context")
    test("returns content", len(result) > 50, f"got {len(result)} chars")
    test("contains Prime Directives or memory", "memory" in result.lower() or "directive" in result.lower(), "missing")


def test_get_architecture_summary():
    """Test get_architecture_summary tool."""
    print("\n--- get_architecture_summary ---")
    result = call_tool("get_architecture_summary")
    test("returns content", len(result) > 100, f"got {len(result)} chars")
    test("contains Elestio", "Elestio" in result, "missing")
    test("contains Telnyx", "Telnyx" in result, "missing")


def test_get_revenue_status():
    """Test get_revenue_status tool."""
    print("\n--- get_revenue_status ---")
    result = call_tool("get_revenue_status")
    test("returns content", len(result) > 50, f"got {len(result)} chars")
    test("contains MRR", "MRR" in result, "missing")
    test("contains pricing", "$" in result, "missing")


def main():
    """Run all tests."""
    print("=" * 60)
    print("Genesis Voice Bridge MCP Server - Test Suite")
    print(f"Target: {MCP_URL}")
    print("=" * 60)

    test_health()
    test_auth()
    test_initialize()
    test_tools_list()
    test_get_project_status()
    test_get_war_room()
    test_search_memory()
    test_query_knowledge_graph()
    test_get_recent_decisions()
    test_get_agent_status()
    test_get_memory_context()
    test_get_architecture_summary()
    test_get_revenue_status()

    # Summary
    print("\n" + "=" * 60)
    print(f"RESULTS: {tests_passed}/{tests_run} passed, {tests_failed} failed")
    print("=" * 60)

    if failures:
        print("\nFAILURES:")
        for f in failures:
            print(f"  - {f}")

    return 0 if tests_failed == 0 else 1


if __name__ == "__main__":
    sys.exit(main())
