#!/usr/bin/env python3
"""
Tests for Story 6.04 (Track B): SwarmWorkerBase — XREADGROUP Consumer

Black Box tests (BB): verify public contract — ACK on success, no ACK on
    failure, stop() exits loop, _reclaim_pending runs on startup.
White Box tests (WB): verify internal mechanics — XREADGROUP uses ">" for
    new messages, XAUTOCLAIM called with min_idle_time=60000, bytes decoded,
    process() is abstract, PEL_TIMEOUT_MS constant, staging.submit_delta called.

Story: 6.04
File under test: core/coherence/swarm_worker_base.py

ALL tests use mocks — NO real Redis connection is made.
NO SQLite anywhere in this module.
"""

from __future__ import annotations

import sys
sys.path.insert(0, "/mnt/e/genesis-system")

import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, call


# ---------------------------------------------------------------------------
# Module under test
# ---------------------------------------------------------------------------

from core.coherence.swarm_worker_base import (
    SwarmWorkerBase,
    PEL_TIMEOUT_MS,
)
from core.coherence.task_dag_pusher import STREAM_KEY, DEFAULT_GROUP


# ---------------------------------------------------------------------------
# Async test helper
# ---------------------------------------------------------------------------

def run(coro):
    """Run a coroutine synchronously (pytest-asyncio not required)."""
    loop = asyncio.new_event_loop()
    try:
        return loop.run_until_complete(coro)
    finally:
        loop.close()


# ---------------------------------------------------------------------------
# Concrete test subclass (SwarmWorkerBase is abstract)
# ---------------------------------------------------------------------------

class EchoWorker(SwarmWorkerBase):
    """
    Concrete test subclass.
    process() returns the task dict unchanged by default.
    Pass raise_exc to make process() raise that exception.
    """

    def __init__(self, redis_client, staging_area=None, *, raise_exc=None):
        super().__init__(redis_client, staging_area)
        self.raise_exc = raise_exc
        self.processed: list = []

    async def process(self, task: dict):
        self.processed.append(task)
        if self.raise_exc is not None:
            raise self.raise_exc
        return task


# ---------------------------------------------------------------------------
# Mock builder helpers
# ---------------------------------------------------------------------------

def _make_single_shot_redis(
    entry_id="1700000000001-0",
    fields=None,
):
    """
    Build a mock Redis that:
    - xautoclaim → returns ("0-0", [], []) immediately
    - xreadgroup → returns one entry on first call, then sets worker._running=False
                   via a side-effect so the loop exits cleanly
    - xack → AsyncMock

    The worker loop exits when _running=False after the first iteration.
    This avoids any asyncio.sleep in tests (and infinite loops).
    """
    if fields is None:
        fields = {b"task_type": b"test", b"session_id": b"sess-single"}

    eid_bytes = entry_id.encode() if isinstance(entry_id, str) else entry_id
    entry = (eid_bytes, fields)
    stream_name = STREAM_KEY.encode()

    mock_redis = MagicMock()
    mock_redis.xautoclaim = AsyncMock(return_value=("0-0", [], []))
    mock_redis.xack = AsyncMock(return_value=1)

    _state = {"calls": 0, "worker": None}

    async def _xreadgroup(group, consumer_id, streams, count=1, block=5000):
        _state["calls"] += 1
        if _state["calls"] == 1:
            return [(stream_name, [entry])]
        # After first call: stop the worker so the loop exits
        if _state["worker"] is not None:
            _state["worker"].stop()
        return []

    mock_redis.xreadgroup = AsyncMock(side_effect=_xreadgroup)
    mock_redis._state = _state  # expose so caller can bind worker ref

    return mock_redis


def _make_empty_redis(stop_after_calls=2):
    """
    Build a mock Redis that returns empty xreadgroup after `stop_after_calls` calls.
    xautoclaim returns ("0-0", [], []).
    Caller must bind worker ref via redis._state['worker'] to stop the loop.
    """
    mock_redis = MagicMock()
    mock_redis.xautoclaim = AsyncMock(return_value=("0-0", [], []))
    mock_redis.xack = AsyncMock(return_value=1)

    _state = {"calls": 0, "worker": None}

    async def _xreadgroup(group, consumer_id, streams, count=1, block=5000):
        _state["calls"] += 1
        if _state["calls"] >= stop_after_calls and _state["worker"] is not None:
            _state["worker"].stop()
        return []

    mock_redis.xreadgroup = AsyncMock(side_effect=_xreadgroup)
    mock_redis._state = _state

    return mock_redis


