#!/usr/bin/env python3
"""
Tests for Story 5.06 (Track B): ColdLedgerInterceptor — Wires L4 Into Interceptor Chain

Black Box tests (BB): verify the public contract as seen from the chain runner —
    correct event types written, payload fields present, failures swallowed silently.
White Box tests (WB): verify internals — priority=5, pass-through semantics,
    error_class field name, attempt forwarding, BaseInterceptor inheritance.

ALL external calls (ColdLedger, Postgres) are mocked. Zero real I/O.

Story: 5.06
File under test: core/storage/cold_ledger_interceptor.py
"""

from __future__ import annotations

import asyncio
import pathlib
import sys

sys.path.insert(0, "/mnt/e/genesis-system")

import logging
from unittest.mock import MagicMock, patch

import pytest

# ---------------------------------------------------------------------------
# Module under test
# ---------------------------------------------------------------------------

from core.storage.cold_ledger_interceptor import ColdLedgerInterceptor
from core.storage import ColdLedgerInterceptor as ColdLedgerInterceptorFromPackage
from core.interceptors.base_interceptor import BaseInterceptor, InterceptorMetadata


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _make_interceptor(side_effect=None):
    """Return a ColdLedgerInterceptor wired to a mock ColdLedger.

    Args:
        side_effect: If provided, ledger.write_event raises this exception.
    """
    mock_ledger = MagicMock()
    if side_effect is not None:
        mock_ledger.write_event.side_effect = side_effect
    interceptor = ColdLedgerInterceptor(ledger=mock_ledger)
    interceptor._mock_ledger = mock_ledger  # convenience accessor in tests
    return interceptor


def _run(coro):
    """Execute an async coroutine synchronously (no running loop required)."""
    return asyncio.get_event_loop().run_until_complete(coro)


# ---------------------------------------------------------------------------
# Canonical payloads used across multiple tests
# ---------------------------------------------------------------------------

TASK_PAYLOAD = {
    "session_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
    "task_type": "forge",
    "tier": "gold",
}

RESULT_PAYLOAD = {"success": True, "output": "ok"}

CORRECTION_PAYLOAD = {
    "session_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
    "attempt": 3,
    "prompt": "CORRECTION: retry",
}


# ===========================================================================
# Black Box tests
# ===========================================================================


class TestBB1PreExecuteWritesDispatchStart:
    """BB1: After pre_execute, write_event is called with 'dispatch_start'."""

    def test_event_type_is_dispatch_start(self):
        interceptor = _make_interceptor()
        _run(interceptor.pre_execute(TASK_PAYLOAD))
        interceptor._mock_ledger.write_event.assert_called_once()
        call_args = interceptor._mock_ledger.write_event.call_args[0]
        assert call_args[1] == "dispatch_start"

    def test_session_id_forwarded(self):
        interceptor = _make_interceptor()
        _run(interceptor.pre_execute(TASK_PAYLOAD))
        call_args = interceptor._mock_ledger.write_event.call_args[0]
        assert call_args[0] == TASK_PAYLOAD["session_id"]

    def test_payload_contains_task_type(self):
        interceptor = _make_interceptor()
        _run(interceptor.pre_execute(TASK_PAYLOAD))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert written_payload["task_type"] == "forge"

    def test_payload_contains_tier(self):
        interceptor = _make_interceptor()
        _run(interceptor.pre_execute(TASK_PAYLOAD))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert written_payload["tier"] == "gold"

    def test_payload_contains_timestamp(self):
        interceptor = _make_interceptor()
        _run(interceptor.pre_execute(TASK_PAYLOAD))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert "timestamp" in written_payload


class TestBB2PostExecuteWritesDispatchComplete:
    """BB2: After post_execute, write_event is called with 'dispatch_complete'."""

    def test_event_type_is_dispatch_complete(self):
        interceptor = _make_interceptor()
        _run(interceptor.post_execute(RESULT_PAYLOAD, TASK_PAYLOAD))
        call_args = interceptor._mock_ledger.write_event.call_args[0]
        assert call_args[1] == "dispatch_complete"

    def test_payload_contains_success_flag(self):
        interceptor = _make_interceptor()
        _run(interceptor.post_execute(RESULT_PAYLOAD, TASK_PAYLOAD))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert "success" in written_payload
        assert written_payload["success"] is True

    def test_session_id_forwarded(self):
        interceptor = _make_interceptor()
        _run(interceptor.post_execute(RESULT_PAYLOAD, TASK_PAYLOAD))
        call_args = interceptor._mock_ledger.write_event.call_args[0]
        assert call_args[0] == TASK_PAYLOAD["session_id"]

    def test_success_false_when_result_says_so(self):
        interceptor = _make_interceptor()
        _run(interceptor.post_execute({"success": False}, TASK_PAYLOAD))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert written_payload["success"] is False


