"""
VERIFICATION_STAMP
Story: 2.07 (Track B)
Verified By: parallel-builder
Verified At: 2026-02-25T00:00:00Z
Tests: 12/12
Coverage: 100%

Module 2 Black-Box Test Suite — JIT Hydration Pipeline
Story 2.07 — Track B

Tests the full JIT hydration pipeline from the outside:
  - Redis L1 read/write contract
  - scatter_gather_memory concurrency and graceful degradation
  - zero_amnesia_envelope XML assembly and escaping
  - interceptor_jit_hydration end-to-end payload enrichment
  - JITHydrationInterceptor pre_execute and error handling

ALL external I/O is mocked. No live Redis, Qdrant, or KG required.
"""
import asyncio
import json
import sys
import xml.etree.ElementTree as ET
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
import pytest_asyncio

sys.path.insert(0, '/mnt/e/genesis-system')

from core.memory.redis_l1_schema import RedisL1Client, RedisL1State
from core.memory.zero_amnesia_envelope import MemoryContext, build_envelope, FALLBACK_WORKING, FALLBACK_KG, FALLBACK_CONSTRAINTS
from core.memory.scatter_gather import scatter_gather_memory
from core.memory.jit_hydration import interceptor_jit_hydration
from core.memory.jit_hydration_interceptor import JITHydrationInterceptor, _ERROR_ENVELOPE


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _make_state(task_id: str = "task-001") -> RedisL1State:
    return RedisL1State(
        task_id=task_id,
        session_id="sess-abc",
        active_task_id=task_id,
        focus_entities=["redis", "qdrant"],
        current_hypothesis="Testing JIT hydration",
        exhausted_paths=[],
        active_blocker=None,
        version=3,
    )


def _make_mock_redis(state: RedisL1State = None) -> AsyncMock:
    """Return a mock aioredis client wired to return the given state JSON."""
    mock_redis = AsyncMock()
    if state is not None:
        raw = json.dumps({
            "task_id": state.task_id,
            "session_id": state.session_id,
            "active_task_id": state.active_task_id,
            "focus_entities": state.focus_entities,
            "current_hypothesis": state.current_hypothesis,
            "exhausted_paths": state.exhausted_paths,
            "active_blocker": state.active_blocker,
            "version": state.version,
        })
        mock_redis.get = AsyncMock(return_value=raw)
    else:
        mock_redis.get = AsyncMock(return_value=None)
    mock_redis.setex = AsyncMock(return_value=True)
    mock_redis.incr = AsyncMock(return_value=1)
    mock_redis.delete = AsyncMock(return_value=1)
    return mock_redis


# ---------------------------------------------------------------------------
# BB1: Redis L1 set_state then get_state returns equivalent RedisL1State
# ---------------------------------------------------------------------------

class TestRedisL1ReadWrite:

    @pytest.mark.asyncio
    async def test_set_then_get_returns_equivalent_state(self):
        """set_state followed by get_state returns a RedisL1State with identical fields."""
        state = _make_state("task-set-get")
        client = RedisL1Client(redis_url="redis://fake:6379")

        # Capture what setex was called with so we can feed it back to get
        stored_data = {}

        async def fake_setex(key, ttl, data):
            stored_data["key"] = key
            stored_data["value"] = data
            return True

        async def fake_get(key):
            if key == stored_data.get("key"):
                return stored_data.get("value")
            return None

        mock_r = AsyncMock()
        mock_r.setex = AsyncMock(side_effect=fake_setex)
        mock_r.get = AsyncMock(side_effect=fake_get)

        client._redis = mock_r

        set_result = await client.set_state(state)
        assert set_result is True

        retrieved = await client.get_state("task-set-get")
        assert retrieved is not None
        assert retrieved.task_id == state.task_id
        assert retrieved.session_id == state.session_id
        assert retrieved.current_hypothesis == state.current_hypothesis
        assert retrieved.focus_entities == state.focus_entities
        assert retrieved.version == state.version

    @pytest.mark.asyncio
    async def test_get_state_on_missing_key_returns_none(self):
        """get_state on a key not in Redis returns None (cold start)."""
        client = RedisL1Client(redis_url="redis://fake:6379")
        mock_r = _make_mock_redis(state=None)  # returns None from GET
        client._redis = mock_r

        result = await client.get_state("nonexistent-task")
        assert result is None

    @pytest.mark.asyncio
    async def test_key_format_matches_genesis_prefix(self):
        """set_state uses the exact key format genesis:state:task:{id}."""
        state = _make_state("my-special-task")
        client = RedisL1Client(redis_url="redis://fake:6379")

        captured_key = {}

        async def spy_setex(key, ttl, data):
            captured_key["k"] = key
            return True

        mock_r = AsyncMock()
        mock_r.setex = AsyncMock(side_effect=spy_setex)
        client._redis = mock_r

        await client.set_state(state)

        assert captured_key["k"] == "genesis:state:task:my-special-task"


