#!/usr/bin/env python3
"""
Tests for Story 5.04 (Track B): SwarmSagaWriter — Saga Lifecycle Recorder

Black Box tests (BB): verify the public contract from the caller's perspective —
    open_saga creates RUNNING saga, record_proposed_delta appends correctly,
    close_saga writes terminal status, invalid status raises ValueError.

White Box tests (WB): verify internals — open_saga returns UUID4, each delta
    entry has submitted_at ISO timestamp, record_proposed_delta issues a
    SQL-level append (not read-modify-write).

ALL tests use mocks — NO real Postgres connection is required.

Story: 5.04
File under test: core/storage/saga_writer.py
"""

from __future__ import annotations

import sys
sys.path.insert(0, "/mnt/e/genesis-system")

import json
import pathlib
import re
import uuid
from datetime import datetime, timezone
from unittest.mock import MagicMock, call, patch

import pytest

# ---------------------------------------------------------------------------
# Modules under test
# ---------------------------------------------------------------------------

from core.storage.saga_writer import SwarmSagaWriter, _VALID_CLOSE_STATUSES
from core.storage.cold_ledger import ColdLedger, SwarmSaga
from core.storage import SwarmSagaWriter as SwarmSagaWriterFromPackage


# ---------------------------------------------------------------------------
# Mock factory helpers  (mirrors the style from test_story_5_02.py)
# ---------------------------------------------------------------------------

def _make_pool_and_conn(fetchone_return=None, fetchall_return=None):
    """Return (mock_pool, mock_conn, mock_cursor) wired for full mock operation."""
    mock_conn = MagicMock()
    mock_cursor = MagicMock()

    mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
    mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)

    mock_cursor.fetchone.return_value = fetchone_return
    mock_cursor.fetchall.return_value = fetchall_return if fetchall_return is not None else []

    mock_pool = MagicMock()
    mock_pool.getconn.return_value = mock_conn

    return mock_pool, mock_conn, mock_cursor


def _make_ledger(fetchone=None, fetchall=None):
    """Return a ColdLedger wired to a fully-mocked pool."""
    mock_pool, mock_conn, mock_cursor = _make_pool_and_conn(fetchone, fetchall)
    with patch("psycopg2.pool.ThreadedConnectionPool", return_value=mock_pool):
        ledger = ColdLedger(
            {"host": "localhost", "port": 5432, "user": "u",
             "password": "p", "dbname": "genesis"}
        )
    ledger._mock_pool = mock_pool
    ledger._mock_conn = mock_conn
    ledger._mock_cursor = mock_cursor
    return ledger


def _make_writer(fetchone=None, fetchall=None) -> tuple[SwarmSagaWriter, ColdLedger]:
    """Return (SwarmSagaWriter, ColdLedger) pair with a mocked ledger."""
    ledger = _make_ledger(fetchone=fetchone, fetchall=fetchall)
    writer = SwarmSagaWriter(ledger)
    return writer, ledger


# ===========================================================================
# Black Box tests
# ===========================================================================