class TestBB3OnErrorWritesDispatchError:
    """BB3: After on_error, write_event called with 'dispatch_error' containing
    error_class and error_message."""

    def test_event_type_is_dispatch_error(self):
        interceptor = _make_interceptor()
        _run(interceptor.on_error(ValueError("bad input"), TASK_PAYLOAD))
        call_args = interceptor._mock_ledger.write_event.call_args[0]
        assert call_args[1] == "dispatch_error"

    def test_payload_contains_error_class(self):
        interceptor = _make_interceptor()
        _run(interceptor.on_error(ValueError("bad input"), TASK_PAYLOAD))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert "error_class" in written_payload
        assert written_payload["error_class"] == "ValueError"

    def test_payload_contains_error_message(self):
        interceptor = _make_interceptor()
        _run(interceptor.on_error(ValueError("bad input"), TASK_PAYLOAD))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert "error_message" in written_payload
        assert "bad input" in written_payload["error_message"]

    def test_custom_exception_class_name(self):
        interceptor = _make_interceptor()
        _run(interceptor.on_error(RuntimeError("oops"), TASK_PAYLOAD))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert written_payload["error_class"] == "RuntimeError"


class TestBB4ColdLedgerUnavailableDoesNotRaise:
    """BB4: ColdLedger.write_event raises → interceptor continues, no exception propagated."""

    def test_pre_execute_swallows_ledger_error(self):
        interceptor = _make_interceptor(side_effect=RuntimeError("DB down"))
        # Must not raise
        result = _run(interceptor.pre_execute(TASK_PAYLOAD))
        assert result == TASK_PAYLOAD  # still returns payload

    def test_post_execute_swallows_ledger_error(self):
        interceptor = _make_interceptor(side_effect=RuntimeError("DB down"))
        # Must not raise
        _run(interceptor.post_execute(RESULT_PAYLOAD, TASK_PAYLOAD))

    def test_on_error_swallows_ledger_error(self):
        interceptor = _make_interceptor(side_effect=RuntimeError("DB down"))
        result = _run(interceptor.on_error(ValueError("err"), TASK_PAYLOAD))
        assert result == TASK_PAYLOAD

    def test_on_correction_swallows_ledger_error(self):
        interceptor = _make_interceptor(side_effect=RuntimeError("DB down"))
        result = _run(interceptor.on_correction(CORRECTION_PAYLOAD))
        assert result == CORRECTION_PAYLOAD

    def test_warning_is_logged_on_ledger_failure(self, caplog):
        interceptor = _make_interceptor(side_effect=RuntimeError("DB down"))
        with caplog.at_level(logging.WARNING, logger="core.storage.cold_ledger_interceptor"):
            _run(interceptor.pre_execute(TASK_PAYLOAD))
        assert any("write_event failed" in r.message for r in caplog.records)


# ===========================================================================
# White Box tests
# ===========================================================================


class TestWB1PriorityIs5:
    """WB1: metadata.priority must equal 5."""

    def test_metadata_priority(self):
        interceptor = _make_interceptor()
        assert interceptor.metadata.priority == 5

    def test_class_level_metadata_priority(self):
        assert ColdLedgerInterceptor.metadata.priority == 5

    def test_metadata_name(self):
        assert ColdLedgerInterceptor.metadata.name == "cold_ledger"


class TestWB2PreExecutePassThrough:
    """WB2: pre_execute returns the task_payload dict unchanged (same object identity)."""

    def test_returns_same_dict(self):
        interceptor = _make_interceptor()
        payload = {"session_id": "s1", "task_type": "x", "tier": "silver"}
        result = _run(interceptor.pre_execute(payload))
        assert result is payload

    def test_does_not_mutate_payload(self):
        interceptor = _make_interceptor()
        payload = {"session_id": "s1", "task_type": "x", "tier": "silver"}
        original_keys = set(payload.keys())
        _run(interceptor.pre_execute(payload))
        assert set(payload.keys()) == original_keys

    def test_returns_payload_even_when_ledger_fails(self):
        interceptor = _make_interceptor(side_effect=Exception("fail"))
        payload = {"session_id": "s1"}
        result = _run(interceptor.pre_execute(payload))
        assert result is payload


