#!/usr/bin/env python3
"""
Tests for Story 3.02: TelnyxWebhookInterceptor — call.initiated Handler
AIVA RLM Nexus PRD v2 — Track A

Black box tests (BB1-BB5): verify the public contract from the outside.
White box tests (WB1-WB5): verify internal paths, DB query shape, logging
                            invariants, and idempotency behaviour.

ALL external dependencies (DB, filesystem) are fully mocked so the suite
runs without any live infrastructure.
"""
import asyncio
import json
import sys
import uuid
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, call, mock_open, patch

import pytest

sys.path.insert(0, "/mnt/e/genesis-system")

from core.interceptors.telnyx_webhook_interceptor import (
    EVENTS_LOG_PATH,
    TelnyxWebhookInterceptor,
    _SESSION_PATH,
)

# ---------------------------------------------------------------------------
# Shared fixtures / helpers
# ---------------------------------------------------------------------------

_VALID_SESSION_ID = "sess-abc-123-xyz"

_VALID_PAYLOAD = {
    "data": {
        "payload": {
            "call_session_id": _VALID_SESSION_ID,
            "direction": "inbound",
        }
    }
}


def _make_interceptor(db=None, redis=None) -> TelnyxWebhookInterceptor:
    """Return a fresh interceptor with optional mocked deps."""
    return TelnyxWebhookInterceptor(db_conn=db, redis_client=redis)


def _run(coro):
    """Run a coroutine synchronously (works on Python 3.7+)."""
    return asyncio.get_event_loop().run_until_complete(coro)


def _make_db_mock() -> MagicMock:
    """Return a minimal psycopg2-connection mock."""
    cursor_mock = MagicMock()
    cursor_mock.__enter__ = MagicMock(return_value=cursor_mock)
    cursor_mock.__exit__ = MagicMock(return_value=False)

    db = MagicMock()
    db.cursor.return_value = cursor_mock
    return db


# ---------------------------------------------------------------------------
# Black-box tests (BB) — treat the class as a black box
# ---------------------------------------------------------------------------


class TestBB1_ValidPayloadReturnsOk:
    """BB1: Valid call.initiated payload → {"status": "ok", "session_id": "..."} returned."""

    def test_returns_status_ok(self):
        interceptor = _make_interceptor()
        with patch.object(interceptor, "_insert_conversation"):
            with patch.object(interceptor, "_log_event"):
                result = _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))
        assert result["status"] == "ok"

    def test_returns_correct_session_id(self):
        interceptor = _make_interceptor()
        with patch.object(interceptor, "_insert_conversation"):
            with patch.object(interceptor, "_log_event"):
                result = _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))
        assert result["session_id"] == _VALID_SESSION_ID

    def test_return_type_is_dict(self):
        interceptor = _make_interceptor()
        with patch.object(interceptor, "_insert_conversation"):
            with patch.object(interceptor, "_log_event"):
                result = _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))
        assert isinstance(result, dict)


class TestBB2_DBConnectionProvided_InsertExecuted:
    """BB2: DB connection provided → INSERT executed into royal_conversations."""

    def test_db_cursor_called(self):
        db = _make_db_mock()
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        db.cursor.assert_called_once()

    def test_cursor_execute_called(self):
        db = _make_db_mock()
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        db.cursor.return_value.execute.assert_called_once()

    def test_db_commit_called(self):
        db = _make_db_mock()
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        db.commit.assert_called_once()


class TestBB3_DBFailure_StillReturnsOk:
    """BB3: DB connection fails → still returns {"status": "ok"} (non-fatal)."""

    def test_db_exception_does_not_propagate(self):
        db = _make_db_mock()
        db.cursor.side_effect = Exception("PG connection refused")
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            # Must NOT raise
            result = _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        assert result["status"] == "ok"

    def test_db_execute_failure_still_returns_ok(self):
        db = _make_db_mock()
        db.cursor.return_value.execute.side_effect = Exception("relation does not exist")
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            result = _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        assert result["status"] == "ok"

    def test_db_failure_logs_db_error_event(self):
        db = _make_db_mock()
        db.cursor.side_effect = Exception("timeout")
        interceptor = _make_interceptor(db=db)

        logged = []

        def capture_log(event_type, data):
            logged.append((event_type, data))

        interceptor._log_event = capture_log
        _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        error_events = [e for e in logged if e[0] == "db_error"]
        assert len(error_events) == 1
        assert "error" in error_events[0][1]