class TestBB1OpenSagaCreatesRunningRecord:
    """BB1: open_saga writes a saga with status='RUNNING' and proposed_deltas=[]."""

    def test_write_saga_called(self):
        """open_saga must call ledger.write_saga exactly once."""
        writer, ledger = _make_writer()
        with patch.object(ledger, "write_saga", return_value="fake-id") as mock_write:
            writer.open_saga("session-abc", {"step": 1})
        mock_write.assert_called_once()

    def test_written_saga_has_running_status(self):
        """The SwarmSaga passed to write_saga must have status='RUNNING'."""
        writer, ledger = _make_writer()
        captured = {}

        def capture(saga):
            captured["saga"] = saga
            return saga.saga_id

        with patch.object(ledger, "write_saga", side_effect=capture):
            writer.open_saga("session-abc", {"step": 1})

        assert captured["saga"].status == "RUNNING"

    def test_written_saga_has_empty_proposed_deltas(self):
        """The SwarmSaga passed to write_saga must have proposed_deltas=[]."""
        writer, ledger = _make_writer()
        captured = {}

        def capture(saga):
            captured["saga"] = saga
            return saga.saga_id

        with patch.object(ledger, "write_saga", side_effect=capture):
            writer.open_saga("session-abc", {"step": 1})

        assert captured["saga"].proposed_deltas == []

    def test_written_saga_has_none_resolved_state(self):
        """A freshly opened saga must have resolved_state=None."""
        writer, ledger = _make_writer()
        captured = {}

        def capture(saga):
            captured["saga"] = saga
            return saga.saga_id

        with patch.object(ledger, "write_saga", side_effect=capture):
            writer.open_saga("session-abc", {"step": 1})

        assert captured["saga"].resolved_state is None

    def test_orchestrator_dag_stored_correctly(self):
        """The orchestrator_dag argument must be persisted on the saga."""
        writer, ledger = _make_writer()
        dag = {"nodes": ["a", "b"], "edges": [["a", "b"]]}
        captured = {}

        def capture(saga):
            captured["saga"] = saga
            return saga.saga_id

        with patch.object(ledger, "write_saga", side_effect=capture):
            writer.open_saga("session-abc", dag)

        assert captured["saga"].orchestrator_dag == dag


class TestBB2RecordProposedDelta:
    """BB2: record_proposed_delta appends delta with agent_id and submitted_at."""

    def test_delta_appended_with_correct_agent_id(self):
        """The SQL executed must embed the correct agent_id in the JSON element."""
        writer, ledger = _make_writer()
        saga_id = str(uuid.uuid4())
        delta = {"key": "value"}

        writer.record_proposed_delta(saga_id, "agent-forge", delta)

        # Capture what was passed to cursor.execute
        call_args = ledger._mock_cursor.execute.call_args
        sql, params = call_args[0]
        # params[0] is the JSON element list; params[1] is saga_id
        element_list = json.loads(params[0])
        assert len(element_list) == 1
        element = element_list[0]
        assert element["agent_id"] == "agent-forge"

    def test_delta_dict_stored_verbatim(self):
        """The delta dict must be stored as-is inside the element."""
        writer, ledger = _make_writer()
        delta = {"mutation": "add_node", "node": "X"}

        writer.record_proposed_delta(str(uuid.uuid4()), "agent-1", delta)

        call_args = ledger._mock_cursor.execute.call_args
        sql, params = call_args[0]
        element_list = json.loads(params[0])
        assert element_list[0]["delta"] == delta

    def test_submitted_at_present_in_element(self):
        """Each delta element must carry a submitted_at field."""
        writer, ledger = _make_writer()

        writer.record_proposed_delta(str(uuid.uuid4()), "agent-x", {"v": 1})

        call_args = ledger._mock_cursor.execute.call_args
        sql, params = call_args[0]
        element_list = json.loads(params[0])
        assert "submitted_at" in element_list[0]

    def test_saga_id_passed_as_second_param(self):
        """The saga_id must be the second SQL parameter (for WHERE clause)."""
        writer, ledger = _make_writer()
        saga_id = str(uuid.uuid4())

        writer.record_proposed_delta(saga_id, "agent-y", {})

        call_args = ledger._mock_cursor.execute.call_args
        sql, params = call_args[0]
        assert params[1] == saga_id

    def test_commit_called_after_append(self):
        """conn.commit() must be called after the delta append."""
        writer, ledger = _make_writer()
        writer.record_proposed_delta(str(uuid.uuid4()), "agent-z", {"a": 1})
        ledger._mock_conn.commit.assert_called()