class TestWB3OnErrorPayloadContainsTypeName:
    """WB3: on_error payload['error_class'] equals type(error).__name__."""

    def test_uses_type_name_not_repr(self):
        interceptor = _make_interceptor()
        err = KeyError("missing_key")
        _run(interceptor.on_error(err, TASK_PAYLOAD))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert written_payload["error_class"] == type(err).__name__

    def test_nested_exception_type(self):
        class CustomDBError(Exception):
            pass

        interceptor = _make_interceptor()
        _run(interceptor.on_error(CustomDBError("db fail"), TASK_PAYLOAD))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert written_payload["error_class"] == "CustomDBError"


class TestWB4OnCorrectionIncludesAttemptNumber:
    """WB4: on_correction payload includes 'attempt' from correction_payload."""

    def test_attempt_forwarded(self):
        interceptor = _make_interceptor()
        _run(interceptor.on_correction(CORRECTION_PAYLOAD))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert written_payload["attempt"] == 3

    def test_attempt_defaults_to_zero_when_absent(self):
        interceptor = _make_interceptor()
        _run(interceptor.on_correction({"session_id": "s1"}))
        written_payload = interceptor._mock_ledger.write_event.call_args[0][2]
        assert written_payload["attempt"] == 0

    def test_event_type_is_dispatch_correction(self):
        interceptor = _make_interceptor()
        _run(interceptor.on_correction(CORRECTION_PAYLOAD))
        call_args = interceptor._mock_ledger.write_event.call_args[0]
        assert call_args[1] == "dispatch_correction"

    def test_correction_payload_returned_unchanged(self):
        interceptor = _make_interceptor()
        result = _run(interceptor.on_correction(CORRECTION_PAYLOAD))
        assert result is CORRECTION_PAYLOAD


class TestWB5InheritsFromBaseInterceptor:
    """WB5: ColdLedgerInterceptor must be a subclass of BaseInterceptor."""

    def test_is_subclass(self):
        assert issubclass(ColdLedgerInterceptor, BaseInterceptor)

    def test_instance_is_base_interceptor(self):
        interceptor = _make_interceptor()
        assert isinstance(interceptor, BaseInterceptor)

    def test_no_sqlite3_import_in_source(self):
        source = pathlib.Path(
            "/mnt/e/genesis-system/core/storage/cold_ledger_interceptor.py"
        ).read_text()
        assert "import sqlite3" not in source, (
            "cold_ledger_interceptor.py must NOT import sqlite3 — Genesis Rule 7"
        )


# ===========================================================================
# Package export tests
# ===========================================================================


class TestPackageExports:
    """ColdLedgerInterceptor must be importable directly from core.storage."""

    def test_importable_from_package(self):
        assert ColdLedgerInterceptorFromPackage is ColdLedgerInterceptor

    def test_all_includes_cold_ledger_interceptor(self):
        from core.storage import __all__
        assert "ColdLedgerInterceptor" in __all__


# ===========================================================================
# Unknown session_id fallback
# ===========================================================================


class TestUnknownSessionIdFallback:
    """When session_id is missing from payload, interceptor uses 'unknown'."""

    def test_pre_execute_unknown_session(self):
        interceptor = _make_interceptor()
        _run(interceptor.pre_execute({"task_type": "x"}))
        call_args = interceptor._mock_ledger.write_event.call_args[0]
        assert call_args[0] == "unknown"

    def test_on_error_unknown_session(self):
        interceptor = _make_interceptor()
        _run(interceptor.on_error(ValueError("e"), {}))
        call_args = interceptor._mock_ledger.write_event.call_args[0]
        assert call_args[0] == "unknown"


# ===========================================================================
# Standalone runner
# ===========================================================================