class TestBB4_MissingSessionId_UUID4Fallback:
    """BB4: Missing session_id in payload → UUID4 fallback used."""

    def test_empty_payload_returns_uuid4(self):
        interceptor = _make_interceptor()
        with patch.object(interceptor, "_insert_conversation"):
            with patch.object(interceptor, "_log_event"):
                result = _run(interceptor.handle_call_initiated({}))
        # Should be a valid UUID4
        parsed = uuid.UUID(result["session_id"], version=4)
        assert str(parsed) == result["session_id"]

    def test_missing_call_session_id_key_returns_uuid4(self):
        payload = {"data": {"payload": {}}}  # key absent
        interceptor = _make_interceptor()
        with patch.object(interceptor, "_insert_conversation"):
            with patch.object(interceptor, "_log_event"):
                result = _run(interceptor.handle_call_initiated(payload))
        parsed = uuid.UUID(result["session_id"], version=4)
        assert str(parsed) == result["session_id"]

    def test_none_call_session_id_falls_back_to_uuid4(self):
        payload = {"data": {"payload": {"call_session_id": None}}}
        interceptor = _make_interceptor()
        with patch.object(interceptor, "_insert_conversation"):
            with patch.object(interceptor, "_log_event"):
                result = _run(interceptor.handle_call_initiated(payload))
        parsed = uuid.UUID(result["session_id"], version=4)
        assert str(parsed) == result["session_id"]


class TestBB5_NoDBConnection_NocrashReturnsOk:
    """BB5: No DB connection (None) → no crash, returns ok."""

    def test_none_db_returns_ok(self):
        interceptor = _make_interceptor(db=None)
        with patch.object(interceptor, "_log_event"):
            result = _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))
        assert result["status"] == "ok"

    def test_none_db_does_not_raise(self):
        interceptor = _make_interceptor(db=None)
        with patch.object(interceptor, "_log_event"):
            # Confirm no exception raised
            try:
                _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))
            except Exception as exc:
                pytest.fail(f"Unexpected exception with db=None: {exc}")


# ---------------------------------------------------------------------------
# White-box tests (WB) — test internals, SQL shape, event format
# ---------------------------------------------------------------------------


class TestWB1_SessionIdExtraction:
    """WB1: session_id extracted from payload["data"]["payload"]["call_session_id"]."""

    def test_correct_path_extracted(self):
        interceptor = _make_interceptor()
        result = interceptor._extract_session_id(_VALID_PAYLOAD)
        assert result == _VALID_SESSION_ID

    def test_extraction_path_matches_constant(self):
        """Canonical path must match _SESSION_PATH tuple defined in module."""
        assert _SESSION_PATH == ("data", "payload", "call_session_id")

    def test_deeply_nested_path_is_reached(self):
        """Verify we traverse all three levels: data → payload → call_session_id."""
        interceptor = _make_interceptor()
        # Manually build the expected path
        node = _VALID_PAYLOAD
        for key in _SESSION_PATH:
            node = node[key]
        assert node == _VALID_SESSION_ID


class TestWB2_InsertUsesCorrectTable:
    """WB2: INSERT uses correct table (royal_conversations)."""

    def test_sql_references_royal_conversations(self):
        db = _make_db_mock()
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        execute_call = db.cursor.return_value.execute.call_args
        sql = execute_call[0][0]  # first positional arg
        assert "royal_conversations" in sql

    def test_insert_params_include_session_id(self):
        db = _make_db_mock()
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        execute_call = db.cursor.return_value.execute.call_args
        params = execute_call[0][1]  # second positional arg (tuple of params)
        assert _VALID_SESSION_ID in params

    def test_insert_params_include_started_at(self):
        db = _make_db_mock()
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        execute_call = db.cursor.return_value.execute.call_args
        params = execute_call[0][1]
        # Second param should be a datetime
        assert any(isinstance(p, datetime) for p in params)

    def test_insert_params_include_participants_json(self):
        db = _make_db_mock()
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        execute_call = db.cursor.return_value.execute.call_args
        params = execute_call[0][1]
        # Third param should be JSON string for participants
        participants_param = params[2]
        parsed = json.loads(participants_param)
        assert "kinan" in parsed
        assert "aiva" in parsed