class TestBB3CloseSaga:
    """BB3: close_saga writes the resolved_state and terminal status."""

    def test_completed_status_accepted(self):
        """close_saga with status='COMPLETED' must not raise."""
        writer, ledger = _make_writer()
        writer.close_saga(str(uuid.uuid4()), {"final": True}, "COMPLETED")
        # No exception = pass

    def test_partial_fail_status_accepted(self):
        """close_saga with status='PARTIAL_FAIL' must not raise."""
        writer, ledger = _make_writer()
        writer.close_saga(str(uuid.uuid4()), {}, "PARTIAL_FAIL")

    def test_failed_status_accepted(self):
        """close_saga with status='FAILED' must not raise."""
        writer, ledger = _make_writer()
        writer.close_saga(str(uuid.uuid4()), {}, "FAILED")

    def test_resolved_state_passed_as_json_param(self):
        """resolved_state must be serialised to JSON and passed as SQL param."""
        writer, ledger = _make_writer()
        resolved = {"agents_done": 3, "summary": "ok"}

        writer.close_saga(str(uuid.uuid4()), resolved, "COMPLETED")

        call_args = ledger._mock_cursor.execute.call_args
        sql, params = call_args[0]
        # params[0] = resolved_state JSON, params[1] = status, params[2] = saga_id
        parsed = json.loads(params[0])
        assert parsed == resolved

    def test_status_passed_as_second_sql_param(self):
        """The status string must be the second SQL parameter."""
        writer, ledger = _make_writer()

        writer.close_saga(str(uuid.uuid4()), {}, "COMPLETED")

        call_args = ledger._mock_cursor.execute.call_args
        sql, params = call_args[0]
        assert params[1] == "COMPLETED"

    def test_commit_called_after_close(self):
        """conn.commit() must be called after the close update."""
        writer, ledger = _make_writer()
        writer.close_saga(str(uuid.uuid4()), {}, "COMPLETED")
        ledger._mock_conn.commit.assert_called()


class TestBB4InvalidStatusRaisesValueError:
    """BB4: close_saga with an invalid status must raise ValueError."""

    def test_bad_status_raises_value_error(self):
        writer, ledger = _make_writer()
        with pytest.raises(ValueError):
            writer.close_saga(str(uuid.uuid4()), {}, "INVALID")

    def test_running_status_not_valid_for_close(self):
        """'RUNNING' is a valid open status but NOT a valid close status."""
        writer, ledger = _make_writer()
        with pytest.raises(ValueError):
            writer.close_saga(str(uuid.uuid4()), {}, "RUNNING")

    def test_empty_string_status_raises(self):
        writer, ledger = _make_writer()
        with pytest.raises(ValueError):
            writer.close_saga(str(uuid.uuid4()), {}, "")

    def test_lowercase_completed_raises(self):
        """Status is case-sensitive — 'completed' must raise."""
        writer, ledger = _make_writer()
        with pytest.raises(ValueError):
            writer.close_saga(str(uuid.uuid4()), {}, "completed")

    def test_error_message_contains_valid_statuses(self):
        """ValueError message should mention the valid statuses."""
        writer, ledger = _make_writer()
        with pytest.raises(ValueError, match="COMPLETED|PARTIAL_FAIL|FAILED"):
            writer.close_saga(str(uuid.uuid4()), {}, "NOPE")


# ===========================================================================
# White Box tests
# ===========================================================================


class TestWB1OpenSagaReturnsUUID:
    """WB1: open_saga returns a UUID4 string (not None, not empty)."""

    def test_returns_string(self):
        writer, ledger = _make_writer()
        result = writer.open_saga("session-1", {})
        assert isinstance(result, str)

    def test_returns_valid_uuid4(self):
        writer, ledger = _make_writer()
        result = writer.open_saga("session-1", {})
        parsed = uuid.UUID(result, version=4)
        assert str(parsed) == result

    def test_two_opens_return_different_ids(self):
        """Each open_saga call must generate a unique saga_id."""
        writer, ledger = _make_writer()
        id1 = writer.open_saga("session-1", {})
        id2 = writer.open_saga("session-1", {})
        assert id1 != id2

    def test_uuid4_is_used(self):
        """open_saga must use uuid.uuid4() — not uuid1/uuid3/uuid5."""
        writer, ledger = _make_writer()
        with patch("core.storage.saga_writer.uuid.uuid4") as mock_uuid4:
            fixed = uuid.UUID("abcdef12-abcd-4bcd-8bcd-abcdef123456")
            mock_uuid4.return_value = fixed
            result = writer.open_saga("session-1", {})
        assert result == str(fixed)
        mock_uuid4.assert_called_once()


