#!/usr/bin/env python3
"""
Tests for Story 3.01: Telnyx Webhook Signature Verification
AIVA RLM Nexus PRD v2 — Track A

Black box tests (BB1-BB4): test the public API from the outside.
White box tests (WB1-WB4): test internal branches, edge cases, and
                            implementation invariants.
"""
import base64
import hashlib
import hmac
import sys
import time
from unittest.mock import patch

import pytest

sys.path.insert(0, "/mnt/e/genesis-system")

from core.interceptors.telnyx_signature import (
    MAX_TIMESTAMP_AGE,
    SignatureError,
    _decode_signature,
    verify_telnyx_signature,
    verify_timestamp_freshness,
)

# ---------------------------------------------------------------------------
# Shared test fixture helpers
# ---------------------------------------------------------------------------

_TEST_KEY = "test-signing-secret-abc123"
_TEST_BODY = b'{"event_type":"call.initiated","id":"ev-001"}'


def _make_sig(body: bytes, timestamp: str, key: str = _TEST_KEY) -> str:
    """Produce a valid HMAC-SHA256 hex signature for the given inputs."""
    signed = f"{timestamp}.".encode() + body
    return hmac.new(key.encode(), signed, hashlib.sha256).hexdigest()


def _now_str() -> str:
    """Current Unix timestamp as a string."""
    return str(time.time())


# ---------------------------------------------------------------------------
# Black-box tests (BB) — treat the module as a black box
# ---------------------------------------------------------------------------


class TestBB1_ValidSignature:
    """BB1: Valid signature computed with known key → True."""

    def test_valid_hex_signature_returns_true(self):
        ts = _now_str()
        sig = _make_sig(_TEST_BODY, ts)
        result = verify_telnyx_signature(_TEST_BODY, sig, ts, public_key=_TEST_KEY)
        assert result is True

    def test_valid_base64_encoded_signature_returns_true(self):
        """Signature delivered as base64 must also be accepted."""
        ts = _now_str()
        raw_hex = _make_sig(_TEST_BODY, ts)
        # Encode the raw bytes (not the hex string) in base64
        b64_sig = base64.b64encode(bytes.fromhex(raw_hex)).decode()
        result = verify_telnyx_signature(_TEST_BODY, b64_sig, ts, public_key=_TEST_KEY)
        assert result is True


class TestBB2_TamperedPayload:
    """BB2: Modified payload, same signature → False."""

    def test_tampered_body_fails(self):
        ts = _now_str()
        sig = _make_sig(_TEST_BODY, ts)
        tampered = _TEST_BODY + b"TAMPERED"
        result = verify_telnyx_signature(tampered, sig, ts, public_key=_TEST_KEY)
        assert result is False

    def test_truncated_body_fails(self):
        ts = _now_str()
        sig = _make_sig(_TEST_BODY, ts)
        truncated = _TEST_BODY[:-5]
        result = verify_telnyx_signature(truncated, sig, ts, public_key=_TEST_KEY)
        assert result is False


class TestBB3_WrongSignature:
    """BB3: Correct payload, wrong signature → False."""

    def test_wrong_key_fails(self):
        ts = _now_str()
        sig = _make_sig(_TEST_BODY, ts, key="wrong-key")
        result = verify_telnyx_signature(_TEST_BODY, sig, ts, public_key=_TEST_KEY)
        assert result is False

    def test_all_zeros_signature_fails(self):
        ts = _now_str()
        junk_sig = "0" * 64
        result = verify_telnyx_signature(_TEST_BODY, junk_sig, ts, public_key=_TEST_KEY)
        assert result is False


class TestBB4_StaleTimestamp:
    """BB4: Stale timestamp (>300s ago) → False (replay protection)."""

    def test_timestamp_just_over_max_age_fails(self):
        stale_ts = str(time.time() - MAX_TIMESTAMP_AGE - 1)
        sig = _make_sig(_TEST_BODY, stale_ts)
        result = verify_telnyx_signature(_TEST_BODY, sig, stale_ts, public_key=_TEST_KEY)
        assert result is False

    def test_timestamp_well_in_past_fails(self):
        ancient_ts = str(time.time() - 3600)  # 1 hour ago
        sig = _make_sig(_TEST_BODY, ancient_ts)
        result = verify_telnyx_signature(_TEST_BODY, sig, ancient_ts, public_key=_TEST_KEY)
        assert result is False

    def test_timestamp_in_future_beyond_window_fails(self):
        """Future timestamps beyond the window are also rejected."""
        future_ts = str(time.time() + MAX_TIMESTAMP_AGE + 1)
        sig = _make_sig(_TEST_BODY, future_ts)
        result = verify_telnyx_signature(_TEST_BODY, sig, future_ts, public_key=_TEST_KEY)
        assert result is False

    def test_timestamp_just_within_max_age_passes(self):
        """Boundary check: timestamp exactly at edge should still pass."""
        fresh_ts = str(time.time() - MAX_TIMESTAMP_AGE + 5)
        sig = _make_sig(_TEST_BODY, fresh_ts)
        result = verify_telnyx_signature(_TEST_BODY, sig, fresh_ts, public_key=_TEST_KEY)
        assert result is True


# ---------------------------------------------------------------------------
# White-box tests (WB) — test internal branches and error paths
# ---------------------------------------------------------------------------


class TestWB1_EmptySignature:
    """WB1: Empty signature string → SignatureError."""

    def test_empty_string_raises(self):
        ts = _now_str()
        with pytest.raises(SignatureError, match="Empty signature"):
            verify_telnyx_signature(_TEST_BODY, "", ts, public_key=_TEST_KEY)

    def test_whitespace_only_raises(self):
        """Whitespace-only signature is treated as non-hex and non-base64."""
        ts = _now_str()
        with pytest.raises(SignatureError):
            verify_telnyx_signature(_TEST_BODY, "   ", ts, public_key=_TEST_KEY)