# ===========================================================================
# BLACK BOX TESTS
# ===========================================================================


def test_bb1_successful_process_sends_xack():
    """BB1: When process() succeeds, XACK is sent for the entry."""
    entry_id = "1700000000001-0"
    redis = _make_single_shot_redis(entry_id=entry_id)

    worker = EchoWorker(redis)
    redis._state["worker"] = worker

    async def _run():
        await worker.run_worker_loop(group=DEFAULT_GROUP, consumer_id="test-c1")

    run(_run())

    redis.xack.assert_called_once_with(STREAM_KEY, DEFAULT_GROUP, entry_id)


def test_bb2_failed_process_does_not_send_xack():
    """BB2: When process() raises an exception, XACK is NOT sent."""
    entry_id = "1700000000002-0"
    redis = _make_single_shot_redis(entry_id=entry_id)

    worker = EchoWorker(redis, raise_exc=RuntimeError("boom"))
    redis._state["worker"] = worker

    async def _run():
        await worker.run_worker_loop(group=DEFAULT_GROUP, consumer_id="test-c2")

    run(_run())

    redis.xack.assert_not_called()


def test_bb3_stop_sets_running_false():
    """BB3: stop() sets _running to False."""
    redis = _make_empty_redis()
    worker = EchoWorker(redis)

    assert worker._running is False  # starts False

    worker._running = True  # simulate started state
    worker.stop()

    assert worker._running is False


def test_bb4_reclaim_pending_called_on_startup():
    """BB4: XAUTOCLAIM is called once when run_worker_loop starts."""
    redis = _make_empty_redis(stop_after_calls=1)
    worker = EchoWorker(redis)
    redis._state["worker"] = worker

    async def _run():
        await worker.run_worker_loop(group=DEFAULT_GROUP, consumer_id="test-c4")

    run(_run())

    assert redis.xautoclaim.call_count >= 1


# ===========================================================================
# WHITE BOX TESTS
# ===========================================================================


def test_wb1_xreadgroup_uses_greater_than_for_new_messages():
    """WB1: XREADGROUP is called with streams={STREAM_KEY: '>'} for new messages."""
    redis = _make_empty_redis(stop_after_calls=1)
    worker = EchoWorker(redis)
    redis._state["worker"] = worker

    async def _run():
        await worker.run_worker_loop(group="genesis_workers", consumer_id="test-wb1")

    run(_run())

    assert redis.xreadgroup.call_count >= 1
    # Inspect the streams argument — must use ">"
    first_call = redis.xreadgroup.call_args_list[0]
    streams_arg = first_call[0][2]  # third positional argument
    assert STREAM_KEY in streams_arg, (
        f"STREAM_KEY not in streams argument: {streams_arg}"
    )
    assert streams_arg[STREAM_KEY] == ">", (
        f"Expected '>' for new messages, got {streams_arg[STREAM_KEY]!r}"
    )


def test_wb2_xautoclaim_called_with_pel_timeout_60000():
    """WB2: _reclaim_pending calls XAUTOCLAIM with min_idle_time=60000."""
    redis = _make_empty_redis()
    worker = EchoWorker(redis)

    run(worker._reclaim_pending("genesis_workers", "test-wb2"))

    redis.xautoclaim.assert_called_once()
    call_kwargs = redis.xautoclaim.call_args
    assert call_kwargs[1]["min_idle_time"] == 60000, (
        f"Expected min_idle_time=60000, got {call_kwargs[1].get('min_idle_time')}"
    )


def test_wb3_bytes_decoded_from_redis():
    """WB3: Bytes keys and values from Redis are decoded to str before process()."""
    entry_id = "1700000000003-0"
    fields = {
        b"task_type": b"research",
        b"session_id": b"sess-bytes",
        b"payload": b'{"query": "test"}',
    }
    redis = _make_single_shot_redis(entry_id=entry_id, fields=fields)

    worker = EchoWorker(redis)
    redis._state["worker"] = worker

    async def _run():
        await worker.run_worker_loop(group=DEFAULT_GROUP, consumer_id="test-wb3")

    run(_run())

    assert len(worker.processed) == 1
    received = worker.processed[0]
    for k, v in received.items():
        assert isinstance(k, str), f"Key {k!r} should be str, got {type(k).__name__}"
        assert isinstance(v, str), f"Value {v!r} should be str, got {type(v).__name__}"

    assert received["task_type"] == "research"
    assert received["session_id"] == "sess-bytes"