class TestWB2SubmittedAtIsISOTimestamp:
    """WB2: each delta element must have a valid ISO 8601 submitted_at timestamp."""

    def test_submitted_at_is_iso_string(self):
        writer, ledger = _make_writer()
        writer.record_proposed_delta(str(uuid.uuid4()), "agent-ts", {"x": 1})

        call_args = ledger._mock_cursor.execute.call_args
        sql, params = call_args[0]
        element = json.loads(params[0])[0]
        submitted_at = element["submitted_at"]

        # Should parse as a datetime without error
        parsed = datetime.fromisoformat(submitted_at)
        assert isinstance(parsed, datetime)

    def test_submitted_at_includes_timezone_info(self):
        """submitted_at must be timezone-aware (UTC)."""
        writer, ledger = _make_writer()
        writer.record_proposed_delta(str(uuid.uuid4()), "agent-tz", {})

        call_args = ledger._mock_cursor.execute.call_args
        sql, params = call_args[0]
        element = json.loads(params[0])[0]
        submitted_at = element["submitted_at"]

        # UTC offset marker must be present (+00:00 or Z)
        assert "+" in submitted_at or submitted_at.endswith("Z"), (
            f"submitted_at must contain timezone offset, got: {submitted_at!r}"
        )


class TestWB3SQLLevelAppendNorReadModifyWrite:
    """WB3: record_proposed_delta must use SQL JSONB append (||), not read-modify-write."""

    def test_only_one_execute_call_for_append(self):
        """record_proposed_delta must issue exactly one SQL statement — no SELECT first."""
        writer, ledger = _make_writer()
        writer.record_proposed_delta(str(uuid.uuid4()), "agent-1", {"v": 1})
        # Only one execute call: the UPDATE with JSONB append
        assert ledger._mock_cursor.execute.call_count == 1

    def test_sql_contains_jsonb_concat_operator(self):
        """The SQL must use the || operator for JSONB concatenation."""
        writer, ledger = _make_writer()
        writer.record_proposed_delta(str(uuid.uuid4()), "agent-sql", {})

        call_args = ledger._mock_cursor.execute.call_args
        sql = call_args[0][0]
        assert "||" in sql, (
            "record_proposed_delta SQL must use || for JSONB append, got: " + sql
        )

    def test_sql_is_update_not_select_then_update(self):
        """The SQL must be UPDATE — no SELECT in record_proposed_delta."""
        writer, ledger = _make_writer()
        writer.record_proposed_delta(str(uuid.uuid4()), "agent-u", {})

        call_args = ledger._mock_cursor.execute.call_args
        sql = call_args[0][0].upper()
        assert sql.startswith("UPDATE"), (
            "record_proposed_delta must use UPDATE, not SELECT+UPDATE"
        )
        assert "SELECT" not in sql, (
            "No SELECT allowed in record_proposed_delta — that would be read-modify-write"
        )

    def test_no_get_saga_called_during_append(self):
        """record_proposed_delta must NOT call ledger.get_saga (no read-before-write)."""
        writer, ledger = _make_writer()
        with patch.object(ledger, "get_saga") as mock_get:
            writer.record_proposed_delta(str(uuid.uuid4()), "agent-v", {})
        mock_get.assert_not_called()

    def test_connection_pool_getconn_putconn_used(self):
        """record_proposed_delta must use ledger pool directly — getconn + putconn."""
        writer, ledger = _make_writer()
        saga_id = str(uuid.uuid4())
        writer.record_proposed_delta(saga_id, "agent-pool", {"m": 1})

        ledger._mock_pool.getconn.assert_called_once()
        ledger._mock_pool.putconn.assert_called_once_with(ledger._mock_conn)

    def test_putconn_called_even_when_execute_raises(self):
        """putconn must be called in finally even if execute raises."""
        writer, ledger = _make_writer()
        ledger._mock_cursor.execute.side_effect = RuntimeError("DB down")

        with pytest.raises(RuntimeError):
            writer.record_proposed_delta(str(uuid.uuid4()), "agent-err", {})

        ledger._mock_pool.putconn.assert_called_once_with(ledger._mock_conn)


