#!/usr/bin/env python3
"""
Tests for Story 2.04 (Track B): scatter_gather_memory() — Concurrent 3-Layer Fetch

Black-box tests:
  BB1: All 3 fetches succeed (mock) → all 3 fields populated in MemoryContext
  BB2: Redis times out → working_state=None, other fields still populated
  BB3: All 3 time out → all fields None, latency_ms still recorded (> 0)
  BB4: Empty target_entities → kg_topology=None (short-circuit)

White-box tests:
  WB1: asyncio.gather called once (not 3 sequential awaits) — verify via mock
  WB2: Each fetch wrapped in asyncio.wait_for (individual timeout enforcement)
  WB3: latency_ms > 0 always (measures wall clock)
  WB4: Partial failure (1 of 3) still returns MemoryContext with populated fields
"""
import asyncio
import sys
import time
from pathlib import Path
from typing import Optional, List
from unittest.mock import AsyncMock, patch, MagicMock

sys.path.insert(0, '/mnt/e/genesis-system')

from core.memory.zero_amnesia_envelope import MemoryContext

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _run(coro):
    """Run async coroutine in test context."""
    return asyncio.get_event_loop().run_until_complete(coro)


# ---------------------------------------------------------------------------
# BB1: All 3 fetches succeed → all 3 fields populated
# ---------------------------------------------------------------------------

def test_bb1_all_layers_succeed():
    """BB1: When all 3 layers return data, MemoryContext has all 3 fields set."""
    from core.memory import scatter_gather

    async def mock_redis(task_id: str) -> Optional[str]:
        return "Task: t1, Focus: redis, Hypothesis: test hypothesis"

    async def mock_kg(entities: List[str]) -> Optional[str]:
        return "KG Topology:\nentity_1: type — description"

    async def mock_qdrant(intent: str) -> Optional[str]:
        return "Learned Constraints:\nScar abc: past mistake (score=0.95)"

    with patch.object(scatter_gather, '_fetch_redis_l1', side_effect=mock_redis), \
         patch.object(scatter_gather, '_fetch_kg_l2', side_effect=mock_kg), \
         patch.object(scatter_gather, '_fetch_qdrant_l3', side_effect=mock_qdrant):

        ctx = _run(scatter_gather.scatter_gather_memory(
            task_id="t1",
            target_entities=["redis"],
            intent_string="test intent",
            timeout_ms=500,
        ))

    assert isinstance(ctx, MemoryContext), f"Expected MemoryContext, got {type(ctx)}"
    assert ctx.working_state is not None, "BB1: working_state should be populated"
    assert ctx.kg_topology is not None, "BB1: kg_topology should be populated"
    assert ctx.learned_constraints is not None, "BB1: learned_constraints should be populated"
    assert "redis" in ctx.working_state.lower(), "BB1: working_state should contain task data"
    assert "KG Topology" in ctx.kg_topology, "BB1: kg_topology should contain KG data"
    assert "Learned Constraints" in ctx.learned_constraints, "BB1: learned_constraints should contain scar data"

    print("BB1 PASSED — all 3 fields populated when all layers succeed")


# ---------------------------------------------------------------------------
# BB2: Redis times out → working_state=None, other 2 fields populated
# ---------------------------------------------------------------------------

def test_bb2_redis_timeout_partial_result():
    """BB2: Redis timeout → working_state=None; kg_topology and learned_constraints still set."""
    from core.memory import scatter_gather

    async def slow_redis(task_id: str) -> Optional[str]:
        await asyncio.sleep(10)  # Will be killed by timeout_ms=50
        return "should not reach"

    async def fast_kg(entities: List[str]) -> Optional[str]:
        return "KG Topology:\nsome_entity: type — some description"

    async def fast_qdrant(intent: str) -> Optional[str]:
        return "Learned Constraints:\nScar x: known issue (score=0.80)"

    with patch.object(scatter_gather, '_fetch_redis_l1', side_effect=slow_redis), \
         patch.object(scatter_gather, '_fetch_kg_l2', side_effect=fast_kg), \
         patch.object(scatter_gather, '_fetch_qdrant_l3', side_effect=fast_qdrant):

        ctx = _run(scatter_gather.scatter_gather_memory(
            task_id="t2",
            target_entities=["entity"],
            intent_string="some intent",
            timeout_ms=50,
        ))

    assert ctx.working_state is None, "BB2: Redis timed out → working_state should be None"
    assert ctx.kg_topology is not None, "BB2: kg_topology should still be populated"
    assert ctx.learned_constraints is not None, "BB2: learned_constraints should still be populated"
    assert ctx.latency_ms > 0, "BB2: latency_ms must always be recorded"

    print("BB2 PASSED — Redis timeout → working_state=None, other fields populated")


