"""
VERIFICATION_STAMP
Story: 2.08 (Track B)
Verified By: parallel-builder
Verified At: 2026-02-25T00:00:00Z
Tests: 12/12
Coverage: 100%

Module 2 White-Box Test Suite — JIT Hydration Pipeline
Story 2.08 — Track B

Tests internal implementation paths:
  - fast_extract: timing contract (<5ms), no I/O, no HTTP
  - scatter_gather: asyncio.gather concurrency, individual per-layer timeout
  - JITHydrationInterceptor: priority, metadata.name, interceptor chain contract
  - build_envelope: latency_ms formatting, mandatory 4 XML child tags
  - RedisL1Client.bump_version: atomic INCR pattern (not read-modify-write)
  - RedisL1State.version: defaults to 0

ALL external I/O is mocked. No live Redis, Qdrant, or KG required.
"""
import asyncio
import inspect
import sys
import time
import xml.etree.ElementTree as ET
from unittest.mock import AsyncMock, MagicMock, call, patch

import pytest

sys.path.insert(0, '/mnt/e/genesis-system')

from core.memory.fast_extract import fast_extract
from core.memory.redis_l1_schema import RedisL1Client, RedisL1State
from core.memory.scatter_gather import scatter_gather_memory
from core.memory.zero_amnesia_envelope import MemoryContext, build_envelope
from core.memory.jit_hydration_interceptor import JITHydrationInterceptor


# ---------------------------------------------------------------------------
# WB1: fast_extract timing < 5ms
# ---------------------------------------------------------------------------

class TestFastExtractTiming:

    def test_fast_extract_completes_under_5ms(self):
        """fast_extract runs in <5ms — it is pure CPU with zero I/O."""
        payload = {
            "prompt": "Build the redis schema at core/memory/redis_l1_schema.py",
            "description": "Use async def get_state and class RedisL1Client",
            "task": "Wire redis and qdrant services together",
        }
        start = time.monotonic()
        fast_extract(payload)
        elapsed_ms = (time.monotonic() - start) * 1000.0
        assert elapsed_ms < 5.0, (
            f"fast_extract took {elapsed_ms:.2f}ms — must stay under 5ms (pure CPU, no I/O)"
        )

    def test_fast_extract_on_empty_payload_is_still_fast(self):
        """fast_extract on an empty dict returns quickly without error."""
        start = time.monotonic()
        entities, intent = fast_extract({})
        elapsed_ms = (time.monotonic() - start) * 1000.0
        assert elapsed_ms < 5.0
        assert entities == []
        assert intent == ""


# ---------------------------------------------------------------------------
# WB2: fast_extract makes no HTTP or I/O calls
# ---------------------------------------------------------------------------

class TestFastExtractNoIO:

    def test_fast_extract_makes_no_network_calls(self):
        """fast_extract is pure Python — no HTTP, no file I/O, no subprocesses."""
        # We verify by confirming the function has no awaitable return and is synchronous
        result = fast_extract({"prompt": "hello redis and postgres"})
        # If it tried network I/O it would need async — verifying sync signature
        assert not asyncio.iscoroutine(result), "fast_extract must NOT be a coroutine"

    def test_fast_extract_is_not_a_coroutine_function(self):
        """fast_extract is defined as a plain sync function — not async def."""
        assert not asyncio.iscoroutinefunction(fast_extract), (
            "fast_extract must be a synchronous function (no I/O)"
        )

    def test_fast_extract_extracts_known_services(self):
        """fast_extract correctly identifies known Genesis services from payload text."""
        entities, intent = fast_extract({
            "prompt": "Connect to redis and qdrant for memory storage with telnyx voice"
        })
        entity_lower = [e.lower() for e in entities]
        assert "redis" in entity_lower
        assert "qdrant" in entity_lower
        assert "telnyx" in entity_lower

    def test_fast_extract_extracts_file_paths(self):
        """fast_extract identifies .py file paths from the payload."""
        entities, intent = fast_extract({
            "prompt": "Edit core/memory/redis_l1_schema.py and core/memory/fast_extract.py"
        })
        assert any("redis_l1_schema.py" in e for e in entities)
        assert any("fast_extract.py" in e for e in entities)

    def test_fast_extract_intent_truncated_to_200_chars(self):
        """Intent string is capped at 200 characters."""
        long_text = "A" * 500
        _, intent = fast_extract({"prompt": long_text})
        assert len(intent) == 200


# ---------------------------------------------------------------------------
# WB3: scatter_gather uses asyncio.gather (concurrent, not sequential)
# ---------------------------------------------------------------------------