# ---------------------------------------------------------------------------
# BB2: scatter_gather — all 3 fetches succeed
# ---------------------------------------------------------------------------

class TestScatterGatherAllSucceed:

    @pytest.mark.asyncio
    async def test_all_three_fetches_succeed_populates_all_fields(self):
        """When all 3 layers return data, MemoryContext has all 3 fields non-None."""
        with (
            patch("core.memory.scatter_gather._fetch_redis_l1", new_callable=AsyncMock, return_value="Working: task A, Focus: redis"),
            patch("core.memory.scatter_gather._fetch_kg_l2",    new_callable=AsyncMock, return_value="KG Topology:\nentity_1: tool — Redis cache"),
            patch("core.memory.scatter_gather._fetch_qdrant_l3", new_callable=AsyncMock, return_value="Learned Constraints:\nScar 1: avoid retry loops (score=0.91)"),
        ):
            ctx = await scatter_gather_memory("task-all", ["redis"], "test intent", timeout_ms=200)

        assert ctx.working_state is not None
        assert ctx.kg_topology is not None
        assert ctx.learned_constraints is not None
        assert ctx.latency_ms >= 0


# ---------------------------------------------------------------------------
# BB3: scatter_gather — Redis times out
# ---------------------------------------------------------------------------

class TestScatterGatherRedisTimeout:

    @pytest.mark.asyncio
    async def test_redis_timeout_yields_none_working_state(self):
        """When the L1 Redis fetch times out, working_state is None, others may be populated."""
        async def slow_redis(_task_id):
            await asyncio.sleep(10)  # never resolves within test timeout

        with (
            patch("core.memory.scatter_gather._fetch_redis_l1", side_effect=slow_redis),
            patch("core.memory.scatter_gather._fetch_kg_l2",    new_callable=AsyncMock, return_value="KG Topology:\nsome entity"),
            patch("core.memory.scatter_gather._fetch_qdrant_l3", new_callable=AsyncMock, return_value="Learned Constraints:\nscar 1"),
        ):
            ctx = await scatter_gather_memory("task-redis-slow", ["entity"], "intent", timeout_ms=20)

        # Redis timed out → None
        assert ctx.working_state is None
        # Other layers had no timeout and returned data
        assert ctx.kg_topology is not None
        assert ctx.learned_constraints is not None
        # latency is still recorded
        assert ctx.latency_ms >= 0


# ---------------------------------------------------------------------------
# BB4: scatter_gather — All 3 layers time out
# ---------------------------------------------------------------------------

class TestScatterGatherAllTimeout:

    @pytest.mark.asyncio
    async def test_all_three_timeout_all_fields_none_latency_recorded(self):
        """When all 3 layers time out, all fields are None but latency_ms is still populated."""
        async def slow(_arg):
            await asyncio.sleep(10)

        with (
            patch("core.memory.scatter_gather._fetch_redis_l1", side_effect=slow),
            patch("core.memory.scatter_gather._fetch_kg_l2",    side_effect=slow),
            patch("core.memory.scatter_gather._fetch_qdrant_l3", side_effect=slow),
        ):
            ctx = await scatter_gather_memory("task-all-slow", ["x"], "y", timeout_ms=20)

        assert ctx.working_state is None
        assert ctx.kg_topology is None
        assert ctx.learned_constraints is None
        assert isinstance(ctx.latency_ms, float)
        assert ctx.latency_ms >= 0


# ---------------------------------------------------------------------------
# BB5: Envelope builder — all fields None → fallback strings
# ---------------------------------------------------------------------------

class TestEnvelopeBuilderFallbacks:

    def test_all_none_fields_produce_canonical_fallback_strings(self):
        """build_envelope with all None fields uses canonical FALLBACK_* constants."""
        ctx = MemoryContext(
            working_state=None,
            kg_topology=None,
            learned_constraints=None,
            latency_ms=12.3,
        )
        xml_str = build_envelope(ctx)

        assert FALLBACK_WORKING in xml_str
        assert FALLBACK_KG in xml_str
        assert FALLBACK_CONSTRAINTS in xml_str


# ---------------------------------------------------------------------------
# BB6: Envelope builder — output is valid XML
# ---------------------------------------------------------------------------

