#!/usr/bin/env python3
"""
Tests for Story 6.06 (Track B): OCCCommitEngine — Barrier Sync + OCC Write

Black Box tests (BB): verify public contract from the outside.
    BB1: 3 workers submit → all collected → merged → committed → success=True
    BB2: Version conflict on first attempt → retry succeeds on attempt 2 → retries=1
    BB3: All 3 retries fail → CommitResult(success=False), saga_status="conflict_exhausted"
    BB4: Merge fails → CommitResult(success=False), saga_status="merge_failed"
    BB5: saga_writer.close_saga called on both success and failure paths

White Box tests (WB): verify internal mechanics.
    WB1: merge_interceptor.merge called with correct args (deltas, current_state, version)
    WB2: staging.wait_for_all called with expected_workers and timeout_ms=60000
    WB3: MAX_RETRIES constant equals 3
    WB4: commit_patch called with correct version for OCC check

Story: 6.06
File under test: core/coherence/occ_commit.py

ALL tests use mocks — NO real Redis connection is made.
NO SQLite anywhere in this module.
"""

from __future__ import annotations

import sys
sys.path.insert(0, "/mnt/e/genesis-system")

import asyncio
from unittest.mock import AsyncMock, MagicMock, call, patch

import pytest


# ---------------------------------------------------------------------------
# Module under test
# ---------------------------------------------------------------------------

from core.coherence.occ_commit import (
    OCCCommitEngine,
    OccCommitResult,
    MAX_RETRIES,
)


# ---------------------------------------------------------------------------
# Async helper
# ---------------------------------------------------------------------------

def run(coro):
    """Run a coroutine synchronously (pytest-asyncio not required)."""
    loop = asyncio.new_event_loop()
    try:
        return loop.run_until_complete(coro)
    finally:
        loop.close()


# ---------------------------------------------------------------------------
# Mock builder helpers
# ---------------------------------------------------------------------------

def _make_staging(deltas: list) -> MagicMock:
    """Return a mock StagingArea that immediately yields `deltas` from wait_for_all."""
    staging = MagicMock()
    staging.wait_for_all = AsyncMock(return_value=deltas)
    return staging


def _make_merge_ok(merged_patch: list) -> MagicMock:
    """Return a mock merge interceptor that returns a successful merge result."""
    merge_result = MagicMock()
    merge_result.success = True
    merge_result.merged_patch = merged_patch

    interceptor = MagicMock()
    interceptor.merge = AsyncMock(return_value=merge_result)
    return interceptor


def _make_merge_fail() -> MagicMock:
    """Return a mock merge interceptor that returns a failed merge result."""
    merge_result = MagicMock()
    merge_result.success = False
    merge_result.merged_patch = []

    interceptor = MagicMock()
    interceptor.merge = AsyncMock(return_value=merge_result)
    return interceptor


def _make_master_state(
    version: int = 5,
    data: dict = None,
    commit_ok: bool = True,
) -> MagicMock:
    """
    Return a mock RedisMasterState.

    get_snapshot always returns (version, data).
    commit_patch returns a CommitResult-like mock with
        success = commit_ok
        new_version = version + 1 (if commit_ok)
        conflict = not commit_ok
    """
    if data is None:
        data = {"status": "running"}

    commit_result = MagicMock()
    commit_result.success = commit_ok
    commit_result.new_version = version + 1 if commit_ok else 0
    commit_result.conflict = not commit_ok

    master = MagicMock()
    master.get_snapshot = AsyncMock(return_value=(version, data))
    master.commit_patch = AsyncMock(return_value=commit_result)
    return master


def _make_master_conflict_then_ok(version: int = 5, data: dict = None) -> MagicMock:
    """
    RedisMasterState that fails commit on attempt 0 (conflict), then succeeds.
    get_snapshot always returns the same (version, data).
    commit_patch alternates: fail on call 1, succeed on call 2.
    """
    if data is None:
        data = {}

    fail_result = MagicMock()
    fail_result.success = False
    fail_result.new_version = 0
    fail_result.conflict = True

    ok_result = MagicMock()
    ok_result.success = True
    ok_result.new_version = version + 1
    ok_result.conflict = False

    master = MagicMock()
    master.get_snapshot = AsyncMock(return_value=(version, data))
    master.commit_patch = AsyncMock(side_effect=[fail_result, ok_result])
    return master


