"""
Tests for Story 4.06 — TierRouterInterceptor

Black Box (BB) tests: 5
White Box  (WB) tests: 5
Total: 10

All router and telemetry dependencies are fully mocked via the constructor so
no Redis, no file I/O, and no actual classification logic is exercised here.

Run:
    pytest tests/track_b/test_story_4_06.py -v
"""

from __future__ import annotations

import pytest
from unittest.mock import MagicMock, AsyncMock, patch

from core.routing.tier_router_interceptor import TierRouterInterceptor
from core.routing.tier_classifier import RoutingDecision


# ---------------------------------------------------------------------------
# Helpers / fixtures
# ---------------------------------------------------------------------------


def _decision(tier: str, model: str, rationale: str = "test", cached: bool = False) -> RoutingDecision:
    return RoutingDecision(tier=tier, model=model, rationale=rationale, is_cached=cached)


def _make_interceptor(
    decision: RoutingDecision | None = None,
    router_raises: Exception | None = None,
    telemetry_raises: Exception | None = None,
) -> tuple[TierRouterInterceptor, MagicMock, MagicMock]:
    """
    Build a TierRouterInterceptor with fully injected mocks.

    Returns (interceptor, mock_router, mock_telemetry).
    """
    mock_router = MagicMock()
    mock_telemetry = MagicMock()

    if router_raises is not None:
        mock_router.route.side_effect = router_raises
    elif decision is not None:
        mock_router.route.return_value = decision
    else:
        mock_router.route.return_value = _decision("T1", "gemini-flash")

    if telemetry_raises is not None:
        mock_telemetry.record.side_effect = telemetry_raises

    interceptor = TierRouterInterceptor(router=mock_router, telemetry=mock_telemetry)
    return interceptor, mock_router, mock_telemetry


# ===========================================================================
# BLACK BOX TESTS
# ===========================================================================


class TestBB1_PayloadHasTierAndModel:
    """BB1: After pre_execute, payload has 'tier' and 'model' fields."""

    @pytest.mark.asyncio
    async def test_tier_present_after_pre_execute(self):
        interceptor, _, _ = _make_interceptor(_decision("T0", "python_function"))
        result = await interceptor.pre_execute({"type": "health_check"})
        assert "tier" in result

    @pytest.mark.asyncio
    async def test_model_present_after_pre_execute(self):
        interceptor, _, _ = _make_interceptor(_decision("T0", "python_function"))
        result = await interceptor.pre_execute({"type": "health_check"})
        assert "model" in result

    @pytest.mark.asyncio
    async def test_routing_rationale_present_after_pre_execute(self):
        interceptor, _, _ = _make_interceptor(_decision("T0", "python_function", "matched T0_TYPES"))
        result = await interceptor.pre_execute({"type": "health_check"})
        assert "routing_rationale" in result

    @pytest.mark.asyncio
    async def test_tier_value_matches_router_decision(self):
        interceptor, _, _ = _make_interceptor(_decision("T1", "gemini-flash"))
        result = await interceptor.pre_execute({"type": "email_draft"})
        assert result["tier"] == "T1"

    @pytest.mark.asyncio
    async def test_model_value_matches_router_decision(self):
        interceptor, _, _ = _make_interceptor(_decision("T2", "claude-opus-4-6"))
        result = await interceptor.pre_execute({"type": "complex_reasoning"})
        assert result["model"] == "claude-opus-4-6"


class TestBB2_RoutingErrorDefaultsToT2:
    """BB2: Routing error → payload defaults to tier='T2', model='claude-opus-4-6'."""

    @pytest.mark.asyncio
    async def test_router_exception_sets_t2_tier(self):
        interceptor, _, _ = _make_interceptor(router_raises=RuntimeError("classifier down"))
        result = await interceptor.pre_execute({"type": "anything"})
        assert result["tier"] == "T2"

    @pytest.mark.asyncio
    async def test_router_exception_sets_opus_model(self):
        interceptor, _, _ = _make_interceptor(router_raises=ValueError("bad payload"))
        result = await interceptor.pre_execute({"type": "anything"})
        assert result["model"] == "claude-opus-4-6"

    @pytest.mark.asyncio
    async def test_router_exception_sets_fallback_rationale(self):
        interceptor, _, _ = _make_interceptor(router_raises=RuntimeError("fail"))
        result = await interceptor.pre_execute({"type": "anything"})
        assert "Routing error" in result["routing_rationale"]


class TestBB3_OnErrorReturnsFallbacks:
    """BB3: on_error returns dict with T2/Opus defaults."""

    @pytest.mark.asyncio
    async def test_on_error_returns_dict(self):
        interceptor, _, _ = _make_interceptor()
        result = await interceptor.on_error(RuntimeError("boom"), {"type": "x"})
        assert isinstance(result, dict)

    @pytest.mark.asyncio
    async def test_on_error_tier_is_t2(self):
        interceptor, _, _ = _make_interceptor()
        result = await interceptor.on_error(RuntimeError("boom"), {"type": "x"})
        assert result["tier"] == "T2"

    @pytest.mark.asyncio
    async def test_on_error_model_is_opus(self):
        interceptor, _, _ = _make_interceptor()
        result = await interceptor.on_error(RuntimeError("boom"), {"type": "x"})
        assert result["model"] == "claude-opus-4-6"

    @pytest.mark.asyncio
    async def test_on_error_contains_error_string(self):
        interceptor, _, _ = _make_interceptor()
        error = RuntimeError("something went wrong")
        result = await interceptor.on_error(error, {"type": "x"})
        assert "something went wrong" in result["error"]

    @pytest.mark.asyncio
    async def test_on_error_contains_original_payload(self):
        interceptor, _, _ = _make_interceptor()
        payload = {"type": "x", "task_id": "abc"}
        result = await interceptor.on_error(RuntimeError("err"), payload)
        assert result["task_payload"] is payload


