#!/usr/bin/env python3
"""
Tests for Story 6.05 (Track B): StagingArea — ProposedDelta Collection Hub

Black Box tests (BB): verify the public contract from the outside —
    submit_delta stores a delta, collect_all returns all deltas,
    clear removes the hash, wait_for_all polls until complete or timeout.

White Box tests (WB): verify internal mechanics — correct Redis key pattern,
    EXPIRE called with 600s TTL, bytes are decoded, partial results on timeout,
    STAGING_KEY_PREFIX constant value.

Story: 6.05
File under test: core/coherence/staging_area.py

ALL tests use mocks — NO real Redis connection is made.
"""

from __future__ import annotations

import sys
sys.path.insert(0, "/mnt/e/genesis-system")

import asyncio
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, call, patch

import pytest

# ---------------------------------------------------------------------------
# Module under test
# ---------------------------------------------------------------------------

from core.coherence.staging_area import (
    StagingArea,
    STAGING_KEY_PREFIX,
    STAGING_TTL_SECONDS,
)
from core.coherence.state_delta import StateDelta


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def run(coro):
    """Run a coroutine synchronously (pytest-asyncio not required)."""
    return asyncio.get_event_loop().run_until_complete(coro)


def _make_delta(
    agent_id: str = "agent-alpha",
    session_id: str = "sess-001",
    version_at_read: int = 0,
    patch_ops: list | None = None,
) -> StateDelta:
    """Build a StateDelta fixture."""
    if patch_ops is None:
        patch_ops = [{"op": "add", "path": "/status", "value": "active"}]
    return StateDelta(
        agent_id=agent_id,
        session_id=session_id,
        version_at_read=version_at_read,
        patch=tuple(patch_ops),
        submitted_at=datetime(2026, 2, 25, 12, 0, 0, tzinfo=timezone.utc),
    )


def _serialized(delta: StateDelta) -> str:
    """Produce the JSON string that StagingArea.submit_delta stores."""
    return json.dumps(
        {
            "agent_id": delta.agent_id,
            "session_id": delta.session_id,
            "version_at_read": delta.version_at_read,
            "patch": list(delta.patch),
            "submitted_at": delta.submitted_at.isoformat(),
        }
    )


def _make_redis_empty() -> MagicMock:
    """Return a mock Redis client with an empty staging hash."""
    r = MagicMock()
    r.hset = AsyncMock(return_value=1)
    r.expire = AsyncMock(return_value=True)
    r.hgetall = AsyncMock(return_value={})
    r.hlen = AsyncMock(return_value=0)
    r.delete = AsyncMock(return_value=1)
    return r


def _make_redis_with_one_delta(delta: StateDelta) -> MagicMock:
    """Return a mock Redis client whose hgetall returns one serialized delta."""
    r = _make_redis_empty()
    r.hgetall = AsyncMock(
        return_value={delta.agent_id: _serialized(delta)}
    )
    r.hlen = AsyncMock(return_value=1)
    return r


def _make_redis_with_two_deltas(d1: StateDelta, d2: StateDelta) -> MagicMock:
    """Return a mock Redis client with two staged deltas."""
    r = _make_redis_empty()
    r.hgetall = AsyncMock(
        return_value={
            d1.agent_id: _serialized(d1),
            d2.agent_id: _serialized(d2),
        }
    )
    r.hlen = AsyncMock(return_value=2)
    return r


# ===========================================================================
# BLACK BOX TESTS
# ===========================================================================


def test_bb1_submit_then_collect_returns_matching_delta():
    """BB1: submit_delta then collect_all → returns dict matching the delta."""
    delta = _make_delta()
    # After submit, hgetall returns the stored delta
    r = _make_redis_with_one_delta(delta)
    sa = StagingArea(r)

    run(sa.submit_delta(delta))
    result = run(sa.collect_all(delta.session_id))

    assert len(result) == 1
    d = result[0]
    assert d["agent_id"] == delta.agent_id
    assert d["session_id"] == delta.session_id
    assert d["version_at_read"] == delta.version_at_read
    assert d["patch"] == list(delta.patch)