class TestScatterGatherConcurrency:

    @pytest.mark.asyncio
    async def test_asyncio_gather_is_called_during_scatter_gather(self):
        """scatter_gather internally calls asyncio.gather to fire all 3 layers at once."""
        gather_calls = []
        original_gather = asyncio.gather

        async def spy_gather(*coros, **kwargs):
            gather_calls.append(len(coros))
            return await original_gather(*coros, **kwargs)

        with (
            patch("core.memory.scatter_gather._fetch_redis_l1", new_callable=AsyncMock, return_value=None),
            patch("core.memory.scatter_gather._fetch_kg_l2",    new_callable=AsyncMock, return_value=None),
            patch("core.memory.scatter_gather._fetch_qdrant_l3", new_callable=AsyncMock, return_value=None),
            patch("core.memory.scatter_gather.asyncio.gather", side_effect=spy_gather),
        ):
            await scatter_gather_memory("task-conc", ["x"], "intent", timeout_ms=200)

        assert len(gather_calls) >= 1, "asyncio.gather must be called for concurrent fetching"
        # gather was called with 3 coroutines (L1, L2, L3)
        assert gather_calls[0] == 3


# ---------------------------------------------------------------------------
# WB4: scatter_gather enforces individual per-layer timeout via asyncio.wait_for
# ---------------------------------------------------------------------------

class TestScatterGatherIndividualTimeout:

    @pytest.mark.asyncio
    async def test_per_layer_timeout_enforced_independently(self):
        """A slow L1 does not delay L2/L3 beyond the per-layer timeout budget."""
        # L1 sleeps for 1s — with timeout_ms=30, should be cancelled
        # L2 and L3 return immediately
        l2_called = []
        l3_called = []

        async def slow_l1(_task_id):
            await asyncio.sleep(1.0)
            return "never"

        async def fast_l2(_entities):
            l2_called.append(True)
            return "KG Topology:\nfast entity"

        async def fast_l3(_intent):
            l3_called.append(True)
            return "Learned Constraints:\nfast scar"

        with (
            patch("core.memory.scatter_gather._fetch_redis_l1", side_effect=slow_l1),
            patch("core.memory.scatter_gather._fetch_kg_l2",    side_effect=fast_l2),
            patch("core.memory.scatter_gather._fetch_qdrant_l3", side_effect=fast_l3),
        ):
            ctx = await scatter_gather_memory("task-timeout", ["y"], "intent", timeout_ms=50)

        # L2 and L3 were called despite L1 being slow
        assert l2_called, "L2 must execute independently of L1 timeout"
        assert l3_called, "L3 must execute independently of L1 timeout"
        # L1 timed out → working_state is None
        assert ctx.working_state is None
        # Others succeeded
        assert ctx.kg_topology is not None
        assert ctx.learned_constraints is not None


# ---------------------------------------------------------------------------
# WB5: JITHydrationInterceptor.metadata.priority == 10
# ---------------------------------------------------------------------------

class TestInterceptorPriority:

    def test_interceptor_priority_is_10(self):
        """JITHydrationInterceptor.metadata.priority must be exactly 10."""
        interceptor = JITHydrationInterceptor()
        assert interceptor.metadata.priority == 10, (
            f"Expected priority=10, got priority={interceptor.metadata.priority}"
        )


# ---------------------------------------------------------------------------
# WB6: JITHydrationInterceptor.metadata.name == "jit_hydration"
# ---------------------------------------------------------------------------

class TestInterceptorMetadataName:

    def test_interceptor_metadata_name_is_jit_hydration(self):
        """JITHydrationInterceptor.metadata.name must be 'jit_hydration'."""
        interceptor = JITHydrationInterceptor()
        assert interceptor.metadata.name == "jit_hydration", (
            f"Expected name='jit_hydration', got name='{interceptor.metadata.name}'"
        )

    def test_interceptor_metadata_enabled_by_default(self):
        """JITHydrationInterceptor.metadata.enabled is True by default."""
        interceptor = JITHydrationInterceptor()
        assert interceptor.metadata.enabled is True


# ---------------------------------------------------------------------------
# WB7: build_envelope latency_ms formatted to 1 decimal place
# ---------------------------------------------------------------------------