class TestWB2_EmptyTimestamp:
    """WB2: Empty timestamp → SignatureError."""

    def test_empty_timestamp_raises(self):
        ts = _now_str()
        sig = _make_sig(_TEST_BODY, ts)
        with pytest.raises(SignatureError, match="Empty timestamp"):
            verify_telnyx_signature(_TEST_BODY, sig, "", public_key=_TEST_KEY)


class TestWB3_NonNumericTimestamp:
    """WB3: Non-numeric timestamp → SignatureError."""

    def test_string_timestamp_raises(self):
        ts = _now_str()
        sig = _make_sig(_TEST_BODY, ts)
        with pytest.raises(SignatureError, match="Malformed timestamp"):
            verify_telnyx_signature(_TEST_BODY, sig, "not-a-number", public_key=_TEST_KEY)

    def test_none_timestamp_raises(self):
        ts = _now_str()
        sig = _make_sig(_TEST_BODY, ts)
        with pytest.raises((SignatureError, TypeError)):
            verify_telnyx_signature(_TEST_BODY, sig, None, public_key=_TEST_KEY)  # type: ignore[arg-type]

    def test_alpha_in_timestamp_raises(self):
        ts = _now_str()
        sig = _make_sig(_TEST_BODY, ts)
        with pytest.raises(SignatureError, match="Malformed timestamp"):
            verify_telnyx_signature(_TEST_BODY, sig, "17000abc00", public_key=_TEST_KEY)


class TestWB4_TimingSafeComparison:
    """WB4: hmac.compare_digest is used (not == operator)."""

    def test_compare_digest_called(self):
        """
        Patch hmac.compare_digest to confirm it is invoked during verification.
        If the implementation used == instead, this test would catch it.
        """
        ts = _now_str()
        sig = _make_sig(_TEST_BODY, ts)

        import core.interceptors.telnyx_signature as mod

        original = hmac.compare_digest
        call_log = []

        def spy_compare_digest(a, b):
            call_log.append((a, b))
            return original(a, b)

        with patch.object(mod.hmac, "compare_digest", side_effect=spy_compare_digest):
            result = verify_telnyx_signature(
                _TEST_BODY, sig, ts, public_key=_TEST_KEY
            )

        assert len(call_log) == 1, "hmac.compare_digest must be called exactly once"
        assert result is True

    def test_compare_digest_called_on_failure(self):
        """compare_digest is still called even when the signature is wrong."""
        ts = _now_str()
        bad_sig = "a" * 64  # valid hex, wrong value

        import core.interceptors.telnyx_signature as mod

        call_log = []
        original = hmac.compare_digest

        def spy(a, b):
            call_log.append((a, b))
            return original(a, b)

        with patch.object(mod.hmac, "compare_digest", side_effect=spy):
            result = verify_telnyx_signature(
                _TEST_BODY, bad_sig, ts, public_key=_TEST_KEY
            )

        assert len(call_log) == 1, "compare_digest must be called even on failure"
        assert result is False


# ---------------------------------------------------------------------------
# Additional tests — verify_timestamp_freshness helper
# ---------------------------------------------------------------------------


class TestVerifyTimestampFreshness:
    """Unit tests for the convenience freshness helper."""

    def test_fresh_timestamp_returns_true(self):
        assert verify_timestamp_freshness(_now_str()) is True

    def test_stale_timestamp_returns_false(self):
        old = str(time.time() - MAX_TIMESTAMP_AGE - 10)
        assert verify_timestamp_freshness(old) is False

    def test_non_numeric_returns_false(self):
        assert verify_timestamp_freshness("not-a-number") is False

    def test_custom_max_age(self):
        ts = str(time.time() - 10)
        assert verify_timestamp_freshness(ts, max_age=5) is False
        assert verify_timestamp_freshness(ts, max_age=30) is True


# ---------------------------------------------------------------------------
# Additional tests — _decode_signature internal helper
# ---------------------------------------------------------------------------


class TestDecodeSignature:
    """Unit tests for the private _decode_signature normaliser."""

    def test_hex_string_returned_lowercased(self):
        hex_sig = "A" * 64
        assert _decode_signature(hex_sig) == "a" * 64

    def test_base64_decoded_to_hex(self):
        raw = bytes(range(32))  # 32 arbitrary bytes
        b64 = base64.b64encode(raw).decode()
        assert _decode_signature(b64) == raw.hex()

    def test_malformed_raises_signature_error(self):
        with pytest.raises(SignatureError, match="Malformed signature"):
            _decode_signature("not-hex-not-base64!!!@@##")


# ---------------------------------------------------------------------------
# No-key tests
# ---------------------------------------------------------------------------


class TestNoKey:
    """Verify that missing signing key raises SignatureError."""

    def test_no_key_no_env_raises(self):
        ts = _now_str()
        sig = _make_sig(_TEST_BODY, ts)
        import core.interceptors.telnyx_signature as mod

        original_key = mod.TELNYX_PUBLIC_KEY
        try:
            mod.TELNYX_PUBLIC_KEY = ""
            with pytest.raises(SignatureError, match="No signing key"):
                verify_telnyx_signature(_TEST_BODY, sig, ts)
        finally:
            mod.TELNYX_PUBLIC_KEY = original_key


# ---------------------------------------------------------------------------
# Run summary
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    result = pytest.main([__file__, "-v", "--tb=short"])
    sys.exit(result)