def test_wb4_process_is_abstract_cannot_instantiate_base():
    """WB4: SwarmWorkerBase is abstract — cannot be instantiated directly."""
    redis = _make_empty_redis()
    with pytest.raises(TypeError):
        SwarmWorkerBase(redis)  # type: ignore[abstract]


def test_wb5_pel_timeout_ms_constant_is_60000():
    """WB5: PEL_TIMEOUT_MS module-level constant equals 60000."""
    assert PEL_TIMEOUT_MS == 60000, (
        f"PEL_TIMEOUT_MS should be 60000, got {PEL_TIMEOUT_MS}"
    )


def test_wb6_staging_submit_delta_called_when_result_not_none():
    """WB6: When process() returns non-None and staging_area is set,
    staging.submit_delta is called with the result."""
    entry_id = "1700000000006-0"
    redis = _make_single_shot_redis(entry_id=entry_id)

    staging = MagicMock()
    staging.submit_delta = AsyncMock()

    worker = EchoWorker(redis, staging_area=staging)
    redis._state["worker"] = worker

    async def _run():
        await worker.run_worker_loop(group=DEFAULT_GROUP, consumer_id="test-wb6")

    run(_run())

    staging.submit_delta.assert_called_once()


# ===========================================================================
# Package export tests
# ===========================================================================


def test_package_exports_swarm_worker_base():
    """Package level: SwarmWorkerBase importable from core.coherence."""
    from core.coherence import SwarmWorkerBase as SWB
    assert SWB is SwarmWorkerBase


def test_package_exports_pel_timeout_ms():
    """Package level: PEL_TIMEOUT_MS importable from core.coherence."""
    from core.coherence import PEL_TIMEOUT_MS as PTM
    assert PTM == 60000


def test_reclaim_pending_returns_count_of_reclaimed_entries():
    """_reclaim_pending returns the count of reclaimed entries from XAUTOCLAIM."""
    reclaimed_entries = [
        ("1700000000010-0", {b"task_type": b"old"}),
        ("1700000000011-0", {b"task_type": b"stale"}),
    ]
    redis = _make_empty_redis()
    redis.xautoclaim = AsyncMock(return_value=("0-0", reclaimed_entries, []))

    worker = EchoWorker(redis)
    count = run(worker._reclaim_pending("genesis_workers", "worker-reclaim"))

    assert count == 2, f"Expected 2 reclaimed entries, got {count}"


def test_reclaim_pending_returns_zero_on_xautoclaim_exception():
    """_reclaim_pending returns 0 (does not raise) when xautoclaim throws."""
    redis = _make_empty_redis()
    redis.xautoclaim = AsyncMock(side_effect=Exception("NOGROUP No such group"))

    worker = EchoWorker(redis)
    count = run(worker._reclaim_pending("genesis_workers", "worker-err"))

    assert count == 0


def test_xreadgroup_exception_does_not_crash_loop():
    """xreadgroup raising an exception logs error and continues — loop does not crash."""
    redis = _make_empty_redis()

    _calls = [0]

    async def _xreadgroup_fails_first(group, consumer_id, streams, count=1, block=5000):
        _calls[0] += 1
        if _calls[0] == 1:
            raise ConnectionError("Redis connection lost")
        # Second call: stop and return empty
        redis._state["worker"].stop()
        return []

    redis.xreadgroup = AsyncMock(side_effect=_xreadgroup_fails_first)

    worker = EchoWorker(redis)
    redis._state["worker"] = worker

    async def _run():
        await worker.run_worker_loop(group=DEFAULT_GROUP, consumer_id="test-exc")

    # Must not raise
    run(_run())

    assert redis.xreadgroup.call_count >= 1