def _make_master_always_conflict(version: int = 5, data: dict = None) -> MagicMock:
    """RedisMasterState that ALWAYS returns a conflict on commit_patch."""
    if data is None:
        data = {}

    fail_result = MagicMock()
    fail_result.success = False
    fail_result.new_version = 0
    fail_result.conflict = True

    master = MagicMock()
    master.get_snapshot = AsyncMock(return_value=(version, data))
    master.commit_patch = AsyncMock(return_value=fail_result)
    return master


def _make_saga_writer() -> MagicMock:
    """Return a mock saga writer."""
    saga = MagicMock()
    saga.close_saga = AsyncMock()
    return saga


def _make_cold_ledger() -> MagicMock:
    """Return a mock cold ledger."""
    ledger = MagicMock()
    ledger.write_event = MagicMock()
    return ledger


def _sample_deltas(count: int = 3) -> list:
    """Return a list of `count` minimal delta dicts (as StagingArea would return)."""
    return [
        {
            "agent_id": f"agent-{i}",
            "session_id": "sess-test",
            "version_at_read": 5,
            "patch": [{"op": "add", "path": f"/key{i}", "value": f"val{i}"}],
        }
        for i in range(count)
    ]


# ===========================================================================
# BLACK BOX TESTS
# ===========================================================================


def test_bb1_three_workers_all_collected_merged_committed_success():
    """
    BB1: 3 workers submit deltas → all collected from staging →
         merged successfully → committed → OccCommitResult(success=True).
    """
    deltas = _sample_deltas(3)
    merged_patch = [{"op": "add", "path": "/merged", "value": True}]

    staging = _make_staging(deltas)
    merger = _make_merge_ok(merged_patch)
    master = _make_master_state(version=5, commit_ok=True)
    saga = _make_saga_writer()
    ledger = _make_cold_ledger()

    engine = OCCCommitEngine(staging, merger, master, saga, ledger)
    result = run(engine.execute_commit("sess-test", expected_workers=3))

    assert result.success is True
    assert result.merged_patch == merged_patch
    assert result.version == 6      # 5 + 1
    assert result.retries == 0      # succeeded on first attempt
    assert result.saga_status == "completed"


def test_bb2_version_conflict_first_attempt_retry_succeeds():
    """
    BB2: First commit_patch returns conflict; second attempt succeeds.
         result.retries == 1 (zero-indexed attempt number for success).
    """
    deltas = _sample_deltas(2)
    merged_patch = [{"op": "replace", "path": "/state", "value": "merged"}]

    staging = _make_staging(deltas)
    merger = _make_merge_ok(merged_patch)
    master = _make_master_conflict_then_ok(version=7)
    saga = _make_saga_writer()
    ledger = _make_cold_ledger()

    engine = OCCCommitEngine(staging, merger, master, saga, ledger)
    result = run(engine.execute_commit("sess-retry", expected_workers=2))

    assert result.success is True
    assert result.merged_patch == merged_patch
    assert result.version == 8      # 7 + 1
    assert result.retries == 1      # succeeded on second attempt (index 1)
    assert result.saga_status == "completed"

    # commit_patch must have been called twice
    assert master.commit_patch.call_count == 2


def test_bb3_all_retries_fail_conflict_exhausted():
    """
    BB3: All MAX_RETRIES (3) OCC attempts fail with conflict.
         OccCommitResult(success=False, saga_status="conflict_exhausted").
    """
    deltas = _sample_deltas(1)
    merged_patch = [{"op": "add", "path": "/x", "value": 1}]

    staging = _make_staging(deltas)
    merger = _make_merge_ok(merged_patch)
    master = _make_master_always_conflict(version=3)
    saga = _make_saga_writer()
    ledger = _make_cold_ledger()

    engine = OCCCommitEngine(staging, merger, master, saga, ledger)
    result = run(engine.execute_commit("sess-exhaust", expected_workers=1))

    assert result.success is False
    assert result.merged_patch == []
    assert result.retries == MAX_RETRIES
    assert result.saga_status == "conflict_exhausted"

    # commit_patch must have been called MAX_RETRIES times
    assert master.commit_patch.call_count == MAX_RETRIES