# ---------------------------------------------------------------------------
# BB3: All 3 time out → all fields None, latency_ms > 0
# ---------------------------------------------------------------------------

def test_bb3_all_layers_timeout():
    """BB3: All 3 fetches timeout → all fields None, latency_ms still recorded."""
    from core.memory import scatter_gather

    async def slow(arg) -> Optional[str]:
        await asyncio.sleep(10)
        return "should not reach"

    with patch.object(scatter_gather, '_fetch_redis_l1', side_effect=slow), \
         patch.object(scatter_gather, '_fetch_kg_l2', side_effect=slow), \
         patch.object(scatter_gather, '_fetch_qdrant_l3', side_effect=slow):

        ctx = _run(scatter_gather.scatter_gather_memory(
            task_id="t3",
            target_entities=["any"],
            intent_string="any intent",
            timeout_ms=30,
        ))

    assert ctx.working_state is None, "BB3: All timed out → working_state should be None"
    assert ctx.kg_topology is None, "BB3: All timed out → kg_topology should be None"
    assert ctx.learned_constraints is None, "BB3: All timed out → learned_constraints should be None"
    assert ctx.latency_ms > 0, "BB3: latency_ms must be > 0 even when all layers fail"

    print("BB3 PASSED — all layers timeout → all fields None, latency_ms > 0")


# ---------------------------------------------------------------------------
# BB4: Empty target_entities → kg_topology=None (short-circuit)
# ---------------------------------------------------------------------------

def test_bb4_empty_entities_short_circuits_kg():
    """BB4: Empty target_entities list → _fetch_kg_l2 returns None immediately."""
    from core.memory import scatter_gather

    async def mock_redis(task_id: str) -> Optional[str]:
        return "Task: t4, Focus: none, Hypothesis: empty"

    async def mock_qdrant(intent: str) -> Optional[str]:
        return "Learned Constraints:\nScar y: edge case (score=0.70)"

    with patch.object(scatter_gather, '_fetch_redis_l1', side_effect=mock_redis), \
         patch.object(scatter_gather, '_fetch_qdrant_l3', side_effect=mock_qdrant):

        ctx = _run(scatter_gather.scatter_gather_memory(
            task_id="t4",
            target_entities=[],   # empty list — should short-circuit KG
            intent_string="some intent",
            timeout_ms=500,
        ))

    assert ctx.kg_topology is None, \
        "BB4: Empty target_entities → kg_topology must be None (short-circuit)"
    assert ctx.working_state is not None, "BB4: working_state should still be populated"
    assert ctx.learned_constraints is not None, "BB4: learned_constraints should still be populated"

    print("BB4 PASSED — empty target_entities → kg_topology=None")


# ---------------------------------------------------------------------------
# WB1: asyncio.gather called ONCE (not 3 sequential awaits)
# ---------------------------------------------------------------------------

def test_wb1_gather_called_once():
    """WB1: scatter_gather_memory() must call asyncio.gather exactly once."""
    from core.memory import scatter_gather

    gather_call_count = []
    original_gather = asyncio.gather

    async def tracking_gather(*coros, **kwargs):
        gather_call_count.append(1)
        return await original_gather(*coros, **kwargs)

    async def mock_redis(task_id: str) -> Optional[str]:
        return "state"

    async def mock_kg(entities: List[str]) -> Optional[str]:
        return "kg"

    async def mock_qdrant(intent: str) -> Optional[str]:
        return "scars"

    with patch.object(scatter_gather, '_fetch_redis_l1', side_effect=mock_redis), \
         patch.object(scatter_gather, '_fetch_kg_l2', side_effect=mock_kg), \
         patch.object(scatter_gather, '_fetch_qdrant_l3', side_effect=mock_qdrant), \
         patch('asyncio.gather', side_effect=tracking_gather):

        _run(scatter_gather.scatter_gather_memory(
            task_id="wb1",
            target_entities=["x"],
            intent_string="test",
            timeout_ms=500,
        ))

    assert len(gather_call_count) == 1, \
        f"WB1: asyncio.gather should be called exactly once, got {len(gather_call_count)}"

    print("WB1 PASSED — asyncio.gather called exactly once")


# ---------------------------------------------------------------------------
# WB2: Each fetch wrapped in asyncio.wait_for (individual timeout enforcement)
# ---------------------------------------------------------------------------