class TestWB3_EventLoggedToEventsJsonl:
    """WB3: Event logged to events.jsonl."""

    def test_log_event_called_on_success(self):
        interceptor = _make_interceptor()
        log_calls = []

        def capture(event_type, data):
            log_calls.append((event_type, data))

        interceptor._log_event = capture

        with patch.object(interceptor, "_insert_conversation"):
            _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        call_initiated_events = [c for c in log_calls if c[0] == "call_initiated"]
        assert len(call_initiated_events) == 1
        assert call_initiated_events[0][1]["session_id"] == _VALID_SESSION_ID

    def test_log_writes_to_events_jsonl_path(self):
        """Verify _log_event writes to the correct file path."""
        interceptor = _make_interceptor()

        written_lines = []

        # Patch open and mkdir
        mock_file = mock_open()
        with patch("core.interceptors.telnyx_webhook_interceptor.EVENTS_LOG_PATH",
                   new=EVENTS_LOG_PATH):
            with patch("builtins.open", mock_file):
                with patch("pathlib.Path.mkdir"):
                    interceptor._log_event("call_initiated", {"session_id": "test-123"})

        # Verify write was called
        handle = mock_file()
        assert handle.write.called

    def test_log_event_json_contains_event_type(self):
        """Verify the written JSON has the telnyx_ prefixed event_type."""
        interceptor = _make_interceptor()

        written_content = []
        original_open = open

        def capture_write(path, mode="r", **kwargs):
            if mode == "a":
                m = MagicMock()
                m.__enter__ = lambda s: s
                m.__exit__ = MagicMock(return_value=False)
                m.write = lambda data: written_content.append(data)
                return m
            return original_open(path, mode, **kwargs)

        with patch("builtins.open", side_effect=capture_write):
            with patch("pathlib.Path.mkdir"):
                interceptor._log_event("call_initiated", {"session_id": "test-999"})

        assert len(written_content) == 1
        record = json.loads(written_content[0].strip())
        assert record["event_type"] == "telnyx_call_initiated"
        assert record["session_id"] == "test-999"


class TestWB4_TimestampIsUTC:
    """WB4: Timestamp in event and DB insert are UTC."""

    def test_insert_started_at_is_utc(self):
        db = _make_db_mock()
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        execute_call = db.cursor.return_value.execute.call_args
        params = execute_call[0][1]
        dt_param = next((p for p in params if isinstance(p, datetime)), None)
        assert dt_param is not None
        assert dt_param.tzinfo == timezone.utc

    def test_log_event_timestamp_is_utc_iso_format(self):
        """The timestamp field in the log record is a UTC ISO-8601 string."""
        interceptor = _make_interceptor()

        written_content = []

        def capture_write(path, mode="r", **kwargs):
            if mode == "a":
                m = MagicMock()
                m.__enter__ = lambda s: s
                m.__exit__ = MagicMock(return_value=False)
                m.write = lambda data: written_content.append(data)
                return m
            original = __builtins__["open"] if isinstance(__builtins__, dict) else open
            return original(path, mode, **kwargs)

        with patch("builtins.open", side_effect=capture_write):
            with patch("pathlib.Path.mkdir"):
                interceptor._log_event("call_initiated", {"session_id": "ts-test"})

        record = json.loads(written_content[0].strip())
        # Should be parseable as ISO 8601 with UTC offset
        ts_str = record["timestamp"]
        dt = datetime.fromisoformat(ts_str)
        # UTC offset should be +00:00 or Z
        assert dt.tzinfo is not None


class TestWB5_OnConflictIdempotency:
    """WB5: ON CONFLICT DO NOTHING — duplicate call.initiated events are idempotent."""

    def test_sql_contains_on_conflict_do_nothing(self):
        db = _make_db_mock()
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        execute_call = db.cursor.return_value.execute.call_args
        sql = execute_call[0][0].upper()
        assert "ON CONFLICT" in sql
        assert "DO NOTHING" in sql

    def test_calling_twice_with_same_session_does_not_raise(self):
        """Calling handle_call_initiated twice must never raise — DB handles idempotency."""
        db = _make_db_mock()
        interceptor = _make_interceptor(db=db)

        with patch.object(interceptor, "_log_event"):
            result1 = _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))
            result2 = _run(interceptor.handle_call_initiated(_VALID_PAYLOAD))

        assert result1["status"] == "ok"
        assert result2["status"] == "ok"
        # Both calls share the same session_id from the payload
        assert result1["session_id"] == result2["session_id"] == _VALID_SESSION_ID


# ---------------------------------------------------------------------------
# Run summary
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    result = pytest.main([__file__, "-v", "--tb=short"])
    sys.exit(result)