def test_bb2_multiple_agents_all_appear_in_collect_all():
    """BB2: Multiple agents submit → all appear in collect_all."""
    d1 = _make_delta(agent_id="agent-alpha")
    d2 = _make_delta(agent_id="agent-beta")
    r = _make_redis_with_two_deltas(d1, d2)
    sa = StagingArea(r)

    result = run(sa.collect_all(d1.session_id))

    assert len(result) == 2
    agent_ids = {d["agent_id"] for d in result}
    assert "agent-alpha" in agent_ids
    assert "agent-beta" in agent_ids


def test_bb3_clear_then_collect_returns_empty_list():
    """BB3: clear → subsequent collect_all returns empty list."""
    delta = _make_delta()
    r = _make_redis_empty()  # after clear, hgetall returns {}
    sa = StagingArea(r)

    run(sa.clear(delta.session_id))
    result = run(sa.collect_all(delta.session_id))

    assert result == []
    r.delete.assert_awaited_once_with(f"{STAGING_KEY_PREFIX}{delta.session_id}")


def test_bb4_wait_for_all_timeout_returns_partial_not_error():
    """
    BB4: wait_for_all with expected=2, only 1 submitted, timeout 200ms
         → returns partial (1 delta), does NOT raise.
    """
    delta = _make_delta()
    r = _make_redis_empty()
    # hlen always returns 1 (never reaches 2)
    r.hlen = AsyncMock(return_value=1)
    # collect_all returns the one delta we do have
    r.hgetall = AsyncMock(return_value={delta.agent_id: _serialized(delta)})
    sa = StagingArea(r)

    with patch("core.coherence.staging_area.asyncio.sleep", new=AsyncMock()):
        result = run(sa.wait_for_all(delta.session_id, expected_count=2, timeout_ms=200))

    # Partial result — 1 delta returned, no exception raised
    assert isinstance(result, list)
    assert len(result) == 1
    assert result[0]["agent_id"] == delta.agent_id


def test_bb5_wait_for_all_returns_immediately_when_count_met():
    """
    BB5: wait_for_all with all deltas already present → returns without
         waiting the full timeout (sleep called fewer times than timeout allows).
    """
    d1 = _make_delta(agent_id="agent-alpha")
    d2 = _make_delta(agent_id="agent-beta")
    r = _make_redis_with_two_deltas(d1, d2)
    sa = StagingArea(r)

    sleep_mock = AsyncMock()
    with patch("core.coherence.staging_area.asyncio.sleep", new=sleep_mock):
        result = run(sa.wait_for_all(d1.session_id, expected_count=2, timeout_ms=60_000))

    # Should have returned at first HLEN check — sleep never called
    sleep_mock.assert_not_awaited()
    assert len(result) == 2


# ===========================================================================
# WHITE BOX TESTS
# ===========================================================================


def test_wb1_hset_called_with_correct_key_pattern():
    """WB1: HSET called with correct key pattern genesis:staging:{session_id}."""
    delta = _make_delta(session_id="sess-wb1")
    r = _make_redis_empty()
    sa = StagingArea(r)

    run(sa.submit_delta(delta))

    expected_key = f"genesis:staging:sess-wb1"
    r.hset.assert_awaited_once_with(
        expected_key,
        delta.agent_id,
        _serialized(delta),
    )


def test_wb2_expire_called_with_600s_ttl_after_submit():
    """WB2: EXPIRE called with 600s TTL after each submit_delta."""
    delta = _make_delta(session_id="sess-wb2")
    r = _make_redis_empty()
    sa = StagingArea(r)

    run(sa.submit_delta(delta))

    r.expire.assert_awaited_once_with(
        f"genesis:staging:sess-wb2",
        600,
    )


def test_wb3_wait_for_all_returns_partial_on_timeout_not_raises():
    """WB3: wait_for_all returns partial results on timeout — does not raise."""
    r = _make_redis_empty()
    r.hlen = AsyncMock(return_value=0)  # always 0 — never meets expected_count
    r.hgetall = AsyncMock(return_value={})
    sa = StagingArea(r)

    with patch("core.coherence.staging_area.asyncio.sleep", new=AsyncMock()):
        result = run(sa.wait_for_all("sess-timeout", expected_count=5, timeout_ms=100))

    # Empty partial — no exception
    assert result == []