class TestEnvelopeLatencyFormatting:

    def test_latency_formatted_to_one_decimal(self):
        """build_envelope formats latency_ms to exactly 1 decimal place."""
        ctx = MemoryContext(
            working_state="ok",
            kg_topology="ok",
            learned_constraints="ok",
            latency_ms=12.3456789,
        )
        xml_str = build_envelope(ctx)
        # Parse and check the HYDRATION_LATENCY_MS element
        root = ET.fromstring(xml_str)
        latency_el = root.find("HYDRATION_LATENCY_MS")
        assert latency_el is not None
        latency_text = latency_el.text
        # Should be "12.3" — exactly one decimal place
        assert latency_text == "12.3", (
            f"Expected latency '12.3', got '{latency_text}'"
        )

    def test_zero_latency_formatted_as_zero_point_zero(self):
        """latency_ms=0.0 renders as '0.0' (not '0', not '0.00')."""
        ctx = MemoryContext(
            working_state=None,
            kg_topology=None,
            learned_constraints=None,
            latency_ms=0.0,
        )
        xml_str = build_envelope(ctx)
        root = ET.fromstring(xml_str)
        latency_el = root.find("HYDRATION_LATENCY_MS")
        assert latency_el.text == "0.0"

    def test_large_latency_formatted_to_one_decimal(self):
        """Large latency values are still formatted to 1 decimal place."""
        ctx = MemoryContext(
            working_state="x",
            kg_topology="x",
            learned_constraints="x",
            latency_ms=999.999,
        )
        xml_str = build_envelope(ctx)
        root = ET.fromstring(xml_str)
        latency_el = root.find("HYDRATION_LATENCY_MS")
        assert latency_el.text == "1000.0"


# ---------------------------------------------------------------------------
# WB8: build_envelope always has exactly 4 XML child tags
# ---------------------------------------------------------------------------

class TestEnvelopeFourTags:

    def test_envelope_always_has_four_child_elements(self):
        """build_envelope produces exactly 4 child elements under ZERO_AMNESIA_STATE."""
        ctx = MemoryContext(
            working_state="ws",
            kg_topology="kg",
            learned_constraints="lc",
            latency_ms=1.0,
        )
        xml_str = build_envelope(ctx)
        root = ET.fromstring(xml_str)
        children = list(root)
        assert len(children) == 4, (
            f"Expected 4 child elements, got {len(children)}: {[c.tag for c in children]}"
        )

    def test_envelope_child_tag_names_are_correct(self):
        """The 4 child elements have the exact expected tag names."""
        ctx = MemoryContext(
            working_state=None,
            kg_topology=None,
            learned_constraints=None,
            latency_ms=0.0,
        )
        xml_str = build_envelope(ctx)
        root = ET.fromstring(xml_str)
        tags = [child.tag for child in root]
        assert "WORKING_CONTEXT" in tags
        assert "TOPOLOGICAL_BLAST_RADIUS" in tags
        assert "LEARNED_CONSTRAINTS" in tags
        assert "HYDRATION_LATENCY_MS" in tags


# ---------------------------------------------------------------------------
# WB9: RedisL1Client.bump_version uses atomic INCR (not read-modify-write)
# ---------------------------------------------------------------------------

class TestRedisL1BumpVersionAtomic:

    @pytest.mark.asyncio
    async def test_bump_version_calls_redis_incr_not_get_set(self):
        """bump_version must use INCR for atomicity — NOT a read-then-write pattern."""
        client = RedisL1Client(redis_url="redis://fake:6379")

        mock_r = AsyncMock()
        mock_r.incr = AsyncMock(return_value=5)
        # Ensure get is NOT called (would indicate a non-atomic pattern)
        mock_r.get = AsyncMock(side_effect=AssertionError("get must NOT be called in bump_version"))

        client._redis = mock_r

        new_version = await client.bump_version("task-atomic")

        assert new_version == 5
        mock_r.incr.assert_awaited_once()
        # Verify the incr key follows the expected pattern
        incr_key = mock_r.incr.call_args[0][0]
        assert "task-atomic" in incr_key
        assert incr_key.endswith(":version")


# ---------------------------------------------------------------------------
# WB10: RedisL1State.version defaults to 0
# ---------------------------------------------------------------------------

class TestRedisL1StateVersionDefault:

    def test_version_default_is_zero(self):
        """RedisL1State.version field defaults to 0 when not provided."""
        state = RedisL1State(
            task_id="t-default",
            session_id="s-1",
            active_task_id="t-default",
        )
        assert state.version == 0, (
            f"Expected version=0 (default), got version={state.version}"
        )

    def test_version_can_be_set_explicitly(self):
        """RedisL1State.version can be explicitly set to any non-negative int."""
        state = RedisL1State(
            task_id="t-versioned",
            session_id="s-1",
            active_task_id="t-versioned",
            version=42,
        )
        assert state.version == 42

    def test_focus_entities_defaults_to_empty_list(self):
        """RedisL1State.focus_entities defaults to an empty list (not None)."""
        state = RedisL1State(
            task_id="t-fe",
            session_id="s-2",
            active_task_id="t-fe",
        )
        assert state.focus_entities == []
        assert isinstance(state.focus_entities, list)

    def test_active_blocker_defaults_to_none(self):
        """RedisL1State.active_blocker defaults to None."""
        state = RedisL1State(
            task_id="t-ab",
            session_id="s-3",
            active_task_id="t-ab",
        )
        assert state.active_blocker is None