# ===========================================================================
# get_saga delegation tests
# ===========================================================================


class TestGetSagaDelegation:
    """get_saga must delegate cleanly to ledger.get_saga()."""

    def test_delegates_to_ledger(self):
        writer, ledger = _make_writer()
        saga_id = str(uuid.uuid4())
        expected = SwarmSaga(
            saga_id=saga_id, session_id=str(uuid.uuid4()),
            orchestrator_dag={}, proposed_deltas=[],
            resolved_state=None, status="RUNNING",
            created_at=datetime(2026, 2, 25),
        )
        with patch.object(ledger, "get_saga", return_value=expected) as mock_get:
            result = writer.get_saga(saga_id)
        mock_get.assert_called_once_with(saga_id)
        assert result is expected

    def test_returns_none_for_unknown_saga(self):
        writer, ledger = _make_writer(fetchone=None)
        result = writer.get_saga("00000000-0000-0000-0000-000000000000")
        assert result is None


# ===========================================================================
# Package export tests
# ===========================================================================


class TestPackageExport:
    """SwarmSagaWriter must be importable directly from core.storage."""

    def test_importable_from_package(self):
        assert SwarmSagaWriterFromPackage is SwarmSagaWriter

    def test_in_dunder_all(self):
        from core.storage import __all__
        assert "SwarmSagaWriter" in __all__


# ===========================================================================
# Source code quality checks (no SQLite, no f-string SQL)
# ===========================================================================


class TestSourceCodeQuality:
    """saga_writer.py must follow Genesis hardwired code standards."""

    _SOURCE = pathlib.Path("/mnt/e/genesis-system/core/storage/saga_writer.py").read_text()

    def test_no_sqlite3_import(self):
        assert "import sqlite3" not in self._SOURCE, (
            "saga_writer.py must NOT import sqlite3 — Genesis Rule 7"
        )

    def test_no_fstring_sql(self):
        fstring_sql = re.findall(
            r'f["\'].*?(SELECT|INSERT|UPDATE|DELETE|WHERE).*?["\']',
            self._SOURCE,
        )
        assert not fstring_sql, (
            f"Found f-string SQL (injection risk): {fstring_sql}"
        )

    def test_uuid4_used_not_uuid1(self):
        assert "uuid.uuid4()" in self._SOURCE, (
            "open_saga must use uuid.uuid4() for saga_id generation"
        )
        assert "uuid.uuid1()" not in self._SOURCE


# ===========================================================================
# Standalone runner
# ===========================================================================