def test_wb4_handles_bytes_from_redis():
    """WB4: collect_all decodes bytes keys/values from Redis correctly."""
    delta = _make_delta(agent_id="agent-bytes", session_id="sess-bytes")
    serialized_bytes = _serialized(delta).encode("utf-8")
    agent_id_bytes = delta.agent_id.encode("utf-8")

    r = _make_redis_empty()
    # Simulate Redis returning bytes (decode_responses=False mode)
    r.hgetall = AsyncMock(return_value={agent_id_bytes: serialized_bytes})
    sa = StagingArea(r)

    result = run(sa.collect_all(delta.session_id))

    assert len(result) == 1
    assert result[0]["agent_id"] == "agent-bytes"
    assert result[0]["session_id"] == "sess-bytes"


def test_wb5_staging_key_prefix_constant():
    """WB5: STAGING_KEY_PREFIX is 'genesis:staging:'."""
    assert STAGING_KEY_PREFIX == "genesis:staging:"


# ===========================================================================
# PACKAGE EXPORT TESTS
# ===========================================================================


def test_package_exports_staging_area():
    """Package level: StagingArea importable from core.coherence."""
    from core.coherence import StagingArea as SA
    assert SA is StagingArea


def test_package_exports_staging_key_prefix():
    """Package level: STAGING_KEY_PREFIX importable from core.coherence."""
    from core.coherence import STAGING_KEY_PREFIX as SKP
    assert SKP == "genesis:staging:"


def test_package_exports_staging_ttl_seconds():
    """Package level: STAGING_TTL_SECONDS importable from core.coherence."""
    from core.coherence import STAGING_TTL_SECONDS as TTL
    assert TTL == 600


def test_staging_ttl_is_600():
    """STAGING_TTL_SECONDS module constant equals 600."""
    assert STAGING_TTL_SECONDS == 600


# ===========================================================================
# Standalone runner (pytest preferred, fallback to direct execution)
# ===========================================================================

if __name__ == "__main__":
    import traceback

    tests = [
        ("BB1: submit_delta then collect_all → matching delta dict", test_bb1_submit_then_collect_returns_matching_delta),
        ("BB2: Multiple agents → all in collect_all", test_bb2_multiple_agents_all_appear_in_collect_all),
        ("BB3: clear → collect_all returns empty list", test_bb3_clear_then_collect_returns_empty_list),
        ("BB4: wait_for_all timeout → partial result (not error)", test_bb4_wait_for_all_timeout_returns_partial_not_error),
        ("BB5: wait_for_all count met → returns immediately", test_bb5_wait_for_all_returns_immediately_when_count_met),
        ("WB1: HSET called with correct key pattern", test_wb1_hset_called_with_correct_key_pattern),
        ("WB2: EXPIRE called with 600s TTL", test_wb2_expire_called_with_600s_ttl_after_submit),
        ("WB3: wait_for_all partial on timeout, no raise", test_wb3_wait_for_all_returns_partial_on_timeout_not_raises),
        ("WB4: Bytes from Redis decoded correctly", test_wb4_handles_bytes_from_redis),
        ("WB5: STAGING_KEY_PREFIX == 'genesis:staging:'", test_wb5_staging_key_prefix_constant),
        ("PKG: StagingArea importable from core.coherence", test_package_exports_staging_area),
        ("PKG: STAGING_KEY_PREFIX importable from core.coherence", test_package_exports_staging_key_prefix),
        ("PKG: STAGING_TTL_SECONDS importable from core.coherence", test_package_exports_staging_ttl_seconds),
        ("STAGING_TTL_SECONDS == 600", test_staging_ttl_is_600),
    ]

    passed = 0
    total = len(tests)
    for name, fn in tests:
        try:
            fn()
            print(f"  [PASS] {name}")
            passed += 1
        except Exception as exc:
            print(f"  [FAIL] {name}: {exc}")
            traceback.print_exc()

    print(f"\n{passed}/{total} tests passed")
    if passed == total:
        print("ALL TESTS PASSED -- Story 6.05 (Track B)")
    else:
        sys.exit(1)