def test_staging_not_called_when_process_returns_none():
    """staging.submit_delta is NOT called when process() returns None;
    but XACK is still sent (process returned None without raising)."""

    class NoneWorker(SwarmWorkerBase):
        async def process(self, task: dict):
            return None

    entry_id = "1700000000099-0"
    redis = _make_single_shot_redis(entry_id=entry_id)

    staging = MagicMock()
    staging.submit_delta = AsyncMock()

    worker = NoneWorker(redis, staging_area=staging)
    redis._state["worker"] = worker

    async def _run():
        await worker.run_worker_loop(group=DEFAULT_GROUP, consumer_id="test-none")

    run(_run())

    staging.submit_delta.assert_not_called()
    # XACK still sent — process() returned None, did not raise
    redis.xack.assert_called_once()


def test_staging_not_called_when_no_staging_area():
    """No error when staging_area is None (default) and process() returns a result."""
    entry_id = "1700000000098-0"
    redis = _make_single_shot_redis(entry_id=entry_id)

    worker = EchoWorker(redis, staging_area=None)
    redis._state["worker"] = worker

    async def _run():
        await worker.run_worker_loop(group=DEFAULT_GROUP, consumer_id="test-nostaging")

    # Must not raise even with staging=None
    run(_run())
    redis.xack.assert_called_once()


def test_xautoclaim_called_with_start_id_zero():
    """_reclaim_pending calls XAUTOCLAIM with start_id='0-0' to scan all PEL entries."""
    redis = _make_empty_redis()
    worker = EchoWorker(redis)

    run(worker._reclaim_pending("genesis_workers", "test-start-id"))

    call_kwargs = redis.xautoclaim.call_args
    assert call_kwargs[1]["start_id"] == "0-0", (
        f"Expected start_id='0-0', got {call_kwargs[1].get('start_id')!r}"
    )


# ===========================================================================
# Standalone runner (pytest preferred, fallback to direct execution)
# ===========================================================================

if __name__ == "__main__":
    import traceback

    tests = [
        ("BB1: Successful process → XACK sent",
         test_bb1_successful_process_sends_xack),
        ("BB2: Failed process → no XACK sent",
         test_bb2_failed_process_does_not_send_xack),
        ("BB3: stop() sets _running=False",
         test_bb3_stop_sets_running_false),
        ("BB4: XAUTOCLAIM called on startup",
         test_bb4_reclaim_pending_called_on_startup),
        ("WB1: XREADGROUP uses '>' for new messages",
         test_wb1_xreadgroup_uses_greater_than_for_new_messages),
        ("WB2: XAUTOCLAIM called with min_idle_time=60000",
         test_wb2_xautoclaim_called_with_pel_timeout_60000),
        ("WB3: Bytes from Redis decoded to str",
         test_wb3_bytes_decoded_from_redis),
        ("WB4: process() is abstract — cannot instantiate base",
         test_wb4_process_is_abstract_cannot_instantiate_base),
        ("WB5: PEL_TIMEOUT_MS constant is 60000",
         test_wb5_pel_timeout_ms_constant_is_60000),
        ("WB6: staging.submit_delta called when result not None",
         test_wb6_staging_submit_delta_called_when_result_not_none),
        ("PKG: SwarmWorkerBase importable from core.coherence",
         test_package_exports_swarm_worker_base),
        ("PKG: PEL_TIMEOUT_MS importable from core.coherence",
         test_package_exports_pel_timeout_ms),
        ("_reclaim_pending returns count of reclaimed entries",
         test_reclaim_pending_returns_count_of_reclaimed_entries),
        ("_reclaim_pending returns 0 on xautoclaim exception",
         test_reclaim_pending_returns_zero_on_xautoclaim_exception),
        ("xreadgroup exception does not crash loop",
         test_xreadgroup_exception_does_not_crash_loop),
        ("staging NOT called when process returns None",
         test_staging_not_called_when_process_returns_none),
        ("staging NOT called when no staging_area",
         test_staging_not_called_when_no_staging_area),
        ("XAUTOCLAIM called with start_id='0-0'",
         test_xautoclaim_called_with_start_id_zero),
    ]

    passed = 0
    total = len(tests)
    for name, fn in tests:
        try:
            fn()
            print(f"  [PASS] {name}")
            passed += 1
        except Exception as exc:
            print(f"  [FAIL] {name}: {exc}")
            traceback.print_exc()

    print(f"\n{passed}/{total} tests passed")
    if passed == total:
        print("ALL TESTS PASSED -- Story 6.04 (Track B)")
    else:
        sys.exit(1)