if __name__ == "__main__":
    import traceback

    tests = [
        # BB1
        ("BB1a: pre_execute event_type=dispatch_start",
         TestBB1PreExecuteWritesDispatchStart().test_event_type_is_dispatch_start),
        ("BB1b: pre_execute session_id forwarded",
         TestBB1PreExecuteWritesDispatchStart().test_session_id_forwarded),
        ("BB1c: pre_execute payload contains task_type",
         TestBB1PreExecuteWritesDispatchStart().test_payload_contains_task_type),
        ("BB1d: pre_execute payload contains tier",
         TestBB1PreExecuteWritesDispatchStart().test_payload_contains_tier),
        ("BB1e: pre_execute payload contains timestamp",
         TestBB1PreExecuteWritesDispatchStart().test_payload_contains_timestamp),
        # BB2
        ("BB2a: post_execute event_type=dispatch_complete",
         TestBB2PostExecuteWritesDispatchComplete().test_event_type_is_dispatch_complete),
        ("BB2b: post_execute payload contains success",
         TestBB2PostExecuteWritesDispatchComplete().test_payload_contains_success_flag),
        ("BB2c: post_execute session_id forwarded",
         TestBB2PostExecuteWritesDispatchComplete().test_session_id_forwarded),
        ("BB2d: post_execute success=False propagated",
         TestBB2PostExecuteWritesDispatchComplete().test_success_false_when_result_says_so),
        # BB3
        ("BB3a: on_error event_type=dispatch_error",
         TestBB3OnErrorWritesDispatchError().test_event_type_is_dispatch_error),
        ("BB3b: on_error payload contains error_class",
         TestBB3OnErrorWritesDispatchError().test_payload_contains_error_class),
        ("BB3c: on_error payload contains error_message",
         TestBB3OnErrorWritesDispatchError().test_payload_contains_error_message),
        ("BB3d: on_error custom exception class name",
         TestBB3OnErrorWritesDispatchError().test_custom_exception_class_name),
        # BB4
        ("BB4a: pre_execute swallows ledger error",
         TestBB4ColdLedgerUnavailableDoesNotRaise().test_pre_execute_swallows_ledger_error),
        ("BB4b: post_execute swallows ledger error",
         TestBB4ColdLedgerUnavailableDoesNotRaise().test_post_execute_swallows_ledger_error),
        ("BB4c: on_error swallows ledger error",
         TestBB4ColdLedgerUnavailableDoesNotRaise().test_on_error_swallows_ledger_error),
        ("BB4d: on_correction swallows ledger error",
         TestBB4ColdLedgerUnavailableDoesNotRaise().test_on_correction_swallows_ledger_error),
        # WB1
        ("WB1a: metadata.priority == 5",
         TestWB1PriorityIs5().test_metadata_priority),
        ("WB1b: class-level metadata.priority == 5",
         TestWB1PriorityIs5().test_class_level_metadata_priority),
        ("WB1c: metadata.name == cold_ledger",
         TestWB1PriorityIs5().test_metadata_name),
        # WB2
        ("WB2a: pre_execute returns same dict",
         TestWB2PreExecutePassThrough().test_returns_same_dict),
        ("WB2b: pre_execute does not mutate payload",
         TestWB2PreExecutePassThrough().test_does_not_mutate_payload),
        ("WB2c: pre_execute returns payload even on ledger fail",
         TestWB2PreExecutePassThrough().test_returns_payload_even_when_ledger_fails),
        # WB3
        ("WB3a: on_error uses type(error).__name__",
         TestWB3OnErrorPayloadContainsTypeName().test_uses_type_name_not_repr),
        ("WB3b: on_error nested exception type",
         TestWB3OnErrorPayloadContainsTypeName().test_nested_exception_type),
        # WB4
        ("WB4a: on_correction attempt forwarded",
         TestWB4OnCorrectionIncludesAttemptNumber().test_attempt_forwarded),
        ("WB4b: on_correction attempt defaults to 0",
         TestWB4OnCorrectionIncludesAttemptNumber().test_attempt_defaults_to_zero_when_absent),
        ("WB4c: on_correction event_type=dispatch_correction",
         TestWB4OnCorrectionIncludesAttemptNumber().test_event_type_is_dispatch_correction),
        ("WB4d: on_correction returns payload unchanged",
         TestWB4OnCorrectionIncludesAttemptNumber().test_correction_payload_returned_unchanged),
        # WB5
        ("WB5a: is subclass of BaseInterceptor",
         TestWB5InheritsFromBaseInterceptor().test_is_subclass),
        ("WB5b: instance is BaseInterceptor",
         TestWB5InheritsFromBaseInterceptor().test_instance_is_base_interceptor),
        ("WB5c: no sqlite3 import in source",
         TestWB5InheritsFromBaseInterceptor().test_no_sqlite3_import_in_source),
        # Package
        ("PKG: importable from core.storage",
         TestPackageExports().test_importable_from_package),
        ("PKG: __all__ includes ColdLedgerInterceptor",
         TestPackageExports().test_all_includes_cold_ledger_interceptor),
        # Fallback
        ("FALLBACK: pre_execute unknown session_id",
         TestUnknownSessionIdFallback().test_pre_execute_unknown_session),
        ("FALLBACK: on_error unknown session_id",
         TestUnknownSessionIdFallback().test_on_error_unknown_session),
    ]

    passed = 0
    failed = 0
    for name, fn in tests:
        try:
            fn()
            print(f"  [PASS] {name}")
            passed += 1
        except Exception as exc:
            print(f"  [FAIL] {name}: {exc}")
            traceback.print_exc()
            failed += 1

    print(f"\n{passed}/{passed + failed} tests passed")
    if failed == 0:
        print("ALL TESTS PASSED -- Story 5.06 (Track B): ColdLedgerInterceptor")
    else:
        sys.exit(1)