def test_wb2_wait_for_wraps_each_fetch():
    """WB2: asyncio.wait_for must be called 3 times (once per layer)."""
    from core.memory import scatter_gather

    wait_for_calls = []
    original_wait_for = asyncio.wait_for

    async def tracking_wait_for(coro, timeout):
        wait_for_calls.append(timeout)
        return await original_wait_for(coro, timeout)

    async def mock_redis(task_id: str) -> Optional[str]:
        return "state"

    async def mock_kg(entities: List[str]) -> Optional[str]:
        return "kg"

    async def mock_qdrant(intent: str) -> Optional[str]:
        return "scars"

    with patch.object(scatter_gather, '_fetch_redis_l1', side_effect=mock_redis), \
         patch.object(scatter_gather, '_fetch_kg_l2', side_effect=mock_kg), \
         patch.object(scatter_gather, '_fetch_qdrant_l3', side_effect=mock_qdrant), \
         patch('asyncio.wait_for', side_effect=tracking_wait_for):

        _run(scatter_gather.scatter_gather_memory(
            task_id="wb2",
            target_entities=["y"],
            intent_string="intent",
            timeout_ms=100,
        ))

    assert len(wait_for_calls) == 3, \
        f"WB2: asyncio.wait_for should be called 3 times, got {len(wait_for_calls)}"

    # All 3 should share the same timeout (100ms → 0.1s)
    expected_timeout_s = 100 / 1000.0
    for i, t in enumerate(wait_for_calls):
        assert abs(t - expected_timeout_s) < 1e-9, \
            f"WB2: wait_for call {i} has wrong timeout: {t}"

    print("WB2 PASSED — asyncio.wait_for called 3 times with correct per-fetch timeout")


# ---------------------------------------------------------------------------
# WB3: latency_ms > 0 always (wall-clock measurement)
# ---------------------------------------------------------------------------

def test_wb3_latency_always_positive():
    """WB3: latency_ms is always > 0 (real wall-clock elapsed)."""
    from core.memory import scatter_gather

    async def instant_none(arg) -> Optional[str]:
        return None

    with patch.object(scatter_gather, '_fetch_redis_l1', side_effect=instant_none), \
         patch.object(scatter_gather, '_fetch_kg_l2', side_effect=instant_none), \
         patch.object(scatter_gather, '_fetch_qdrant_l3', side_effect=instant_none):

        ctx = _run(scatter_gather.scatter_gather_memory(
            task_id="wb3",
            target_entities=[],
            intent_string="",
            timeout_ms=500,
        ))

    assert ctx.latency_ms > 0, \
        f"WB3: latency_ms must always be > 0, got {ctx.latency_ms}"
    assert isinstance(ctx.latency_ms, float), \
        f"WB3: latency_ms must be a float, got {type(ctx.latency_ms)}"

    print(f"WB3 PASSED — latency_ms={ctx.latency_ms:.3f}ms (always positive)")


# ---------------------------------------------------------------------------
# WB4: Partial failure (1 of 3) still returns valid MemoryContext
# ---------------------------------------------------------------------------

def test_wb4_partial_failure_valid_return():
    """WB4: One layer raises an exception → other two fields still populated."""
    from core.memory import scatter_gather

    async def raising_redis(task_id: str) -> Optional[str]:
        raise RuntimeError("Redis connection refused")

    async def good_kg(entities: List[str]) -> Optional[str]:
        return "KG Topology:\ngood_entity: GoodType — works fine"

    async def good_qdrant(intent: str) -> Optional[str]:
        return "Learned Constraints:\nScar z: something learned (score=0.85)"

    with patch.object(scatter_gather, '_fetch_redis_l1', side_effect=raising_redis), \
         patch.object(scatter_gather, '_fetch_kg_l2', side_effect=good_kg), \
         patch.object(scatter_gather, '_fetch_qdrant_l3', side_effect=good_qdrant):

        ctx = _run(scatter_gather.scatter_gather_memory(
            task_id="wb4",
            target_entities=["entity"],
            intent_string="test intent",
            timeout_ms=500,
        ))

    # Failing layer → None; other layers → populated
    assert ctx.working_state is None, \
        "WB4: Redis raised → working_state should be None"
    assert ctx.kg_topology is not None, \
        "WB4: KG succeeded → kg_topology should be populated"
    assert ctx.learned_constraints is not None, \
        "WB4: Qdrant succeeded → learned_constraints should be populated"
    assert ctx.latency_ms > 0, "WB4: latency_ms must still be recorded"

    print("WB4 PASSED — partial failure (1/3) returns valid MemoryContext with 2/3 fields")


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    print("=" * 60)
    print("Story 2.04 — scatter_gather_memory() Tests")
    print("=" * 60)

    test_bb1_all_layers_succeed()
    test_bb2_redis_timeout_partial_result()
    test_bb3_all_layers_timeout()
    test_bb4_empty_entities_short_circuits_kg()
    test_wb1_gather_called_once()
    test_wb2_wait_for_wraps_each_fetch()
    test_wb3_latency_always_positive()
    test_wb4_partial_failure_valid_return()

    print("=" * 60)
    print("ALL 8 TESTS PASSED — Story 2.04 (Track B)")
    print("=" * 60)
