#!/usr/bin/env python3
"""
Tests for Story 6.02 (Track B): RedisMasterState — Versioned OCC State Store

Black Box tests (BB): verify the public contract from the outside —
    get_snapshot return types, commit_patch success/conflict, initialize_state.
White Box tests (WB): verify internal mechanics — new_version increment,
    JSON storage format, WatchError → conflict, patch application, JSON decode.

Story: 6.02
File under test: core/coherence/redis_master_state.py

ALL tests use mocks — NO real Redis connection is made.
"""

from __future__ import annotations

import sys
sys.path.insert(0, "/mnt/e/genesis-system")

import json
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock

from redis.exceptions import WatchError


# ---------------------------------------------------------------------------
# Module under test
# ---------------------------------------------------------------------------

from core.coherence.redis_master_state import RedisMasterState, CommitResult


# ---------------------------------------------------------------------------
# Async test helpers
# ---------------------------------------------------------------------------

def run(coro):
    """Run a coroutine synchronously (pytest-asyncio not required)."""
    return asyncio.get_event_loop().run_until_complete(coro)


# ---------------------------------------------------------------------------
# Mock builders
# ---------------------------------------------------------------------------

def _encode(version: int, data: dict) -> bytes:
    """Encode state the same way RedisMasterState does."""
    return json.dumps({"version": version, "data": data}, separators=(",", ":")).encode()


def _make_redis_fresh():
    """
    Return a mock redis client where get() returns None (fresh session).
    pipeline() returns an async context manager with watch/multi/execute support.
    """
    mock_redis = MagicMock()

    # get() returns None (no existing state)
    mock_redis.get = AsyncMock(return_value=None)

    # set() (for initialize_state) returns truthy (success)
    mock_redis.set = AsyncMock(return_value=True)

    # Build pipeline context manager
    pipe = AsyncMock()
    pipe.watch = AsyncMock()
    pipe.get = AsyncMock(return_value=None)
    pipe.unwatch = AsyncMock()
    pipe.multi = MagicMock()  # synchronous in redis-py
    pipe.set = MagicMock()    # queues the command (synchronous in MULTI mode)
    pipe.execute = AsyncMock(return_value=[True])

    # pipeline() returns an async context manager that yields `pipe`
    pipeline_cm = MagicMock()
    pipeline_cm.__aenter__ = AsyncMock(return_value=pipe)
    pipeline_cm.__aexit__ = AsyncMock(return_value=False)
    mock_redis.pipeline = MagicMock(return_value=pipeline_cm)

    mock_redis._pipe = pipe  # expose for assertions
    return mock_redis


def _make_redis_with_state(version: int, data: dict):
    """
    Return a mock redis client that already has state (version, data) stored.
    """
    mock_redis = MagicMock()
    encoded = _encode(version, data)

    mock_redis.get = AsyncMock(return_value=encoded)
    mock_redis.set = AsyncMock(return_value=None)  # key exists → NX fails → None

    pipe = AsyncMock()
    pipe.watch = AsyncMock()
    pipe.get = AsyncMock(return_value=encoded)
    pipe.unwatch = AsyncMock()
    pipe.multi = MagicMock()
    pipe.set = MagicMock()
    pipe.execute = AsyncMock(return_value=[True])

    pipeline_cm = MagicMock()
    pipeline_cm.__aenter__ = AsyncMock(return_value=pipe)
    pipeline_cm.__aexit__ = AsyncMock(return_value=False)
    mock_redis.pipeline = MagicMock(return_value=pipeline_cm)

    mock_redis._pipe = pipe
    return mock_redis


def _make_redis_with_watch_error(version: int, data: dict):
    """
    Return a mock redis client where execute() raises WatchError (concurrent write).
    """
    mock_redis = MagicMock()
    encoded = _encode(version, data)

    mock_redis.get = AsyncMock(return_value=encoded)
    mock_redis.set = AsyncMock(return_value=True)

    pipe = AsyncMock()
    pipe.watch = AsyncMock()
    pipe.get = AsyncMock(return_value=encoded)
    pipe.unwatch = AsyncMock()
    pipe.multi = MagicMock()
    pipe.set = MagicMock()
    # execute raises WatchError — concurrent writer modified the key
    pipe.execute = AsyncMock(side_effect=WatchError())

    pipeline_cm = MagicMock()
    pipeline_cm.__aenter__ = AsyncMock(return_value=pipe)
    pipeline_cm.__aexit__ = AsyncMock(return_value=False)
    mock_redis.pipeline = MagicMock(return_value=pipeline_cm)

    mock_redis._pipe = pipe
    return mock_redis


