#!/usr/bin/env python3
"""
Tests for Story 5.02 (Track B): ColdLedger — L4 Postgres Read/Write Client

Black Box tests (BB): verify the public contract from the caller's perspective —
    correct UUIDs returned, filtering, ordering, None on missing records.
White Box tests (WB): verify internals — connection pool getconn/putconn usage,
    no SQLite import, parameterised queries only, UUID4 generation, close().

ALL tests use mocks — NO real Postgres connection is required.

Story: 5.02
File under test: core/storage/cold_ledger.py
"""

from __future__ import annotations

import sys
sys.path.insert(0, "/mnt/e/genesis-system")

import importlib
import json
import re
import uuid
from datetime import datetime
from unittest.mock import MagicMock, call, patch, PropertyMock

import pytest

# ---------------------------------------------------------------------------
# Module under test
# ---------------------------------------------------------------------------

from core.storage.cold_ledger import ColdLedger, SwarmSaga
from core.storage import ColdLedger as ColdLedgerFromPackage
from core.storage import SwarmSaga as SwarmSagaFromPackage


# ---------------------------------------------------------------------------
# Mock-connection factory helpers
# ---------------------------------------------------------------------------

def _make_pool_and_conn(fetchone_return=None, fetchall_return=None):
    """
    Returns (mock_pool, mock_conn) where:
      - mock_pool.getconn() → mock_conn
      - mock_pool.putconn(conn) is tracked
      - mock_conn.cursor() supports context manager
      - cursor.fetchone() → fetchone_return
      - cursor.fetchall() → fetchall_return or []
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()

    # Context-manager cursor
    mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
    mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)

    # Default query results
    mock_cursor.fetchone.return_value = fetchone_return
    mock_cursor.fetchall.return_value = fetchall_return if fetchall_return is not None else []

    mock_pool = MagicMock()
    mock_pool.getconn.return_value = mock_conn

    return mock_pool, mock_conn, mock_cursor


def _make_ledger(fetchone=None, fetchall=None):
    """Return a ColdLedger instance wired to a fully-mocked pool."""
    mock_pool, mock_conn, mock_cursor = _make_pool_and_conn(fetchone, fetchall)
    with patch("psycopg2.pool.ThreadedConnectionPool", return_value=mock_pool):
        ledger = ColdLedger({"host": "localhost", "port": 5432,
                             "user": "u", "password": "p", "dbname": "genesis"})
    ledger._mock_pool = mock_pool
    ledger._mock_conn = mock_conn
    ledger._mock_cursor = mock_cursor
    return ledger


def _make_saga(
    saga_id: str | None = None,
    session_id: str | None = None,
    status: str = "RUNNING",
    resolved: dict | None = None,
) -> SwarmSaga:
    return SwarmSaga(
        saga_id=saga_id or str(uuid.uuid4()),
        session_id=session_id or str(uuid.uuid4()),
        orchestrator_dag={"step": 1},
        proposed_deltas=[{"delta": "a"}],
        resolved_state=resolved,
        status=status,
        created_at=datetime(2026, 2, 25, 10, 0, 0),
    )


# ===========================================================================
# Black Box tests
# ===========================================================================


class TestBB1WriteEventReturnsUUID:
    """BB1: write_event returns a valid UUID string."""

    def test_returns_string(self):
        ledger = _make_ledger()
        result = ledger.write_event("session-1", "dispatch_start", {"k": "v"})
        assert isinstance(result, str)

    def test_returned_value_is_valid_uuid(self):
        ledger = _make_ledger()
        result = ledger.write_event("session-1", "dispatch_start", {})
        # Must not raise
        parsed = uuid.UUID(result)
        assert str(parsed) == result

    def test_two_calls_return_different_uuids(self):
        ledger = _make_ledger()
        id1 = ledger.write_event("s1", "type_a", {})
        id2 = ledger.write_event("s1", "type_b", {})
        assert id1 != id2


class TestBB2GetEventsFilterByType:
    """BB2: get_events with event_type filter returns matching records."""

    def test_filter_applied(self):
        row = {
            "id": str(uuid.uuid4()),
            "session_id": "sess-1",
            "event_type": "dispatch_start",
            "payload": {"a": 1},
            "created_at": datetime(2026, 2, 25),
        }
        ledger = _make_ledger(fetchall=[row])
        results = ledger.get_events("sess-1", event_type="dispatch_start")
        assert len(results) == 1
        assert results[0]["event_type"] == "dispatch_start"

    def test_no_filter_uses_different_sql_branch(self):
        """When event_type is None, a different (shorter) SQL path is used."""
        ledger = _make_ledger(fetchall=[])
        # Should not raise regardless of which branch executes
        results = ledger.get_events("sess-1")
        assert results == []


class TestBB3TwoWritesSameSession:
    """BB3: two writes with same session_id are both retrievable."""

    def test_get_events_returns_both(self):
        session_id = str(uuid.uuid4())
        row1 = {
            "id": str(uuid.uuid4()), "session_id": session_id,
            "event_type": "start", "payload": {}, "created_at": datetime(2026, 2, 25),
        }
        row2 = {
            "id": str(uuid.uuid4()), "session_id": session_id,
            "event_type": "end", "payload": {}, "created_at": datetime(2026, 2, 25),
        }
        ledger = _make_ledger(fetchall=[row1, row2])
        results = ledger.get_events(session_id)
        assert len(results) == 2


class TestBB4WriteSagaThenGetSaga:
    """BB4: write_saga then get_saga returns matching SwarmSaga."""

    def test_returned_saga_has_correct_fields(self):
        saga = _make_saga(status="COMPLETED")
        # Simulate DB returning the row we just wrote
        db_row = {
            "saga_id": saga.saga_id,
            "session_id": saga.session_id,
            "orchestrator_dag": saga.orchestrator_dag,
            "proposed_deltas": saga.proposed_deltas,
            "resolved_state": None,
            "status": "COMPLETED",
            "created_at": saga.created_at,
        }
        ledger = _make_ledger(fetchone=db_row)
        # write_saga uses a separate pool acquire — so we patch get_saga specifically
        with patch.object(ledger, "write_saga", return_value=saga.saga_id):
            saga_id = ledger.write_saga(saga)
        fetched = ledger.get_saga(saga.saga_id)
        assert fetched is not None
        assert fetched.saga_id == saga.saga_id
        assert fetched.status == "COMPLETED"

    def test_saga_fields_match_original(self):
        saga = _make_saga(status="RUNNING")
        db_row = {
            "saga_id": saga.saga_id,
            "session_id": saga.session_id,
            "orchestrator_dag": {"step": 1},
            "proposed_deltas": [{"delta": "a"}],
            "resolved_state": None,
            "status": "RUNNING",
            "created_at": saga.created_at,
        }
        ledger = _make_ledger(fetchone=db_row)
        fetched = ledger.get_saga(saga.saga_id)
        assert fetched.session_id == saga.session_id
        assert fetched.orchestrator_dag == {"step": 1}
        assert fetched.proposed_deltas == [{"delta": "a"}]
        assert fetched.resolved_state is None


class TestBB5GetSagasBySession:
    """BB5: get_sagas_by_session returns all sagas for that session."""

    def test_returns_multiple_sagas(self):
        session_id = str(uuid.uuid4())
        now = datetime(2026, 2, 25)
        rows = [
            {
                "saga_id": str(uuid.uuid4()), "session_id": session_id,
                "orchestrator_dag": {}, "proposed_deltas": [],
                "resolved_state": None, "status": "RUNNING", "created_at": now,
            },
            {
                "saga_id": str(uuid.uuid4()), "session_id": session_id,
                "orchestrator_dag": {}, "proposed_deltas": [],
                "resolved_state": None, "status": "COMPLETED", "created_at": now,
            },
        ]
        ledger = _make_ledger(fetchall=rows)
        results = ledger.get_sagas_by_session(session_id)
        assert len(results) == 2
        assert all(isinstance(s, SwarmSaga) for s in results)

    def test_returns_empty_list_when_none(self):
        ledger = _make_ledger(fetchall=[])
        results = ledger.get_sagas_by_session(str(uuid.uuid4()))
        assert results == []


# ===========================================================================
# White Box tests
# ===========================================================================


class TestWB1ConnectionPoolPattern:
    """WB1: getconn/putconn pattern — connection always returned in finally."""

    def test_write_event_calls_getconn_then_putconn(self):
        ledger = _make_ledger()
        pool = ledger._mock_pool
        ledger.write_event("sess", "type", {})
        pool.getconn.assert_called_once()
        pool.putconn.assert_called_once_with(ledger._mock_conn)

    def test_get_events_calls_getconn_then_putconn(self):
        ledger = _make_ledger(fetchall=[])
        pool = ledger._mock_pool
        ledger.get_events("sess")
        pool.getconn.assert_called_once()
        pool.putconn.assert_called_once_with(ledger._mock_conn)

    def test_write_saga_calls_getconn_then_putconn(self):
        ledger = _make_ledger()
        pool = ledger._mock_pool
        ledger.write_saga(_make_saga())
        pool.getconn.assert_called_once()
        pool.putconn.assert_called_once_with(ledger._mock_conn)

    def test_get_saga_calls_getconn_then_putconn(self):
        ledger = _make_ledger(fetchone=None)
        pool = ledger._mock_pool
        ledger.get_saga(str(uuid.uuid4()))
        pool.getconn.assert_called_once()
        pool.putconn.assert_called_once_with(ledger._mock_conn)

    def test_get_sagas_by_session_calls_getconn_then_putconn(self):
        ledger = _make_ledger(fetchall=[])
        pool = ledger._mock_pool
        ledger.get_sagas_by_session(str(uuid.uuid4()))
        pool.getconn.assert_called_once()
        pool.putconn.assert_called_once_with(ledger._mock_conn)

    def test_putconn_called_even_when_execute_raises(self):
        """putconn must be called in finally even if the cursor raises."""
        ledger = _make_ledger()
        # Make cursor.execute raise
        ledger._mock_cursor.execute.side_effect = RuntimeError("DB error")
        pool = ledger._mock_pool
        with pytest.raises(RuntimeError):
            ledger.write_event("sess", "type", {})
        # putconn MUST still have been called
        pool.putconn.assert_called_once_with(ledger._mock_conn)


class TestWB2NoSQLiteImport:
    """WB2: cold_ledger.py must not import sqlite3 anywhere."""

    def test_no_sqlite3_import_in_source(self):
        import pathlib
        source = pathlib.Path("/mnt/e/genesis-system/core/storage/cold_ledger.py").read_text()
        assert "import sqlite3" not in source, (
            "cold_ledger.py must NOT import sqlite3 — Genesis Rule 7 (no SQLite)"
        )

    def test_sqlite3_not_in_module_namespace(self):
        import core.storage.cold_ledger as mod
        assert not hasattr(mod, "sqlite3"), (
            "sqlite3 must not be present in cold_ledger module namespace"
        )


class TestWB3ParameterisedQueries:
    """WB3: all SQL in cold_ledger.py uses %s placeholders — no f-string SQL."""

    def test_no_fstring_sql_in_source(self):
        import pathlib
        source = pathlib.Path("/mnt/e/genesis-system/core/storage/cold_ledger.py").read_text()

        # Look for f-strings that contain SQL keywords — a heuristic for f-string SQL
        # Pattern: f"... SELECT|INSERT|UPDATE|DELETE|WHERE ..."
        fstring_sql = re.findall(r'f["\'].*?(SELECT|INSERT|UPDATE|DELETE|WHERE).*?["\']', source)
        assert not fstring_sql, (
            f"Found f-string SQL (injection risk): {fstring_sql}"
        )

    def test_write_event_sql_executed_with_params_tuple(self):
        """write_event must call cur.execute(sql, params) — not bare cur.execute(sql)."""
        ledger = _make_ledger()
        ledger.write_event("sess-123", "type_a", {"k": "v"})
        cur = ledger._mock_cursor
        assert cur.execute.call_count >= 1
        # The execute call must have a second argument (params tuple)
        for call_args in cur.execute.call_args_list:
            args = call_args[0]
            assert len(args) >= 2, (
                "cur.execute() must be called with (sql, params) — not just (sql,)"
            )


class TestWB4WriteEventGeneratesUUID4:
    """WB4: write_event generates a UUID4 for event_id."""

    def test_event_id_passed_to_execute_is_uuid4(self):
        ledger = _make_ledger()
        with patch("core.storage.cold_ledger.uuid.uuid4") as mock_uuid4:
            fixed_id = uuid.UUID("12345678-1234-4234-b234-123456789abc")
            mock_uuid4.return_value = fixed_id
            returned_id = ledger.write_event("sess", "t", {})

        assert returned_id == str(fixed_id)
        # Verify uuid4 was called (not uuid1/uuid3/uuid5)
        mock_uuid4.assert_called()

    def test_uuid4_is_used_not_uuid1(self):
        import core.storage.cold_ledger as mod
        # The module must reference uuid.uuid4, not uuid.uuid1
        import pathlib
        source = pathlib.Path("/mnt/e/genesis-system/core/storage/cold_ledger.py").read_text()
        assert "uuid.uuid4()" in source, "write_event must use uuid.uuid4() for event IDs"
        assert "uuid.uuid1()" not in source, "Must not use uuid1 — use uuid4"


class TestWB5GetSagaUnknownIdReturnsNone:
    """WB5: get_saga with an unknown ID returns None (not an exception)."""

    def test_unknown_saga_id_returns_none(self):
        ledger = _make_ledger(fetchone=None)
        result = ledger.get_saga("00000000-0000-0000-0000-000000000000")
        assert result is None, "get_saga must return None for unknown saga_id"

    def test_does_not_raise_on_missing_row(self):
        ledger = _make_ledger(fetchone=None)
        # Must not raise KeyError, AttributeError, or any other exception
        try:
            result = ledger.get_saga(str(uuid.uuid4()))
            assert result is None
        except Exception as exc:
            pytest.fail(f"get_saga raised unexpectedly: {exc}")


# ===========================================================================
# Package export tests
# ===========================================================================


class TestPackageExports:
    """ColdLedger and SwarmSaga must be importable from core.storage."""

    def test_cold_ledger_importable_from_package(self):
        assert ColdLedgerFromPackage is ColdLedger

    def test_swarm_saga_importable_from_package(self):
        assert SwarmSagaFromPackage is SwarmSaga

    def test_all_includes_cold_ledger(self):
        from core.storage import __all__
        assert "ColdLedger" in __all__

    def test_all_includes_swarm_saga(self):
        from core.storage import __all__
        assert "SwarmSaga" in __all__


# ===========================================================================
# SwarmSaga dataclass tests
# ===========================================================================


class TestSwarmSagaDataclass:
    """SwarmSaga must be a proper typed dataclass."""

    def test_instantiation(self):
        saga = _make_saga()
        assert isinstance(saga, SwarmSaga)

    def test_all_fields_accessible(self):
        saga = _make_saga(status="FAILED")
        assert saga.status == "FAILED"
        assert isinstance(saga.orchestrator_dag, dict)
        assert isinstance(saga.proposed_deltas, list)
        assert saga.resolved_state is None

    def test_resolved_state_can_be_dict(self):
        saga = _make_saga(resolved={"final": True})
        assert saga.resolved_state == {"final": True}


# ===========================================================================
# ColdLedger.close() test
# ===========================================================================


class TestClose:
    """close() must call pool.closeall()."""

    def test_close_calls_closeall(self):
        ledger = _make_ledger()
        ledger.close()
        ledger._mock_pool.closeall.assert_called_once()

    def test_close_is_idempotent_when_pool_is_mock(self):
        ledger = _make_ledger()
        ledger.close()
        ledger.close()  # second call must not raise
        assert ledger._mock_pool.closeall.call_count == 2


# ===========================================================================
# ThreadedConnectionPool instantiation test
# ===========================================================================


class TestConnectionPoolInit:
    """ColdLedger must create ThreadedConnectionPool(minconn=2, maxconn=10)."""

    def test_pool_created_with_correct_min_max(self):
        with patch("psycopg2.pool.ThreadedConnectionPool") as mock_pool_cls:
            mock_pool_cls.return_value = MagicMock()
            params = {"host": "h", "port": 5432, "user": "u",
                      "password": "p", "dbname": "db"}
            ColdLedger(params)

        call_args = mock_pool_cls.call_args
        positional = call_args[0]
        assert positional[0] == 2, f"minconn must be 2, got {positional[0]}"
        assert positional[1] == 10, f"maxconn must be 10, got {positional[1]}"

    def test_pool_receives_connection_params(self):
        with patch("psycopg2.pool.ThreadedConnectionPool") as mock_pool_cls:
            mock_pool_cls.return_value = MagicMock()
            params = {"host": "myhost", "port": 5433, "user": "admin",
                      "password": "secret", "dbname": "genesis_test"}
            ColdLedger(params)

        kwargs = mock_pool_cls.call_args[1]
        assert kwargs["host"] == "myhost"
        assert kwargs["dbname"] == "genesis_test"


# ===========================================================================
# Standalone runner
# ===========================================================================

if __name__ == "__main__":
    import traceback

    tests = [
        # BB
        ("BB1a: write_event returns string", TestBB1WriteEventReturnsUUID().test_returns_string),
        ("BB1b: returned value is valid UUID", TestBB1WriteEventReturnsUUID().test_returned_value_is_valid_uuid),
        ("BB1c: two calls → different UUIDs", TestBB1WriteEventReturnsUUID().test_two_calls_return_different_uuids),
        ("BB2a: filter applied", TestBB2GetEventsFilterByType().test_filter_applied),
        ("BB2b: no filter uses different branch", TestBB2GetEventsFilterByType().test_no_filter_uses_different_sql_branch),
        ("BB3: two writes same session", TestBB3TwoWritesSameSession().test_get_events_returns_both),
        ("BB4a: write saga then get saga", TestBB4WriteSagaThenGetSaga().test_returned_saga_has_correct_fields),
        ("BB4b: saga fields match original", TestBB4WriteSagaThenGetSaga().test_saga_fields_match_original),
        ("BB5a: get sagas by session multiple", TestBB5GetSagasBySession().test_returns_multiple_sagas),
        ("BB5b: get sagas empty", TestBB5GetSagasBySession().test_returns_empty_list_when_none),
        # WB
        ("WB1a: write_event getconn/putconn", TestWB1ConnectionPoolPattern().test_write_event_calls_getconn_then_putconn),
        ("WB1b: get_events getconn/putconn", TestWB1ConnectionPoolPattern().test_get_events_calls_getconn_then_putconn),
        ("WB1c: write_saga getconn/putconn", TestWB1ConnectionPoolPattern().test_write_saga_calls_getconn_then_putconn),
        ("WB1d: get_saga getconn/putconn", TestWB1ConnectionPoolPattern().test_get_saga_calls_getconn_then_putconn),
        ("WB1e: get_sagas_by_session getconn/putconn", TestWB1ConnectionPoolPattern().test_get_sagas_by_session_calls_getconn_then_putconn),
        ("WB1f: putconn called even on error", TestWB1ConnectionPoolPattern().test_putconn_called_even_when_execute_raises),
        ("WB2a: no sqlite3 import in source", TestWB2NoSQLiteImport().test_no_sqlite3_import_in_source),
        ("WB2b: sqlite3 not in namespace", TestWB2NoSQLiteImport().test_sqlite3_not_in_module_namespace),
        ("WB3a: no f-string SQL", TestWB3ParameterisedQueries().test_no_fstring_sql_in_source),
        ("WB3b: write_event uses param tuple", TestWB3ParameterisedQueries().test_write_event_sql_executed_with_params_tuple),
        ("WB4a: event_id is UUID4", TestWB4WriteEventGeneratesUUID4().test_event_id_passed_to_execute_is_uuid4),
        ("WB4b: uuid4 not uuid1", TestWB4WriteEventGeneratesUUID4().test_uuid4_is_used_not_uuid1),
        ("WB5a: unknown saga → None", TestWB5GetSagaUnknownIdReturnsNone().test_unknown_saga_id_returns_none),
        ("WB5b: no exception on missing row", TestWB5GetSagaUnknownIdReturnsNone().test_does_not_raise_on_missing_row),
        # Package
        ("PKG: ColdLedger importable", TestPackageExports().test_cold_ledger_importable_from_package),
        ("PKG: SwarmSaga importable", TestPackageExports().test_swarm_saga_importable_from_package),
        ("PKG: __all__ has ColdLedger", TestPackageExports().test_all_includes_cold_ledger),
        ("PKG: __all__ has SwarmSaga", TestPackageExports().test_all_includes_swarm_saga),
        # Pool init
        ("POOL: minconn=2 maxconn=10", TestConnectionPoolInit().test_pool_created_with_correct_min_max),
        ("POOL: params forwarded", TestConnectionPoolInit().test_pool_receives_connection_params),
        # Close
        ("CLOSE: closeall called", TestClose().test_close_calls_closeall),
    ]

    passed = 0
    failed = 0
    for name, fn in tests:
        try:
            fn()
            print(f"  [PASS] {name}")
            passed += 1
        except Exception as exc:
            print(f"  [FAIL] {name}: {exc}")
            traceback.print_exc()
            failed += 1

    print(f"\n{passed}/{passed + failed} tests passed")
    if failed == 0:
        print("ALL TESTS PASSED -- Story 5.02 (Track B): ColdLedger")
    else:
        sys.exit(1)