def test_bb4_merge_fails_returns_merge_failed_status():
    """
    BB4: merge_interceptor.merge returns success=False.
         OccCommitResult(success=False, saga_status="merge_failed").
    """
    deltas = _sample_deltas(2)

    staging = _make_staging(deltas)
    merger = _make_merge_fail()
    master = _make_master_state(version=5, commit_ok=True)
    saga = _make_saga_writer()

    engine = OCCCommitEngine(staging, merger, master, saga)
    result = run(engine.execute_commit("sess-merge-fail", expected_workers=2))

    assert result.success is False
    assert result.merged_patch == []
    assert result.saga_status == "merge_failed"
    # commit_patch must NOT have been called — we failed before reaching OCC
    master.commit_patch.assert_not_called()


def test_bb5_saga_writer_called_on_success():
    """
    BB5a: saga_writer.close_saga is called with ("completed") on success.
    """
    deltas = _sample_deltas(1)
    staging = _make_staging(deltas)
    merger = _make_merge_ok([{"op": "add", "path": "/k", "value": 1}])
    master = _make_master_state(version=2, commit_ok=True)
    saga = _make_saga_writer()

    engine = OCCCommitEngine(staging, merger, master, saga)
    run(engine.execute_commit("sess-saga-ok", expected_workers=1))

    saga.close_saga.assert_awaited_once_with("sess-saga-ok", "completed")


def test_bb5b_saga_writer_called_on_failure():
    """
    BB5b: saga_writer.close_saga is called even when commit fails (conflict_exhausted).
    """
    deltas = _sample_deltas(1)
    staging = _make_staging(deltas)
    merger = _make_merge_ok([{"op": "add", "path": "/k", "value": 1}])
    master = _make_master_always_conflict(version=2)
    saga = _make_saga_writer()

    engine = OCCCommitEngine(staging, merger, master, saga)
    run(engine.execute_commit("sess-saga-fail", expected_workers=1))

    saga.close_saga.assert_awaited_once_with("sess-saga-fail", "conflict_exhausted")


def test_bb5c_saga_writer_called_on_merge_failure():
    """
    BB5c: saga_writer.close_saga is called when merge fails (merge_failed).
    """
    deltas = _sample_deltas(1)
    staging = _make_staging(deltas)
    merger = _make_merge_fail()
    master = _make_master_state(version=2, commit_ok=True)
    saga = _make_saga_writer()

    engine = OCCCommitEngine(staging, merger, master, saga)
    run(engine.execute_commit("sess-saga-merge-fail", expected_workers=1))

    saga.close_saga.assert_awaited_once_with("sess-saga-merge-fail", "merge_failed")


# ===========================================================================
# WHITE BOX TESTS
# ===========================================================================


def test_wb1_merge_called_with_correct_args():
    """
    WB1: merge_interceptor.merge is called with
         (deltas, current_state, version) exactly.
    """
    current_data = {"agent": "state"}
    version = 9
    deltas = _sample_deltas(2)

    staging = _make_staging(deltas)
    merger = _make_merge_ok([])
    master = _make_master_state(version=version, data=current_data, commit_ok=True)

    engine = OCCCommitEngine(staging, merger, master)
    run(engine.execute_commit("sess-wb1", expected_workers=2))

    merger.merge.assert_awaited_once_with(deltas, current_data, version)


