#!/usr/bin/env python3
"""
Tests for Story 6.03 (Track B): TaskDAGPusher — Redis Streams Task Queue

Black Box tests (BB): verify the public contract from the outside —
    push_dag return type/count, create_consumer_group idempotency,
    get_stream_length delegation.
White Box tests (WB): verify internal mechanics — XADD called once per task,
    entry IDs vs task_id UUIDs, payload JSON encoding, STREAM_KEY constant.

Story: 6.03
File under test: core/coherence/task_dag_pusher.py

ALL tests use mocks — NO real Redis connection is made.
NO SQLite anywhere in this module.
"""

from __future__ import annotations

import sys
sys.path.insert(0, "/mnt/e/genesis-system")

import asyncio
import json
import uuid
import pytest
from unittest.mock import AsyncMock, MagicMock, call


# ---------------------------------------------------------------------------
# Module under test
# ---------------------------------------------------------------------------

from core.coherence.task_dag_pusher import (
    TaskDAGPusher,
    STREAM_KEY,
    DEFAULT_GROUP,
)


# ---------------------------------------------------------------------------
# Async test helper
# ---------------------------------------------------------------------------

def run(coro):
    """Run a coroutine synchronously (pytest-asyncio not required)."""
    return asyncio.get_event_loop().run_until_complete(coro)


# ---------------------------------------------------------------------------
# Mock builders
# ---------------------------------------------------------------------------

def _make_redis(stream_len: int = 5) -> MagicMock:
    """
    Build a minimal mock Redis client that supports:
        xadd()         — returns incrementing fake stream entry IDs
        xgroup_create()— returns True by default
        xlen()         — returns stream_len
    """
    mock_redis = MagicMock()

    # xadd returns unique fake stream entry IDs
    _counter = [0]

    async def _xadd(key, fields):
        _counter[0] += 1
        return f"16960000000{_counter[0]:02d}-0"

    mock_redis.xadd = AsyncMock(side_effect=_xadd)
    mock_redis.xgroup_create = AsyncMock(return_value=True)
    mock_redis.xlen = AsyncMock(return_value=stream_len)

    return mock_redis


def _make_redis_busygroup() -> MagicMock:
    """
    Build a mock Redis client where xgroup_create raises a BUSYGROUP error.
    This simulates the consumer group already existing.
    """
    mock_redis = _make_redis()
    mock_redis.xgroup_create = AsyncMock(
        side_effect=Exception("BUSYGROUP Consumer Group name already exists")
    )
    return mock_redis


def _make_redis_xgroup_error() -> MagicMock:
    """
    Build a mock Redis client where xgroup_create raises a non-BUSYGROUP error.
    """
    mock_redis = _make_redis()
    mock_redis.xgroup_create = AsyncMock(
        side_effect=Exception("CONNECTION_REFUSED Cannot connect to Redis")
    )
    return mock_redis


# ===========================================================================
# BLACK BOX TESTS
# ===========================================================================


def test_bb1_push_dag_two_tasks_calls_xadd_twice_returns_two_ids():
    """BB1: push_dag with 2 tasks → redis.xadd called twice → returns 2 entry IDs."""
    mock_redis = _make_redis()
    pusher = TaskDAGPusher(mock_redis)

    tasks = [
        {"task_type": "research", "payload": {"query": "AI"}, "tier": "T2", "priority": "high"},
        {"task_type": "synthesize", "payload": {"topic": "AI"}, "tier": "T1", "priority": "normal"},
    ]

    entry_ids = run(pusher.push_dag("sess-bb1", tasks))

    assert mock_redis.xadd.call_count == 2, (
        f"Expected xadd called twice, got {mock_redis.xadd.call_count}"
    )
    assert len(entry_ids) == 2, f"Expected 2 entry IDs, got {len(entry_ids)}"


