#!/usr/bin/env python3
"""
Tests for Story 2.02 (Track B): interceptor_jit_hydration — Master Function

BB1: After hydration, task_payload["system_injection"] present and contains "<ZERO_AMNESIA_STATE>"
BB2: On Redis timeout → hydration still completes with fallback content
BB3: Cold start (empty memory) → envelope has fallback strings
BB4: Original payload fields preserved after hydration

WB1: fast_extract called before scatter_gather_memory (verify call order via mock)
WB2: scatter_gather_memory called with timeout_ms=45
WB3: build_envelope called with MemoryContext from scatter_gather
"""
import asyncio
import sys
from unittest.mock import AsyncMock, MagicMock, patch, call

sys.path.insert(0, '/mnt/e/genesis-system')


# ---------------------------------------------------------------------------
# Black-box tests
# ---------------------------------------------------------------------------

def test_bb1_system_injection_present_and_valid():
    """BB1: After hydration, task_payload['system_injection'] present and contains <ZERO_AMNESIA_STATE>."""
    from core.memory.jit_hydration import interceptor_jit_hydration
    from core.memory.zero_amnesia_envelope import MemoryContext

    cold_context = MemoryContext(
        working_state=None,
        kg_topology=None,
        learned_constraints=None,
        latency_ms=5.0,
    )

    payload = {"task_id": "bb1-test", "prompt": "Fix the redis connection in core/memory/redis_l1_schema.py"}

    with patch("core.memory.jit_hydration.scatter_gather_memory", new=AsyncMock(return_value=cold_context)):
        result = asyncio.run(interceptor_jit_hydration(payload))

    assert "system_injection" in result, "system_injection key must be present in returned payload"
    injection = result["system_injection"]
    assert "<ZERO_AMNESIA_STATE>" in injection, f"Expected <ZERO_AMNESIA_STATE> tag, got: {injection[:200]}"
    assert "</ZERO_AMNESIA_STATE>" in injection, "Expected closing </ZERO_AMNESIA_STATE> tag"
    print("BB1 PASSED: system_injection present and contains <ZERO_AMNESIA_STATE>")


def test_bb2_redis_timeout_completes_with_fallback():
    """BB2: On Redis timeout → hydration still completes with fallback content."""
    from core.memory.jit_hydration import interceptor_jit_hydration
    from core.memory.zero_amnesia_envelope import MemoryContext, FALLBACK_WORKING

    # Simulate: scatter_gather itself always returns (it handles timeouts internally)
    # L1 timed out → working_state=None; others also None
    degraded_context = MemoryContext(
        working_state=None,    # Redis timed out
        kg_topology=None,      # KG also missed
        learned_constraints=None,
        latency_ms=45.0,
    )

    payload = {"task_id": "bb2-timeout", "prompt": "Deploy the redis cache layer"}

    with patch("core.memory.jit_hydration.scatter_gather_memory", new=AsyncMock(return_value=degraded_context)):
        result = asyncio.run(interceptor_jit_hydration(payload))

    assert "system_injection" in result
    injection = result["system_injection"]
    assert "<ZERO_AMNESIA_STATE>" in injection
    # Fallback string must appear since all layers were None
    assert FALLBACK_WORKING in injection or "Cold start" in injection, (
        f"Expected fallback content for cold/timeout state, got: {injection}"
    )
    print("BB2 PASSED: Timeout scenario still returns envelope with fallback content")


def test_bb3_cold_start_fallback_strings():
    """BB3: Cold start (empty memory) → envelope has all fallback strings."""
    from core.memory.jit_hydration import interceptor_jit_hydration
    from core.memory.zero_amnesia_envelope import (
        MemoryContext, FALLBACK_WORKING, FALLBACK_KG, FALLBACK_CONSTRAINTS
    )

    cold_context = MemoryContext(
        working_state=None,
        kg_topology=None,
        learned_constraints=None,
        latency_ms=0.5,
    )

    payload = {}  # Completely empty — no task_id, no prompt

    with patch("core.memory.jit_hydration.scatter_gather_memory", new=AsyncMock(return_value=cold_context)):
        result = asyncio.run(interceptor_jit_hydration(payload))

    injection = result["system_injection"]

    assert FALLBACK_WORKING in injection, f"Expected '{FALLBACK_WORKING}' in envelope"
    assert FALLBACK_KG in injection, f"Expected '{FALLBACK_KG}' in envelope"
    assert FALLBACK_CONSTRAINTS in injection, f"Expected '{FALLBACK_CONSTRAINTS}' in envelope"
    print("BB3 PASSED: Cold start envelope contains all three fallback strings")


def test_bb4_original_payload_fields_preserved():
    """BB4: Original payload fields preserved after hydration."""
    from core.memory.jit_hydration import interceptor_jit_hydration
    from core.memory.zero_amnesia_envelope import MemoryContext

    ctx = MemoryContext(working_state="state", kg_topology=None, learned_constraints=None, latency_ms=3.0)

    original_extra = {
        "task_id": "bb4-preserve",
        "prompt": "Refactor the scatter_gather module",
        "model": "gemini-2.0-flash",
        "temperature": 0.7,
        "custom_field": {"nested": True},
        "tags": ["redis", "qdrant"],
    }
    payload = dict(original_extra)

    with patch("core.memory.jit_hydration.scatter_gather_memory", new=AsyncMock(return_value=ctx)):
        result = asyncio.run(interceptor_jit_hydration(payload))

    # All original keys must survive
    for key, value in original_extra.items():
        assert key in result, f"Original key '{key}' missing from result"
        assert result[key] == value, f"Value for '{key}' changed: expected {value!r}, got {result[key]!r}"

    # system_injection added
    assert "system_injection" in result
    print("BB4 PASSED: All original payload fields preserved, system_injection added")