def _make_redis_with_wrong_version(stored_version: int, data: dict):
    """
    Return a mock redis client where stored version differs from caller's version.
    """
    mock_redis = MagicMock()
    encoded = _encode(stored_version, data)

    mock_redis.get = AsyncMock(return_value=encoded)
    mock_redis.set = AsyncMock(return_value=True)

    pipe = AsyncMock()
    pipe.watch = AsyncMock()
    pipe.get = AsyncMock(return_value=encoded)
    pipe.unwatch = AsyncMock()
    pipe.multi = MagicMock()
    pipe.set = MagicMock()
    pipe.execute = AsyncMock(return_value=[True])

    pipeline_cm = MagicMock()
    pipeline_cm.__aenter__ = AsyncMock(return_value=pipe)
    pipeline_cm.__aexit__ = AsyncMock(return_value=False)
    mock_redis.pipeline = MagicMock(return_value=pipeline_cm)

    mock_redis._pipe = pipe
    return mock_redis


# ===========================================================================
# BLACK BOX TESTS
# ===========================================================================


def test_bb1_get_snapshot_fresh_session_returns_zero_empty():
    """BB1: Fresh session → get_snapshot returns (0, {})."""
    mock_redis = _make_redis_fresh()
    rms = RedisMasterState(mock_redis)
    version, data = run(rms.get_snapshot("sess-fresh"))
    assert version == 0
    assert data == {}


def test_bb2_commit_patch_correct_version_returns_success():
    """BB2: commit_patch with correct version → CommitResult(success=True, new_version=1, conflict=False)."""
    mock_redis = _make_redis_fresh()  # stored version=0, data={}
    rms = RedisMasterState(mock_redis)
    patch_ops = [{"op": "add", "path": "/status", "value": "active"}]
    result = run(rms.commit_patch("sess-abc", version=0, patch=patch_ops))
    assert isinstance(result, CommitResult)
    assert result.success is True
    assert result.new_version == 1
    assert result.conflict is False


def test_bb3_initialize_state_sets_version_one():
    """BB3: initialize_state → CommitResult(success=True, new_version=1, conflict=False)."""
    mock_redis = _make_redis_fresh()
    rms = RedisMasterState(mock_redis)
    result = run(rms.initialize_state("sess-new", {"agent": "genesis", "ready": True}))
    assert isinstance(result, CommitResult)
    assert result.success is True
    assert result.new_version == 1
    assert result.conflict is False


def test_bb4_commit_patch_wrong_version_returns_conflict():
    """BB4: commit_patch with wrong version → CommitResult(success=False, conflict=True)."""
    # Stored version=5, but caller claims version=3
    mock_redis = _make_redis_with_wrong_version(stored_version=5, data={"x": 1})
    rms = RedisMasterState(mock_redis)
    result = run(rms.commit_patch("sess-abc", version=3, patch=[]))
    assert result.success is False
    assert result.conflict is True


def test_bb5_get_snapshot_existing_state_returns_version_and_data():
    """BB5: get_snapshot on existing session returns correct (version, data)."""
    existing_data = {"agent": "genesis", "tasks": 7}
    mock_redis = _make_redis_with_state(version=4, data=existing_data)
    rms = RedisMasterState(mock_redis)
    version, data = run(rms.get_snapshot("sess-existing"))
    assert version == 4
    assert data == existing_data


def test_bb6_initialize_state_key_exists_returns_conflict():
    """BB6: initialize_state when key already exists → CommitResult(success=False, conflict=True)."""
    mock_redis = _make_redis_with_state(version=1, data={"existing": True})
    # set() with NX returns None when key exists
    mock_redis.set = AsyncMock(return_value=None)
    rms = RedisMasterState(mock_redis)
    result = run(rms.initialize_state("sess-exists", {"new": "data"}))
    assert result.success is False
    assert result.conflict is True
    assert result.new_version == 0


# ===========================================================================
# WHITE BOX TESTS
# ===========================================================================


def test_wb1_new_version_is_version_plus_one_exactly():
    """WB1: new_version = version + 1 exactly (not +2, not +0)."""
    # Start from version 7
    mock_redis = _make_redis_with_state(version=7, data={"count": 42})
    rms = RedisMasterState(mock_redis)
    result = run(rms.commit_patch("sess-wb1", version=7, patch=[]))
    assert result.new_version == 8  # 7 + 1


def test_wb2_state_stored_as_json_with_version_and_data_keys():
    """WB2: State stored as JSON with 'version' and 'data' keys."""
    mock_redis = _make_redis_fresh()
    rms = RedisMasterState(mock_redis)
    patch_ops = [{"op": "add", "path": "/key", "value": "val"}]
    run(rms.commit_patch("sess-wb2", version=0, patch=patch_ops))

    # Extract the JSON string passed to pipe.set
    pipe = mock_redis._pipe
    set_call_args = pipe.set.call_args
    assert set_call_args is not None, "pipe.set was not called"
    stored_json = set_call_args[0][1]  # second positional arg is the value

    parsed = json.loads(stored_json)
    assert "version" in parsed, f"'version' key missing from stored JSON: {parsed}"
    assert "data" in parsed, f"'data' key missing from stored JSON: {parsed}"
    assert parsed["version"] == 1
    assert parsed["data"]["key"] == "val"