def test_bb2_stream_entry_contains_all_required_fields():
    """BB2: Stream entry contains all required fields — session_id, task_id,
    task_type, payload, tier, priority."""
    mock_redis = _make_redis()
    pusher = TaskDAGPusher(mock_redis)

    tasks = [{"task_type": "audit", "payload": {"url": "https://example.com"}, "tier": "T3", "priority": "low"}]
    run(pusher.push_dag("sess-bb2", tasks))

    assert mock_redis.xadd.call_count == 1
    xadd_call = mock_redis.xadd.call_args
    # xadd(STREAM_KEY, fields) — positional args
    _, fields = xadd_call[0]

    required_fields = {"session_id", "task_id", "task_type", "payload", "tier", "priority"}
    missing = required_fields - set(fields.keys())
    assert not missing, f"Missing required fields in stream entry: {missing}"


def test_bb3_create_consumer_group_succeeds_on_first_call():
    """BB3: create_consumer_group succeeds on first call (no exception) → returns True."""
    mock_redis = _make_redis()
    pusher = TaskDAGPusher(mock_redis)

    result = run(pusher.create_consumer_group())

    assert result is True
    assert mock_redis.xgroup_create.call_count == 1


def test_bb4_create_consumer_group_busygroup_returns_true_idempotent():
    """BB4: create_consumer_group with BUSYGROUP error → returns True (idempotent)."""
    mock_redis = _make_redis_busygroup()
    pusher = TaskDAGPusher(mock_redis)

    result = run(pusher.create_consumer_group())

    assert result is True, (
        "create_consumer_group must return True when BUSYGROUP error is raised"
    )


def test_bb5_get_stream_length_returns_xlen_value():
    """BB5: get_stream_length returns the value from redis.xlen."""
    mock_redis = _make_redis(stream_len=42)
    pusher = TaskDAGPusher(mock_redis)

    length = run(pusher.get_stream_length())

    assert length == 42
    mock_redis.xlen.assert_called_once_with(STREAM_KEY)


# ===========================================================================
# WHITE BOX TESTS
# ===========================================================================


def test_wb1_xadd_called_once_per_task_not_batch():
    """WB1: XADD is called once per individual task, never batched."""
    mock_redis = _make_redis()
    pusher = TaskDAGPusher(mock_redis)

    tasks = [
        {"task_type": "t1", "payload": {}, "tier": "T1", "priority": "normal"},
        {"task_type": "t2", "payload": {}, "tier": "T1", "priority": "normal"},
        {"task_type": "t3", "payload": {}, "tier": "T2", "priority": "high"},
    ]

    run(pusher.push_dag("sess-wb1", tasks))

    # Each task maps to exactly one XADD call
    assert mock_redis.xadd.call_count == len(tasks), (
        f"Expected {len(tasks)} xadd calls (one per task), "
        f"got {mock_redis.xadd.call_count}"
    )

    # Every call targets the correct stream key
    for c in mock_redis.xadd.call_args_list:
        key_arg = c[0][0]
        assert key_arg == STREAM_KEY, f"Unexpected stream key: {key_arg!r}"


def test_wb2_returns_list_of_stream_entry_ids_not_task_uuids():
    """WB2: push_dag returns the list of stream entry IDs (from XADD return
    values), not the generated task_id UUIDs."""
    mock_redis = _make_redis()
    pusher = TaskDAGPusher(mock_redis)

    tasks = [
        {"task_type": "ping", "payload": {}, "tier": "T1", "priority": "normal"},
        {"task_type": "pong", "payload": {}, "tier": "T1", "priority": "normal"},
    ]

    entry_ids = run(pusher.push_dag("sess-wb2", tasks))

    # The fake xadd returns IDs like "1696000000001-0"
    for eid in entry_ids:
        assert "-" in eid, (
            f"Entry ID should look like a Redis stream ID (e.g. '1234-0'), got: {eid!r}"
        )
        # Confirm it is NOT a bare UUID (which would have 5 hyphen-separated groups)
        parts = eid.split("-")
        assert len(parts) == 2, (
            f"Entry ID should be a Redis stream ID (two '-'-separated parts), got: {eid!r}"
        )