class TestBB4_OnCorrectionPassthrough:
    """BB4: on_correction returns unchanged payload (passthrough)."""

    @pytest.mark.asyncio
    async def test_on_correction_returns_same_dict(self):
        interceptor, _, _ = _make_interceptor()
        payload = {"CORRECTION: retry": True, "tier": "T2"}
        result = await interceptor.on_correction(payload)
        assert result is payload

    @pytest.mark.asyncio
    async def test_on_correction_does_not_mutate_payload(self):
        interceptor, _, _ = _make_interceptor()
        payload = {"key": "value", "tier": "T2"}
        original_copy = dict(payload)
        await interceptor.on_correction(payload)
        assert payload == original_copy


class TestBB5_PostExecuteRecordsToTelemetry:
    """BB5: post_execute records to telemetry."""

    @pytest.mark.asyncio
    async def test_post_execute_calls_telemetry_record(self):
        decision = _decision("T0", "python_function")
        interceptor, _, mock_telemetry = _make_interceptor(decision)
        # Populate _last_decision by running pre_execute first
        await interceptor.pre_execute({"type": "health_check"})
        await interceptor.post_execute({}, {"type": "health_check"})
        mock_telemetry.record.assert_called_once_with(decision)

    @pytest.mark.asyncio
    async def test_post_execute_does_not_crash_if_telemetry_raises(self):
        decision = _decision("T1", "gemini-flash")
        interceptor, _, _ = _make_interceptor(decision, telemetry_raises=RuntimeError("redis down"))
        await interceptor.pre_execute({"type": "email_draft"})
        # Must not raise
        await interceptor.post_execute({}, {"type": "email_draft"})

    @pytest.mark.asyncio
    async def test_post_execute_skips_telemetry_if_no_decision(self):
        """If pre_execute was never called, _last_decision is None — no call to record."""
        interceptor, _, mock_telemetry = _make_interceptor()
        await interceptor.post_execute({}, {})
        mock_telemetry.record.assert_not_called()


# ===========================================================================
# WHITE BOX TESTS
# ===========================================================================


class TestWB1_PriorityIs20:
    """WB1: metadata.priority == 20."""

    def test_metadata_priority(self):
        interceptor, _, _ = _make_interceptor()
        assert interceptor.metadata.priority == 20


class TestWB2_TelemetryRecordCalledInPostNotPre:
    """WB2: RoutingTelemetry.record called in post_execute, NOT in pre_execute."""

    @pytest.mark.asyncio
    async def test_record_not_called_during_pre_execute(self):
        decision = _decision("T0", "python_function")
        interceptor, _, mock_telemetry = _make_interceptor(decision)
        await interceptor.pre_execute({"type": "health_check"})
        mock_telemetry.record.assert_not_called()

    @pytest.mark.asyncio
    async def test_record_called_during_post_execute(self):
        decision = _decision("T0", "python_function")
        interceptor, _, mock_telemetry = _make_interceptor(decision)
        await interceptor.pre_execute({"type": "health_check"})
        await interceptor.post_execute({}, {"type": "health_check"})
        mock_telemetry.record.assert_called_once()


class TestWB3_MetadataName:
    """WB3: metadata.name == 'tier_router'."""

    def test_metadata_name(self):
        interceptor, _, _ = _make_interceptor()
        assert interceptor.metadata.name == "tier_router"


class TestWB4_LastDecisionUpdatedAfterPreExecute:
    """WB4: _last_decision is updated after pre_execute."""

    @pytest.mark.asyncio
    async def test_last_decision_none_before_pre_execute(self):
        interceptor, _, _ = _make_interceptor()
        assert interceptor._last_decision is None

    @pytest.mark.asyncio
    async def test_last_decision_set_after_pre_execute(self):
        decision = _decision("T1", "gemini-flash", "matched T1_TYPES")
        interceptor, _, _ = _make_interceptor(decision)
        await interceptor.pre_execute({"type": "email_draft"})
        assert interceptor._last_decision is not None
        assert interceptor._last_decision.tier == "T1"

    @pytest.mark.asyncio
    async def test_last_decision_set_on_routing_error(self):
        """Even on routing error, _last_decision is populated with fallback."""
        interceptor, _, _ = _make_interceptor(router_raises=RuntimeError("fail"))
        await interceptor.pre_execute({"type": "anything"})
        assert interceptor._last_decision is not None
        assert interceptor._last_decision.tier == "T2"


class TestWB5_PostExecuteDoesNotMutateResult:
    """WB5: post_execute does NOT mutate result."""

    @pytest.mark.asyncio
    async def test_result_unchanged_after_post_execute(self):
        interceptor, _, _ = _make_interceptor(_decision("T2", "claude-opus-4-6"))
        await interceptor.pre_execute({"type": "deep_analysis"})
        result = {"output": "some value", "status": "ok"}
        original_copy = dict(result)
        await interceptor.post_execute(result, {"type": "deep_analysis"})
        assert result == original_copy

    @pytest.mark.asyncio
    async def test_post_execute_returns_none(self):
        interceptor, _, _ = _make_interceptor(_decision("T0", "python_function"))
        await interceptor.pre_execute({"type": "cache_get"})
        ret = await interceptor.post_execute({}, {})
        assert ret is None
