#!/usr/bin/env python3
"""
Tests for Story 8.02: OpenClaw Bridge — Redis Queue Writer
AIVA RLM Nexus PRD v2 — Track A, Module 8

Black box tests (BB1-BB3): verify observable contract — correct queue selection,
expired message rejection, and Redis error handling — without inspecting internals.

White box tests (WB1-WB3): verify implementation properties — RPUSH (not LPUSH),
UTC-normalised expiry comparison, and JSON (not pickle) serialisation.

All Redis I/O is mocked via unittest.mock.MagicMock.  No live Redis connection
is required or expected.
"""

from __future__ import annotations

import asyncio
import json
import sys
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, call

import pytest

sys.path.insert(0, "/mnt/e/genesis-system")

from core.bridge.openclaw_bridge import (
    BridgeWriter,
    BRIDGE_QUEUE_AIVA_TO_GENESIS,
    BRIDGE_QUEUE_GENESIS_TO_AIVA,
    MessageDirection,
    OpenClawMessage,
)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _make_msg(**overrides) -> OpenClawMessage:
    """Return a fully-populated OpenClawMessage with sensible defaults."""
    defaults = dict(
        message_id="550e8400-e29b-41d4-a716-446655440000",
        session_id="session-test-002",
        direction=MessageDirection.AIVA_TO_GENESIS,
        payload={"intent": "task_request", "body": "test payload 8.02"},
        priority=2,
        created_at=datetime(2026, 2, 25, 10, 0, 0, tzinfo=timezone.utc),
    )
    defaults.update(overrides)
    return OpenClawMessage(**defaults)


def _make_writer() -> tuple[BridgeWriter, MagicMock]:
    """Return a (BridgeWriter, mock_redis) pair."""
    mock_redis = MagicMock()
    writer = BridgeWriter(redis_client=mock_redis)
    return writer, mock_redis


def _run(coro):
    """Execute a coroutine synchronously (works without a running event loop)."""
    return asyncio.get_event_loop().run_until_complete(coro)


# ---------------------------------------------------------------------------
# BB1: AIVA_TO_GENESIS message is pushed to bridge:queue:aiva_to_genesis
# ---------------------------------------------------------------------------


class TestBB1_CorrectQueueForAivaToGenesis:
    """BB1: A message with direction AIVA_TO_GENESIS routes to the correct queue."""

    def test_aiva_to_genesis_uses_correct_queue(self):
        """RPUSH must target BRIDGE_QUEUE_AIVA_TO_GENESIS."""
        writer, mock_redis = _make_writer()
        msg = _make_msg(direction=MessageDirection.AIVA_TO_GENESIS)

        result = _run(writer.send(msg))

        assert result is True
        assert mock_redis.rpush.call_count == 1
        actual_key = mock_redis.rpush.call_args[0][0]
        assert actual_key == BRIDGE_QUEUE_AIVA_TO_GENESIS

    def test_genesis_to_aiva_uses_correct_queue(self):
        """RPUSH must target BRIDGE_QUEUE_GENESIS_TO_AIVA for the other direction."""
        writer, mock_redis = _make_writer()
        msg = _make_msg(direction=MessageDirection.GENESIS_TO_AIVA)

        result = _run(writer.send(msg))

        assert result is True
        actual_key = mock_redis.rpush.call_args[0][0]
        assert actual_key == BRIDGE_QUEUE_GENESIS_TO_AIVA

    def test_two_different_directions_hit_different_queues(self):
        """Each direction must route to its own distinct queue key."""
        assert BRIDGE_QUEUE_AIVA_TO_GENESIS != BRIDGE_QUEUE_GENESIS_TO_AIVA

    def test_return_value_is_true_on_success(self):
        """send() must return True when Redis accepts the message."""
        writer, _ = _make_writer()
        msg = _make_msg()
        assert _run(writer.send(msg)) is True

    def test_payload_ends_up_in_redis_push(self):
        """The value pushed to Redis must contain the message_id field."""
        writer, mock_redis = _make_writer()
        msg = _make_msg(message_id="unique-id-xyz")

        _run(writer.send(msg))

        pushed_value = mock_redis.rpush.call_args[0][1]
        decoded = json.loads(pushed_value)
        assert decoded["message_id"] == "unique-id-xyz"