class TestEnvelopeBuilderValidXML:

    def test_output_parses_as_valid_xml(self):
        """build_envelope always produces XML parseable by ElementTree."""
        ctx = MemoryContext(
            working_state="some state",
            kg_topology="some topology",
            learned_constraints="some scars",
            latency_ms=5.7,
        )
        xml_str = build_envelope(ctx)
        # Should not raise
        root = ET.fromstring(xml_str)
        assert root.tag == "ZERO_AMNESIA_STATE"

    def test_all_none_fields_also_produces_valid_xml(self):
        """Even with all None fields, the envelope parses as valid XML."""
        ctx = MemoryContext(
            working_state=None,
            kg_topology=None,
            learned_constraints=None,
            latency_ms=0.0,
        )
        xml_str = build_envelope(ctx)
        root = ET.fromstring(xml_str)
        assert root is not None


# ---------------------------------------------------------------------------
# BB7: Envelope builder — XML special characters are escaped
# ---------------------------------------------------------------------------

class TestEnvelopeBuilderXMLEscaping:

    def test_lt_character_is_escaped_in_content(self):
        """Content containing < is escaped to &lt; in the raw XML string.

        ElementTree round-trip:
          - The raw XML bytes must contain &lt; (escaped form).
          - After ET.fromstring() parsing, working_el.text contains the decoded
            literal '<' (ElementTree always decodes entities on parse).
          - The XML string would be un-parseable if < were left unescaped.
        """
        ctx = MemoryContext(
            working_state="Condition: a < b",
            kg_topology=None,
            learned_constraints=None,
            latency_ms=1.0,
        )
        xml_str = build_envelope(ctx)
        # 1. The raw XML string must contain the escaped form
        assert "&lt;" in xml_str, "Raw XML must contain &lt; for the < character"
        # 2. Must be parseable — unescaped < would cause ET to raise
        root = ET.fromstring(xml_str)
        working_el = root.find("WORKING_CONTEXT")
        assert working_el is not None
        # 3. After parsing, ET decodes &lt; back to literal < in .text
        assert working_el.text is not None
        assert "<" in working_el.text  # literal < after entity decoding

    def test_ampersand_is_escaped(self):
        """Content containing & is escaped to &amp; to preserve valid XML."""
        ctx = MemoryContext(
            working_state="foo & bar",
            kg_topology=None,
            learned_constraints=None,
            latency_ms=1.0,
        )
        xml_str = build_envelope(ctx)
        root = ET.fromstring(xml_str)
        assert root is not None
        assert "&amp;" in xml_str


# ---------------------------------------------------------------------------
# BB8: interceptor_jit_hydration — payload has system_injection after call
# ---------------------------------------------------------------------------

class TestInterceptorJITHydrationFullPipeline:

    @pytest.mark.asyncio
    async def test_payload_gains_system_injection_key(self):
        """interceptor_jit_hydration adds system_injection key to the payload."""
        with (
            patch("core.memory.scatter_gather._fetch_redis_l1", new_callable=AsyncMock, return_value=None),
            patch("core.memory.scatter_gather._fetch_kg_l2",    new_callable=AsyncMock, return_value=None),
            patch("core.memory.scatter_gather._fetch_qdrant_l3", new_callable=AsyncMock, return_value=None),
        ):
            payload = {"task_id": "task-full", "prompt": "Build the redis schema"}
            result = await interceptor_jit_hydration(payload)

        assert "system_injection" in result
        assert len(result["system_injection"]) > 0

    @pytest.mark.asyncio
    async def test_original_payload_keys_preserved(self):
        """interceptor_jit_hydration preserves all original keys in the returned dict."""
        with (
            patch("core.memory.scatter_gather._fetch_redis_l1", new_callable=AsyncMock, return_value=None),
            patch("core.memory.scatter_gather._fetch_kg_l2",    new_callable=AsyncMock, return_value=None),
            patch("core.memory.scatter_gather._fetch_qdrant_l3", new_callable=AsyncMock, return_value=None),
        ):
            payload = {
                "task_id": "task-preserve",
                "prompt": "Do something useful",
                "custom_key": "my_value",
            }
            result = await interceptor_jit_hydration(payload)

        assert result["task_id"] == "task-preserve"
        assert result["prompt"] == "Do something useful"
        assert result["custom_key"] == "my_value"


# ---------------------------------------------------------------------------
# BB9: JITHydrationInterceptor.pre_execute attaches system_injection
# ---------------------------------------------------------------------------