def test_wb3_watch_error_from_pipeline_returns_conflict():
    """WB3: WatchError raised by Redis pipeline.execute() → CommitResult(success=False, conflict=True)."""
    mock_redis = _make_redis_with_watch_error(version=2, data={"k": "v"})
    rms = RedisMasterState(mock_redis)
    result = run(rms.commit_patch("sess-wb3", version=2, patch=[]))
    assert result.success is False
    assert result.conflict is True
    assert result.new_version == 0


def test_wb4_patch_operations_correctly_modify_data():
    """WB4: Patch operations (add, replace, remove) correctly transform the data dict."""
    initial_data = {"a": 1, "b": 2, "c": 3}
    mock_redis = _make_redis_with_state(version=1, data=initial_data)
    rms = RedisMasterState(mock_redis)

    patch_ops = [
        {"op": "add",     "path": "/d",  "value": 4},   # new key
        {"op": "replace", "path": "/a",  "value": 99},  # update existing
        {"op": "remove",  "path": "/c"},                 # delete key
    ]
    run(rms.commit_patch("sess-wb4", version=1, patch=patch_ops))

    # Extract stored JSON
    pipe = mock_redis._pipe
    stored_json = pipe.set.call_args[0][1]
    parsed = json.loads(stored_json)
    data = parsed["data"]

    assert data["d"] == 4,  f"'add' op failed: {data}"
    assert data["a"] == 99, f"'replace' op failed: {data}"
    assert "c" not in data, f"'remove' op failed: {data}"
    assert data["b"] == 2,  f"untouched key 'b' was changed: {data}"


def test_wb5_get_snapshot_deserializes_json_correctly():
    """WB5: get_snapshot correctly deserializes the JSON blob from Redis."""
    payload = {"version": 12, "data": {"alpha": "beta", "count": 42}}
    encoded = json.dumps(payload).encode()

    mock_redis = MagicMock()
    mock_redis.get = AsyncMock(return_value=encoded)

    rms = RedisMasterState(mock_redis)
    version, data = run(rms.get_snapshot("sess-wb5"))

    assert version == 12
    assert data == {"alpha": "beta", "count": 42}


# ===========================================================================
# Package export tests
# ===========================================================================


def test_package_exports_redis_master_state():
    """Package level: RedisMasterState importable from core.coherence."""
    from core.coherence import RedisMasterState as RMS
    assert RMS is RedisMasterState


def test_package_exports_commit_result():
    """Package level: CommitResult importable from core.coherence."""
    from core.coherence import CommitResult as CR
    assert CR is CommitResult


def test_commit_result_dataclass_fields():
    """CommitResult has expected fields: success, new_version, conflict."""
    r = CommitResult(success=True, new_version=5, conflict=False)
    assert r.success is True
    assert r.new_version == 5
    assert r.conflict is False


def test_key_prefix_constant():
    """KEY_PREFIX is set to 'genesis:state:master:'."""
    assert RedisMasterState.KEY_PREFIX == "genesis:state:master:"


# ===========================================================================
# Standalone runner (pytest preferred, fallback to direct execution)
# ===========================================================================

if __name__ == "__main__":
    import traceback

    tests = [
        ("BB1: Fresh session → get_snapshot returns (0, {})", test_bb1_get_snapshot_fresh_session_returns_zero_empty),
        ("BB2: commit_patch correct version → success", test_bb2_commit_patch_correct_version_returns_success),
        ("BB3: initialize_state → version=1 success", test_bb3_initialize_state_sets_version_one),
        ("BB4: commit_patch wrong version → conflict", test_bb4_commit_patch_wrong_version_returns_conflict),
        ("BB5: get_snapshot existing state → correct values", test_bb5_get_snapshot_existing_state_returns_version_and_data),
        ("BB6: initialize_state key exists → conflict", test_bb6_initialize_state_key_exists_returns_conflict),
        ("WB1: new_version = version + 1 exactly", test_wb1_new_version_is_version_plus_one_exactly),
        ("WB2: State stored as JSON with version+data keys", test_wb2_state_stored_as_json_with_version_and_data_keys),
        ("WB3: WatchError → conflict=True", test_wb3_watch_error_from_pipeline_returns_conflict),
        ("WB4: Patch operations modify data correctly", test_wb4_patch_operations_correctly_modify_data),
        ("WB5: get_snapshot deserializes JSON correctly", test_wb5_get_snapshot_deserializes_json_correctly),
        ("PKG: RedisMasterState importable from core.coherence", test_package_exports_redis_master_state),
        ("PKG: CommitResult importable from core.coherence", test_package_exports_commit_result),
        ("CommitResult dataclass fields", test_commit_result_dataclass_fields),
        ("KEY_PREFIX constant check", test_key_prefix_constant),
    ]

    passed = 0
    total = len(tests)
    for name, fn in tests:
        try:
            fn()
            print(f"  [PASS] {name}")
            passed += 1
        except Exception as exc:
            print(f"  [FAIL] {name}: {exc}")
            traceback.print_exc()

    print(f"\n{passed}/{total} tests passed")
    if passed == total:
        print("ALL TESTS PASSED -- Story 6.02 (Track B)")
    else:
        sys.exit(1)