def test_wb2_staging_wait_for_all_called_with_correct_args():
    """
    WB2: staging.wait_for_all is called with
         (session_id, expected_workers, timeout_ms=60000).
    """
    deltas = _sample_deltas(3)
    staging = _make_staging(deltas)
    merger = _make_merge_ok([])
    master = _make_master_state(version=1, commit_ok=True)

    engine = OCCCommitEngine(staging, merger, master)
    run(engine.execute_commit("sess-wb2", expected_workers=3))

    staging.wait_for_all.assert_awaited_once_with(
        "sess-wb2",
        3,
        timeout_ms=60_000,
    )


def test_wb3_max_retries_constant_is_3():
    """WB3: MAX_RETRIES module-level constant equals 3."""
    assert MAX_RETRIES == 3, f"MAX_RETRIES should be 3, got {MAX_RETRIES}"


def test_wb4_commit_patch_called_with_correct_version():
    """
    WB4: commit_patch is called with the exact version returned by
         get_snapshot — not an off-by-one or stale value.
    """
    version = 42
    deltas = _sample_deltas(1)
    merged_patch = [{"op": "add", "path": "/v", "value": version}]

    staging = _make_staging(deltas)
    merger = _make_merge_ok(merged_patch)
    master = _make_master_state(version=version, commit_ok=True)

    engine = OCCCommitEngine(staging, merger, master)
    run(engine.execute_commit("sess-wb4", expected_workers=1))

    master.commit_patch.assert_awaited_once_with(
        "sess-wb4",
        version,
        merged_patch,
    )


# ===========================================================================
# ADDITIONAL COVERAGE TESTS
# ===========================================================================


def test_cold_ledger_write_event_called_on_success():
    """cold_ledger.write_event is called with session_id and 'saga_committed' on success."""
    deltas = _sample_deltas(1)
    staging = _make_staging(deltas)
    merger = _make_merge_ok([{"op": "add", "path": "/k", "value": 1}])
    master = _make_master_state(version=3, commit_ok=True)
    ledger = _make_cold_ledger()

    engine = OCCCommitEngine(staging, merger, master, cold_ledger=ledger)
    run(engine.execute_commit("sess-ledger", expected_workers=1))

    ledger.write_event.assert_called_once_with("sess-ledger", "saga_committed", {})


def test_cold_ledger_not_called_on_failure():
    """cold_ledger.write_event is NOT called when commit fails."""
    deltas = _sample_deltas(1)
    staging = _make_staging(deltas)
    merger = _make_merge_ok([{"op": "add", "path": "/k", "value": 1}])
    master = _make_master_always_conflict(version=3)
    ledger = _make_cold_ledger()

    engine = OCCCommitEngine(staging, merger, master, cold_ledger=ledger)
    run(engine.execute_commit("sess-ledger-fail", expected_workers=1))

    ledger.write_event.assert_not_called()


def test_no_saga_writer_does_not_raise():
    """When saga_writer=None, success path does not raise."""
    deltas = _sample_deltas(1)
    staging = _make_staging(deltas)
    merger = _make_merge_ok([])
    master = _make_master_state(version=1, commit_ok=True)

    engine = OCCCommitEngine(staging, merger, master)  # no saga_writer
    result = run(engine.execute_commit("sess-no-saga", expected_workers=1))

    assert result.success is True


def test_no_cold_ledger_does_not_raise():
    """When cold_ledger=None, success path does not raise."""
    deltas = _sample_deltas(1)
    staging = _make_staging(deltas)
    merger = _make_merge_ok([])
    master = _make_master_state(version=1, commit_ok=True)

    engine = OCCCommitEngine(staging, merger, master)  # no cold_ledger
    result = run(engine.execute_commit("sess-no-ledger", expected_workers=1))

    assert result.success is True


def test_get_snapshot_called_on_each_retry():
    """get_snapshot is called once per OCC retry attempt (not cached)."""
    deltas = _sample_deltas(1)
    staging = _make_staging(deltas)
    merger = _make_merge_ok([{"op": "add", "path": "/k", "value": 1}])
    master = _make_master_always_conflict(version=2)

    engine = OCCCommitEngine(staging, merger, master)
    run(engine.execute_commit("sess-snapshot-calls", expected_workers=1))

    # get_snapshot should be called once per retry attempt
    assert master.get_snapshot.call_count == MAX_RETRIES