# ---------------------------------------------------------------------------
# BB2: Expired message (expires_at in the past) is NOT queued
# ---------------------------------------------------------------------------


class TestBB2_ExpiredMessageRejected:
    """BB2: Messages with expires_at in the past must be dropped before Redis is touched."""

    def test_expired_message_returns_false(self):
        """send() must return False for a message that expired 5 minutes ago."""
        writer, mock_redis = _make_writer()
        past = datetime.now(timezone.utc) - timedelta(minutes=5)
        msg = _make_msg(expires_at=past)

        result = _run(writer.send(msg))

        assert result is False

    def test_expired_message_does_not_call_rpush(self):
        """Redis MUST NOT be touched for an expired message."""
        writer, mock_redis = _make_writer()
        past = datetime.now(timezone.utc) - timedelta(seconds=1)
        msg = _make_msg(expires_at=past)

        _run(writer.send(msg))

        mock_redis.rpush.assert_not_called()

    def test_non_expired_message_is_queued(self):
        """A message with expires_at in the future must be queued normally."""
        writer, mock_redis = _make_writer()
        future = datetime.now(timezone.utc) + timedelta(minutes=30)
        msg = _make_msg(expires_at=future)

        result = _run(writer.send(msg))

        assert result is True
        assert mock_redis.rpush.call_count == 1

    def test_no_expires_at_is_always_queued(self):
        """A message with expires_at=None must never be treated as expired."""
        writer, mock_redis = _make_writer()
        msg = _make_msg(expires_at=None)

        result = _run(writer.send(msg))

        assert result is True
        assert mock_redis.rpush.call_count == 1


# ---------------------------------------------------------------------------
# BB3: Redis error → False returned (no exception propagated)
# ---------------------------------------------------------------------------


class TestBB3_RedisErrorHandled:
    """BB3: If Redis raises, send() must return False without re-raising."""

    def test_redis_exception_returns_false(self):
        """A RuntimeError from rpush must be caught and False returned."""
        mock_redis = MagicMock()
        mock_redis.rpush.side_effect = RuntimeError("connection refused")
        writer = BridgeWriter(redis_client=mock_redis)
        msg = _make_msg()

        result = _run(writer.send(msg))

        assert result is False

    def test_redis_exception_does_not_propagate(self):
        """No exception must escape send() even if Redis raises."""
        mock_redis = MagicMock()
        mock_redis.rpush.side_effect = Exception("unexpected error")
        writer = BridgeWriter(redis_client=mock_redis)
        msg = _make_msg()

        try:
            _run(writer.send(msg))
        except Exception as exc:
            pytest.fail(f"send() propagated exception: {exc}")

    def test_connection_error_returns_false(self):
        """A ConnectionError from rpush must also be handled gracefully."""
        mock_redis = MagicMock()
        mock_redis.rpush.side_effect = ConnectionError("Redis down")
        writer = BridgeWriter(redis_client=mock_redis)
        msg = _make_msg()

        assert _run(writer.send(msg)) is False


# ---------------------------------------------------------------------------
# WB1: RPUSH is used (not LPUSH) — FIFO ordering
# ---------------------------------------------------------------------------


class TestWB1_RPushUsed:
    """WB1: send() must call redis.rpush(), NOT redis.lpush()."""

    def test_rpush_called(self):
        """The mock's rpush method must be called exactly once."""
        writer, mock_redis = _make_writer()
        msg = _make_msg()

        _run(writer.send(msg))

        mock_redis.rpush.assert_called_once()

    def test_lpush_never_called(self):
        """lpush must NEVER be called — that would reverse queue order."""
        writer, mock_redis = _make_writer()
        msg = _make_msg()

        _run(writer.send(msg))

        mock_redis.lpush.assert_not_called()

    def test_rpush_receives_two_positional_args(self):
        """rpush(key, value) must be called with exactly 2 positional arguments."""
        writer, mock_redis = _make_writer()
        msg = _make_msg()

        _run(writer.send(msg))

        args, kwargs = mock_redis.rpush.call_args
        assert len(args) == 2, (
            f"Expected rpush(key, value) — 2 positional args, got {len(args)}"
        )


# ---------------------------------------------------------------------------
# WB2: Expired check uses UTC comparison (handles naive datetimes)
# ---------------------------------------------------------------------------