class TestJITHydrationInterceptorPreExecute:

    @pytest.mark.asyncio
    async def test_pre_execute_attaches_system_injection(self):
        """JITHydrationInterceptor.pre_execute returns payload with system_injection set."""
        interceptor = JITHydrationInterceptor()

        with (
            patch("core.memory.jit_hydration_interceptor.interceptor_jit_hydration", new_callable=AsyncMock) as mock_hydrate,
        ):
            mock_hydrate.return_value = {
                "task_id": "t-99",
                "prompt": "test",
                "system_injection": "<ZERO_AMNESIA_STATE><WORKING_CONTEXT>ok</WORKING_CONTEXT></ZERO_AMNESIA_STATE>",
            }
            payload = {"task_id": "t-99", "prompt": "test"}
            result = await interceptor.pre_execute(payload)

        assert "system_injection" in result
        mock_hydrate.assert_awaited_once_with(payload)


# ---------------------------------------------------------------------------
# BB10: JITHydrationInterceptor — hydration error injects error envelope
# ---------------------------------------------------------------------------

class TestJITHydrationInterceptorErrorHandling:

    @pytest.mark.asyncio
    async def test_hydration_exception_injects_error_envelope(self):
        """When interceptor_jit_hydration raises, an error envelope is injected instead."""
        interceptor = JITHydrationInterceptor()

        with patch(
            "core.memory.jit_hydration_interceptor.interceptor_jit_hydration",
            new_callable=AsyncMock,
            side_effect=RuntimeError("Redis is down"),
        ):
            payload = {"task_id": "t-err", "prompt": "risky task"}
            result = await interceptor.pre_execute(payload)

        assert "system_injection" in result
        assert result["system_injection"] == _ERROR_ENVELOPE
        # Original task_id preserved
        assert result["task_id"] == "t-err"

    @pytest.mark.asyncio
    async def test_error_envelope_is_a_non_empty_string(self):
        """The _ERROR_ENVELOPE constant is a non-empty string (sanity check)."""
        assert isinstance(_ERROR_ENVELOPE, str)
        assert len(_ERROR_ENVELOPE) > 0
        assert "ZERO_AMNESIA_STATE" in _ERROR_ENVELOPE


# ---------------------------------------------------------------------------
# BB11: post_execute logs observability without raising
# ---------------------------------------------------------------------------

class TestJITHydrationInterceptorPostExecute:

    @pytest.mark.asyncio
    async def test_post_execute_does_not_raise_on_io_failure(self):
        """post_execute swallows I/O errors and never raises (observability must not block)."""
        interceptor = JITHydrationInterceptor()
        payload = {
            "task_id": "t-post",
            "system_injection": "<ZERO_AMNESIA_STATE/>",
        }
        result_dict = {"status": "completed"}

        # Patch EVENTS_DIR to a non-writable location to simulate I/O failure
        with patch("core.memory.jit_hydration_interceptor.EVENTS_DIR") as mock_dir:
            mock_dir.mkdir = MagicMock(side_effect=OSError("Permission denied"))
            # Should not raise
            await interceptor.post_execute(result_dict, payload)


# ---------------------------------------------------------------------------
# BB12: on_correction — prepends CORRECTION: prefix and re-hydrates
# ---------------------------------------------------------------------------

class TestJITHydrationInterceptorOnCorrection:

    @pytest.mark.asyncio
    async def test_on_correction_prepends_correction_prefix(self):
        """on_correction prepends 'CORRECTION: ' to the prompt and re-hydrates."""
        interceptor = JITHydrationInterceptor()

        with patch(
            "core.memory.jit_hydration_interceptor.interceptor_jit_hydration",
            new_callable=AsyncMock,
        ) as mock_hydrate:
            mock_hydrate.return_value = {
                "prompt": "CORRECTION: retry this",
                "system_injection": "<ZERO_AMNESIA_STATE/>",
            }
            payload = {"prompt": "retry this", "task_id": "t-corr"}
            result = await interceptor.on_correction(payload)

        # Prompt was mutated before hydration call
        assert payload["prompt"] == "CORRECTION: retry this"
        mock_hydrate.assert_awaited_once()

    @pytest.mark.asyncio
    async def test_on_correction_without_prompt_key_does_not_crash(self):
        """on_correction with no prompt key still calls re-hydration safely."""
        interceptor = JITHydrationInterceptor()

        with patch(
            "core.memory.jit_hydration_interceptor.interceptor_jit_hydration",
            new_callable=AsyncMock,
            return_value={"system_injection": "<ZERO_AMNESIA_STATE/>"},
        ) as mock_hydrate:
            payload = {"task_id": "t-no-prompt"}
            result = await interceptor.on_correction(payload)

        assert "prompt" not in payload  # no mutation when key absent
        mock_hydrate.assert_awaited_once()