def test_occ_commit_result_dataclass_fields():
    """OccCommitResult dataclass has expected fields with correct defaults."""
    result = OccCommitResult(success=True)
    assert result.success is True
    assert result.merged_patch == []
    assert result.version == 0
    assert result.retries == 0
    assert result.saga_status == "unknown"


# ===========================================================================
# PACKAGE EXPORT TESTS
# ===========================================================================


def test_package_exports_occ_commit_engine():
    """Package level: OCCCommitEngine importable from core.coherence."""
    from core.coherence import OCCCommitEngine as OCE
    assert OCE is OCCCommitEngine


def test_package_exports_occ_commit_result():
    """Package level: OccCommitResult importable from core.coherence."""
    from core.coherence import OccCommitResult as OCR
    assert OCR is OccCommitResult


def test_package_exports_max_retries():
    """Package level: MAX_RETRIES importable from core.coherence."""
    from core.coherence import MAX_RETRIES as MR
    assert MR == 3


# ===========================================================================
# Standalone runner (pytest preferred, fallback to direct execution)
# ===========================================================================

if __name__ == "__main__":
    import traceback

    tests = [
        ("BB1: 3 workers → collected → merged → committed → success=True",
         test_bb1_three_workers_all_collected_merged_committed_success),
        ("BB2: Conflict on attempt 0 → retry succeeds on attempt 1 → retries=1",
         test_bb2_version_conflict_first_attempt_retry_succeeds),
        ("BB3: All 3 retries fail → conflict_exhausted",
         test_bb3_all_retries_fail_conflict_exhausted),
        ("BB4: Merge fails → merge_failed",
         test_bb4_merge_fails_returns_merge_failed_status),
        ("BB5a: saga_writer.close_saga called on success",
         test_bb5_saga_writer_called_on_success),
        ("BB5b: saga_writer.close_saga called on conflict_exhausted",
         test_bb5b_saga_writer_called_on_failure),
        ("BB5c: saga_writer.close_saga called on merge_failed",
         test_bb5c_saga_writer_called_on_merge_failure),
        ("WB1: merge called with (deltas, current_state, version)",
         test_wb1_merge_called_with_correct_args),
        ("WB2: staging.wait_for_all called with correct args + timeout_ms=60000",
         test_wb2_staging_wait_for_all_called_with_correct_args),
        ("WB3: MAX_RETRIES == 3",
         test_wb3_max_retries_constant_is_3),
        ("WB4: commit_patch called with correct version from get_snapshot",
         test_wb4_commit_patch_called_with_correct_version),
        ("cold_ledger.write_event called on success",
         test_cold_ledger_write_event_called_on_success),
        ("cold_ledger NOT called on failure",
         test_cold_ledger_not_called_on_failure),
        ("No saga_writer → no raise",
         test_no_saga_writer_does_not_raise),
        ("No cold_ledger → no raise",
         test_no_cold_ledger_does_not_raise),
        ("get_snapshot called once per retry",
         test_get_snapshot_called_on_each_retry),
        ("OccCommitResult dataclass fields",
         test_occ_commit_result_dataclass_fields),
        ("PKG: OCCCommitEngine importable from core.coherence",
         test_package_exports_occ_commit_engine),
        ("PKG: OccCommitResult importable from core.coherence",
         test_package_exports_occ_commit_result),
        ("PKG: MAX_RETRIES importable from core.coherence",
         test_package_exports_max_retries),
    ]

    passed = 0
    total = len(tests)
    for name, fn in tests:
        try:
            fn()
            print(f"  [PASS] {name}")
            passed += 1
        except Exception as exc:
            print(f"  [FAIL] {name}: {exc}")
            traceback.print_exc()

    print(f"\n{passed}/{total} tests passed")
    if passed == total:
        print("ALL TESTS PASSED -- Story 6.06 (Track B)")
    else:
        sys.exit(1)