if __name__ == "__main__":
    import traceback

    tests = [
        # BB1
        ("BB1a: open_saga calls write_saga", TestBB1OpenSagaCreatesRunningRecord().test_write_saga_called),
        ("BB1b: status=RUNNING", TestBB1OpenSagaCreatesRunningRecord().test_written_saga_has_running_status),
        ("BB1c: proposed_deltas=[]", TestBB1OpenSagaCreatesRunningRecord().test_written_saga_has_empty_proposed_deltas),
        ("BB1d: resolved_state=None", TestBB1OpenSagaCreatesRunningRecord().test_written_saga_has_none_resolved_state),
        ("BB1e: orchestrator_dag stored", TestBB1OpenSagaCreatesRunningRecord().test_orchestrator_dag_stored_correctly),
        # BB2
        ("BB2a: delta has agent_id", TestBB2RecordProposedDelta().test_delta_appended_with_correct_agent_id),
        ("BB2b: delta dict stored verbatim", TestBB2RecordProposedDelta().test_delta_dict_stored_verbatim),
        ("BB2c: submitted_at present", TestBB2RecordProposedDelta().test_submitted_at_present_in_element),
        ("BB2d: saga_id as second param", TestBB2RecordProposedDelta().test_saga_id_passed_as_second_param),
        ("BB2e: commit called after append", TestBB2RecordProposedDelta().test_commit_called_after_append),
        # BB3
        ("BB3a: COMPLETED accepted", TestBB3CloseSaga().test_completed_status_accepted),
        ("BB3b: PARTIAL_FAIL accepted", TestBB3CloseSaga().test_partial_fail_status_accepted),
        ("BB3c: FAILED accepted", TestBB3CloseSaga().test_failed_status_accepted),
        ("BB3d: resolved_state as JSON param", TestBB3CloseSaga().test_resolved_state_passed_as_json_param),
        ("BB3e: status as second param", TestBB3CloseSaga().test_status_passed_as_second_sql_param),
        ("BB3f: commit called after close", TestBB3CloseSaga().test_commit_called_after_close),
        # BB4
        ("BB4a: INVALID → ValueError", TestBB4InvalidStatusRaisesValueError().test_bad_status_raises_value_error),
        ("BB4b: RUNNING → ValueError", TestBB4InvalidStatusRaisesValueError().test_running_status_not_valid_for_close),
        ("BB4c: empty string → ValueError", TestBB4InvalidStatusRaisesValueError().test_empty_string_status_raises),
        ("BB4d: lowercase → ValueError", TestBB4InvalidStatusRaisesValueError().test_lowercase_completed_raises),
        ("BB4e: error message mentions valid values", TestBB4InvalidStatusRaisesValueError().test_error_message_contains_valid_statuses),
        # WB1
        ("WB1a: returns string", TestWB1OpenSagaReturnsUUID().test_returns_string),
        ("WB1b: valid UUID4", TestWB1OpenSagaReturnsUUID().test_returns_valid_uuid4),
        ("WB1c: unique per call", TestWB1OpenSagaReturnsUUID().test_two_opens_return_different_ids),
        ("WB1d: uuid4 used internally", TestWB1OpenSagaReturnsUUID().test_uuid4_is_used),
        # WB2
        ("WB2a: submitted_at is ISO string", TestWB2SubmittedAtIsISOTimestamp().test_submitted_at_is_iso_string),
        ("WB2b: submitted_at has timezone", TestWB2SubmittedAtIsISOTimestamp().test_submitted_at_includes_timezone_info),
        # WB3
        ("WB3a: one execute call only", TestWB3SQLLevelAppendNorReadModifyWrite().test_only_one_execute_call_for_append),
        ("WB3b: SQL uses || operator", TestWB3SQLLevelAppendNorReadModifyWrite().test_sql_contains_jsonb_concat_operator),
        ("WB3c: UPDATE not SELECT+UPDATE", TestWB3SQLLevelAppendNorReadModifyWrite().test_sql_is_update_not_select_then_update),
        ("WB3d: no get_saga during append", TestWB3SQLLevelAppendNorReadModifyWrite().test_no_get_saga_called_during_append),
        ("WB3e: pool getconn+putconn used", TestWB3SQLLevelAppendNorReadModifyWrite().test_connection_pool_getconn_putconn_used),
        ("WB3f: putconn called even on error", TestWB3SQLLevelAppendNorReadModifyWrite().test_putconn_called_even_when_execute_raises),
        # get_saga delegation
        ("DEL1: delegates to ledger", TestGetSagaDelegation().test_delegates_to_ledger),
        ("DEL2: None for unknown saga", TestGetSagaDelegation().test_returns_none_for_unknown_saga),
        # package export
        ("PKG1: importable from package", TestPackageExport().test_importable_from_package),
        ("PKG2: in __all__", TestPackageExport().test_in_dunder_all),
        # source quality
        ("SRC1: no sqlite3 import", TestSourceCodeQuality().test_no_sqlite3_import),
        ("SRC2: no f-string SQL", TestSourceCodeQuality().test_no_fstring_sql),
        ("SRC3: uuid4 not uuid1", TestSourceCodeQuality().test_uuid4_used_not_uuid1),
    ]

    passed = 0
    failed = 0
    for name, fn in tests:
        try:
            fn()
            print(f"  [PASS] {name}")
            passed += 1
        except Exception as exc:
            print(f"  [FAIL] {name}: {exc}")
            traceback.print_exc()
            failed += 1

    print(f"\n{passed}/{passed + failed} tests passed")
    if failed == 0:
        print("ALL TESTS PASSED — Story 5.04 (Track B): SwarmSagaWriter")
    else:
        sys.exit(1)