def test_wb3_task_id_in_entry_is_valid_uuid4():
    """WB3: The task_id field in every stream entry is a valid UUID4 string."""
    mock_redis = _make_redis()
    pusher = TaskDAGPusher(mock_redis)

    tasks = [
        {"task_type": "job-a", "payload": {}, "tier": "T1", "priority": "normal"},
        {"task_type": "job-b", "payload": {}, "tier": "T2", "priority": "high"},
    ]

    run(pusher.push_dag("sess-wb3", tasks))

    assert mock_redis.xadd.call_count == 2
    for c in mock_redis.xadd.call_args_list:
        _, fields = c[0]
        task_id_str = fields["task_id"]
        # Parse and validate as UUID; check version==4
        try:
            parsed = uuid.UUID(task_id_str)
        except ValueError:
            pytest.fail(f"task_id {task_id_str!r} is not a valid UUID")
        assert parsed.version == 4, (
            f"task_id {task_id_str!r} is not UUID version 4 (got v{parsed.version})"
        )


def test_wb4_payload_field_is_json_encoded_string_not_raw_dict():
    """WB4: The payload field in the stream entry is a JSON-encoded string,
    not a raw Python dict."""
    mock_redis = _make_redis()
    pusher = TaskDAGPusher(mock_redis)

    payload_data = {"url": "https://genesis.ai", "depth": 3}
    tasks = [{"task_type": "crawl", "payload": payload_data, "tier": "T2", "priority": "normal"}]

    run(pusher.push_dag("sess-wb4", tasks))

    xadd_call = mock_redis.xadd.call_args
    _, fields = xadd_call[0]
    payload_field = fields["payload"]

    # Must be a string, not a dict
    assert isinstance(payload_field, str), (
        f"payload field must be a JSON string, got {type(payload_field).__name__}"
    )

    # Must be valid JSON that round-trips to the original dict
    decoded = json.loads(payload_field)
    assert decoded == payload_data, (
        f"Decoded payload mismatch: expected {payload_data!r}, got {decoded!r}"
    )


def test_wb5_stream_key_constant_is_genesis_swarm_tasks():
    """WB5: STREAM_KEY module-level constant equals 'genesis:swarm:tasks'."""
    assert STREAM_KEY == "genesis:swarm:tasks", (
        f"STREAM_KEY should be 'genesis:swarm:tasks', got {STREAM_KEY!r}"
    )


# ===========================================================================
# Package export tests
# ===========================================================================


def test_package_exports_task_dag_pusher():
    """Package level: TaskDAGPusher importable from core.coherence."""
    from core.coherence import TaskDAGPusher as TDP
    assert TDP is TaskDAGPusher


def test_package_exports_stream_key():
    """Package level: STREAM_KEY importable from core.coherence."""
    from core.coherence import STREAM_KEY as SK
    assert SK == "genesis:swarm:tasks"


def test_package_exports_default_group():
    """Package level: DEFAULT_GROUP importable from core.coherence."""
    from core.coherence import DEFAULT_GROUP as DG
    assert DG == "genesis_workers"


def test_create_consumer_group_non_busygroup_error_reraises():
    """create_consumer_group re-raises non-BUSYGROUP errors."""
    mock_redis = _make_redis_xgroup_error()
    pusher = TaskDAGPusher(mock_redis)

    with pytest.raises(Exception) as exc_info:
        run(pusher.create_consumer_group())

    assert "CONNECTION_REFUSED" in str(exc_info.value)


def test_push_dag_empty_tasks_returns_empty_list():
    """push_dag with empty task list returns empty list, xadd not called."""
    mock_redis = _make_redis()
    pusher = TaskDAGPusher(mock_redis)

    entry_ids = run(pusher.push_dag("sess-empty", []))

    assert entry_ids == []
    mock_redis.xadd.assert_not_called()


def test_push_dag_defaults_applied_for_missing_task_fields():
    """Task with no optional fields → defaults: task_type='unknown', tier='T1',
    priority='normal', payload='{}'."""
    mock_redis = _make_redis()
    pusher = TaskDAGPusher(mock_redis)

    run(pusher.push_dag("sess-defaults", [{}]))

    _, fields = mock_redis.xadd.call_args[0]
    assert fields["task_type"] == "unknown"
    assert fields["tier"] == "T1"
    assert fields["priority"] == "normal"
    assert json.loads(fields["payload"]) == {}