# ---------------------------------------------------------------------------
# White-box tests
# ---------------------------------------------------------------------------

def test_wb1_fast_extract_called_before_scatter_gather():
    """WB1: fast_extract called before scatter_gather_memory (verify call order)."""
    from core.memory.jit_hydration import interceptor_jit_hydration
    from core.memory.zero_amnesia_envelope import MemoryContext

    call_order = []

    def mock_fast_extract(payload):
        call_order.append("fast_extract")
        return (["redis"], "fix redis")

    async def mock_scatter_gather(task_id, target_entities, intent_string, timeout_ms):
        call_order.append("scatter_gather")
        return MemoryContext(working_state=None, kg_topology=None, learned_constraints=None, latency_ms=1.0)

    payload = {"task_id": "wb1", "prompt": "test redis"}

    with patch("core.memory.jit_hydration.fast_extract", side_effect=mock_fast_extract), \
         patch("core.memory.jit_hydration.scatter_gather_memory", side_effect=mock_scatter_gather):
        asyncio.run(interceptor_jit_hydration(payload))

    assert call_order == ["fast_extract", "scatter_gather"], (
        f"Expected ['fast_extract', 'scatter_gather'], got {call_order}"
    )
    print("WB1 PASSED: fast_extract is called before scatter_gather_memory")


def test_wb2_scatter_gather_called_with_timeout_45():
    """WB2: scatter_gather_memory called with timeout_ms=45."""
    from core.memory.jit_hydration import interceptor_jit_hydration
    from core.memory.zero_amnesia_envelope import MemoryContext

    ctx = MemoryContext(working_state="w", kg_topology="k", learned_constraints="c", latency_ms=20.0)
    mock_scatter = AsyncMock(return_value=ctx)

    payload = {"task_id": "wb2-task", "prompt": "Deploy redis cache"}

    with patch("core.memory.jit_hydration.scatter_gather_memory", new=mock_scatter):
        asyncio.run(interceptor_jit_hydration(payload))

    # Verify it was called exactly once
    mock_scatter.assert_called_once()
    _, kwargs = mock_scatter.call_args
    # timeout_ms may be positional or keyword
    all_args = mock_scatter.call_args.args
    all_kwargs = mock_scatter.call_args.kwargs

    timeout_used = all_kwargs.get("timeout_ms", all_args[3] if len(all_args) > 3 else None)
    assert timeout_used == 45, f"Expected timeout_ms=45, got {timeout_used}"
    print(f"WB2 PASSED: scatter_gather_memory called with timeout_ms=45")


def test_wb3_build_envelope_called_with_memory_context():
    """WB3: build_envelope called with MemoryContext from scatter_gather."""
    from core.memory.jit_hydration import interceptor_jit_hydration
    from core.memory.zero_amnesia_envelope import MemoryContext

    expected_ctx = MemoryContext(
        working_state="Working on tradie scraper",
        kg_topology="KG Topology:\nredis: service — in-memory cache",
        learned_constraints="Learned Constraints:\nScar 1: never use free Gemini",
        latency_ms=22.5,
    )
    mock_scatter = AsyncMock(return_value=expected_ctx)

    captured_context = []

    def mock_build_envelope(ctx):
        captured_context.append(ctx)
        # Call real build_envelope to produce valid XML
        from core.memory.zero_amnesia_envelope import build_envelope as real_build
        return real_build(ctx)

    payload = {"task_id": "wb3-task", "prompt": "Redis scraper"}

    with patch("core.memory.jit_hydration.scatter_gather_memory", new=mock_scatter), \
         patch("core.memory.jit_hydration.build_envelope", side_effect=mock_build_envelope):
        asyncio.run(interceptor_jit_hydration(payload))

    assert len(captured_context) == 1, "build_envelope should be called exactly once"
    ctx_used = captured_context[0]
    assert ctx_used is expected_ctx, "build_envelope must receive the MemoryContext returned by scatter_gather"
    assert ctx_used.working_state == "Working on tradie scraper"
    assert ctx_used.latency_ms == 22.5
    print("WB3 PASSED: build_envelope called with exact MemoryContext from scatter_gather")


# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------

def test_all():
    test_bb1_system_injection_present_and_valid()
    test_bb2_redis_timeout_completes_with_fallback()
    test_bb3_cold_start_fallback_strings()
    test_bb4_original_payload_fields_preserved()
    test_wb1_fast_extract_called_before_scatter_gather()
    test_wb2_scatter_gather_called_with_timeout_45()
    test_wb3_build_envelope_called_with_memory_context()
    print("\nALL 7 TESTS PASSED — Story 2.02 (Track B): interceptor_jit_hydration")


if __name__ == "__main__":
    test_all()