class TestWB2_UtcExpiryNormalisation:
    """WB2: expires_at comparison must normalise naive datetimes to UTC."""

    def test_naive_datetime_in_past_is_treated_as_expired(self):
        """A naive expires_at that is clearly in the past must be rejected."""
        writer, mock_redis = _make_writer()
        # naive datetime, 1 hour behind current UTC wall clock
        naive_past = datetime.utcnow() - timedelta(hours=1)
        msg = _make_msg(expires_at=naive_past)

        result = _run(writer.send(msg))

        assert result is False
        mock_redis.rpush.assert_not_called()

    def test_naive_datetime_in_future_is_queued(self):
        """A naive expires_at clearly in the future must be accepted."""
        writer, mock_redis = _make_writer()
        naive_future = datetime.utcnow() + timedelta(hours=1)
        msg = _make_msg(expires_at=naive_future)

        result = _run(writer.send(msg))

        assert result is True

    def test_aware_datetime_in_past_is_rejected(self):
        """A timezone-aware expires_at in the past must be rejected."""
        writer, mock_redis = _make_writer()
        aware_past = datetime.now(timezone.utc) - timedelta(minutes=10)
        msg = _make_msg(expires_at=aware_past)

        assert _run(writer.send(msg)) is False


# ---------------------------------------------------------------------------
# WB3: JSON serialisation (not pickle) — message fields present and correct
# ---------------------------------------------------------------------------


class TestWB3_JsonSerialisation:
    """WB3: The value pushed to Redis must be a valid JSON string with all 7 fields."""

    def _pushed_json(self, msg: OpenClawMessage) -> dict:
        writer, mock_redis = _make_writer()
        _run(writer.send(msg))
        raw = mock_redis.rpush.call_args[0][1]
        return json.loads(raw)  # raises if not valid JSON

    def test_pushed_value_is_valid_json(self):
        """rpush value must parse as JSON without error."""
        msg = _make_msg()
        decoded = self._pushed_json(msg)
        assert isinstance(decoded, dict)

    def test_all_seven_fields_present(self):
        """JSON payload must contain all 7 message fields."""
        msg = _make_msg(
            expires_at=datetime(2026, 12, 31, 0, 0, 0, tzinfo=timezone.utc)
        )
        decoded = self._pushed_json(msg)
        required_keys = {
            "message_id", "session_id", "direction",
            "payload", "priority", "created_at", "expires_at",
        }
        assert required_keys.issubset(decoded.keys()), (
            f"Missing keys: {required_keys - decoded.keys()}"
        )

    def test_direction_serialised_as_string_value(self):
        """direction must be serialised as its .value string, not the enum repr."""
        msg = _make_msg(direction=MessageDirection.AIVA_TO_GENESIS)
        decoded = self._pushed_json(msg)
        assert decoded["direction"] == "aiva_to_genesis"

    def test_expires_at_none_serialised_as_null(self):
        """expires_at=None must serialise to JSON null (not the string 'None')."""
        msg = _make_msg(expires_at=None)
        decoded = self._pushed_json(msg)
        assert decoded["expires_at"] is None

    def test_expires_at_datetime_serialised_as_isoformat(self):
        """A datetime expires_at must serialise as an ISO 8601 string."""
        expiry = datetime(2026, 6, 15, 12, 30, 0, tzinfo=timezone.utc)
        msg = _make_msg(expires_at=expiry)
        decoded = self._pushed_json(msg)
        assert decoded["expires_at"] == expiry.isoformat()

    def test_payload_dict_preserved(self):
        """Nested payload dict must survive JSON round-trip unchanged."""
        payload = {"intent": "inject_context", "data": [1, 2, 3], "nested": {"k": "v"}}
        msg = _make_msg(payload=payload)
        decoded = self._pushed_json(msg)
        assert decoded["payload"] == payload

    def test_no_pickle_in_source(self):
        """Verify that pickle is not imported in the bridge module (JSON only)."""
        import inspect
        import core.bridge.openclaw_bridge as mod

        source = inspect.getsource(mod)
        assert "import pickle" not in source, (
            "pickle is FORBIDDEN in openclaw_bridge.py — use JSON only"
        )


# ---------------------------------------------------------------------------
# Run summary
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    result = pytest.main([__file__, "-v", "--tb=short"])
    sys.exit(result)