def test_push_dag_session_id_propagated_to_every_entry():
    """Every stream entry carries the session_id passed to push_dag."""
    mock_redis = _make_redis()
    pusher = TaskDAGPusher(mock_redis)

    session = "sess-propagate-abc"
    tasks = [{"task_type": f"t{i}", "payload": {}} for i in range(4)]

    run(pusher.push_dag(session, tasks))

    for c in mock_redis.xadd.call_args_list:
        _, fields = c[0]
        assert fields["session_id"] == session, (
            f"Expected session_id={session!r}, got {fields['session_id']!r}"
        )


def test_create_consumer_group_passes_mkstream_and_id_zero():
    """create_consumer_group calls XGROUP CREATE with id='0' and mkstream=True."""
    mock_redis = _make_redis()
    pusher = TaskDAGPusher(mock_redis)

    run(pusher.create_consumer_group("test_group"))

    mock_redis.xgroup_create.assert_called_once_with(
        STREAM_KEY, "test_group", id="0", mkstream=True
    )


# ===========================================================================
# Standalone runner (pytest preferred, fallback to direct execution)
# ===========================================================================

if __name__ == "__main__":
    import traceback

    tests = [
        ("BB1: push_dag 2 tasks → xadd called twice → 2 entry IDs",
         test_bb1_push_dag_two_tasks_calls_xadd_twice_returns_two_ids),
        ("BB2: Stream entry contains all required fields",
         test_bb2_stream_entry_contains_all_required_fields),
        ("BB3: create_consumer_group succeeds on first call",
         test_bb3_create_consumer_group_succeeds_on_first_call),
        ("BB4: create_consumer_group BUSYGROUP → returns True (idempotent)",
         test_bb4_create_consumer_group_busygroup_returns_true_idempotent),
        ("BB5: get_stream_length returns xlen value",
         test_bb5_get_stream_length_returns_xlen_value),
        ("WB1: XADD called once per task (not batch)",
         test_wb1_xadd_called_once_per_task_not_batch),
        ("WB2: Returns list of stream entry IDs (not task UUIDs)",
         test_wb2_returns_list_of_stream_entry_ids_not_task_uuids),
        ("WB3: task_id in each entry is valid UUID4",
         test_wb3_task_id_in_entry_is_valid_uuid4),
        ("WB4: payload field is JSON-encoded string (not raw dict)",
         test_wb4_payload_field_is_json_encoded_string_not_raw_dict),
        ("WB5: STREAM_KEY constant is 'genesis:swarm:tasks'",
         test_wb5_stream_key_constant_is_genesis_swarm_tasks),
        ("PKG: TaskDAGPusher importable from core.coherence",
         test_package_exports_task_dag_pusher),
        ("PKG: STREAM_KEY importable from core.coherence",
         test_package_exports_stream_key),
        ("PKG: DEFAULT_GROUP importable from core.coherence",
         test_package_exports_default_group),
        ("create_consumer_group re-raises non-BUSYGROUP errors",
         test_create_consumer_group_non_busygroup_error_reraises),
        ("push_dag empty tasks → empty list, xadd not called",
         test_push_dag_empty_tasks_returns_empty_list),
        ("push_dag missing fields → defaults applied",
         test_push_dag_defaults_applied_for_missing_task_fields),
        ("session_id propagated to every stream entry",
         test_push_dag_session_id_propagated_to_every_entry),
        ("create_consumer_group passes id='0' and mkstream=True",
         test_create_consumer_group_passes_mkstream_and_id_zero),
    ]

    passed = 0
    total = len(tests)
    for name, fn in tests:
        try:
            fn()
            print(f"  [PASS] {name}")
            passed += 1
        except Exception as exc:
            print(f"  [FAIL] {name}: {exc}")
            traceback.print_exc()

    print(f"\n{passed}/{total} tests passed")
    if passed == total:
        print("ALL TESTS PASSED -- Story 6.03 (Track B)")
    else:
        sys.exit(1)
