#!/usr/bin/env python3
"""
Tests for Story 3.05: RLMCaptureAgent — Async Loop + Lifecycle
AIVA RLM Nexus PRD v2 — Track A

Black box tests (BB1-BB5): verify the public contract from the outside.
White box tests (WB1-WB6): verify internal state, constants, event shape,
                            Redis key format, and fail-safe behaviour.

ALL external dependencies (Redis, filesystem) are fully mocked so the
suite runs without any live infrastructure.
"""
import asyncio
import json
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

sys.path.insert(0, "/mnt/e/genesis-system")

from core.agents.rlm_capture_agent import (
    EVENTS_DIR,
    MAX_RUNTIME_SECONDS,
    POLL_INTERVAL_SECONDS,
    RLMCaptureAgent,
)

# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------

_SESSION_ID = "sess-3-05-test"
_CALL_CTRL_ID = "cc-abc-xyz-001"


def _make_agent(redis=None) -> RLMCaptureAgent:
    """Return a fresh agent with optional mocked Redis."""
    return RLMCaptureAgent(
        session_id=_SESSION_ID,
        call_control_id=_CALL_CTRL_ID,
        redis_client=redis,
    )


def _make_redis_ended() -> AsyncMock:
    """AsyncMock Redis that immediately returns an 'ended' state."""
    redis = AsyncMock()
    redis.get = AsyncMock(return_value=json.dumps({"status": "ended"}).encode())
    return redis


def _make_redis_active() -> AsyncMock:
    """AsyncMock Redis that always returns None (key absent — call active)."""
    redis = AsyncMock()
    redis.get = AsyncMock(return_value=None)
    return redis


def _make_redis_error() -> AsyncMock:
    """AsyncMock Redis whose get() always raises an exception."""
    redis = AsyncMock()
    redis.get = AsyncMock(side_effect=Exception("Redis connection refused"))
    return redis


def _run(coro):
    """Run a coroutine synchronously (Python 3.7+)."""
    return asyncio.get_event_loop().run_until_complete(coro)


# ---------------------------------------------------------------------------
# Black-box tests (BB) — treat the class as a black box
# ---------------------------------------------------------------------------


class TestBB1_EndedRedisState_AgentRunCompletes:
    """BB1: Redis state is 'ended' → run() completes (stops within the loop)."""

    def test_run_completes_without_exception(self):
        redis = _make_redis_ended()
        agent = _make_agent(redis=redis)

        with patch.object(agent, "_log_event"):
            with patch("asyncio.sleep", new_callable=AsyncMock):
                _run(agent.run())

        # No exception raised — run() returned normally

    def test_running_is_false_after_run(self):
        redis = _make_redis_ended()
        agent = _make_agent(redis=redis)

        with patch.object(agent, "_log_event"):
            with patch("asyncio.sleep", new_callable=AsyncMock):
                _run(agent.run())

        assert agent._running is False

    def test_redis_get_was_called(self):
        redis = _make_redis_ended()
        agent = _make_agent(redis=redis)

        with patch.object(agent, "_log_event"):
            with patch("asyncio.sleep", new_callable=AsyncMock):
                _run(agent.run())

        redis.get.assert_called()


class TestBB2_NoRedisKey_IsCallEndedReturnsFalse:
    """BB2: Redis key absent → _is_call_ended() returns False (agent continues)."""

    def test_missing_key_returns_false(self):
        redis = _make_redis_active()
        agent = _make_agent(redis=redis)

        result = _run(agent._is_call_ended())

        assert result is False

    def test_none_raw_value_returns_false(self):
        redis = AsyncMock()
        redis.get = AsyncMock(return_value=None)
        agent = _make_agent(redis=redis)

        result = _run(agent._is_call_ended())

        assert result is False

    def test_empty_bytes_returns_false(self):
        redis = AsyncMock()
        redis.get = AsyncMock(return_value=b"")
        agent = _make_agent(redis=redis)

        result = _run(agent._is_call_ended())

        assert result is False


class TestBB3_ZombieGuard_AgentStopsAfterMaxRuntime:
    """BB3: Zombie guard triggers → agent stops after MAX_RUNTIME_SECONDS."""

    def test_zombie_guard_stops_agent(self):
        """Monkey-patch MAX_RUNTIME_SECONDS to 0 so the guard fires immediately."""
        agent = _make_agent(redis=None)

        import core.agents.rlm_capture_agent as mod

        original = mod.MAX_RUNTIME_SECONDS
        mod.MAX_RUNTIME_SECONDS = 0  # Guard fires on first iteration
        try:
            with patch.object(agent, "_log_event"):
                with patch("asyncio.sleep", new_callable=AsyncMock):
                    _run(agent.run())
        finally:
            mod.MAX_RUNTIME_SECONDS = original

        assert agent._running is False

    def test_zombie_guard_logs_event(self):
        agent = _make_agent(redis=None)

        import core.agents.rlm_capture_agent as mod

        original = mod.MAX_RUNTIME_SECONDS
        mod.MAX_RUNTIME_SECONDS = 0
        logged = []

        def capture(event_type, extra=None):
            logged.append(event_type)

        try:
            with patch.object(agent, "_log_event", side_effect=capture):
                with patch("asyncio.sleep", new_callable=AsyncMock):
                    _run(agent.run())
        finally:
            mod.MAX_RUNTIME_SECONDS = original

        assert "zombie_guard_triggered" in logged


class TestBB4_StopSetsRunningFalse:
    """BB4: stop() sets _running = False."""

    def test_stop_sets_running_false(self):
        agent = _make_agent()
        assert agent._running is True

        _run(agent.stop())

        assert agent._running is False

    def test_stop_returns_none(self):
        agent = _make_agent()
        result = _run(agent.stop())

        assert result is None


class TestBB5_MultipleStopCalls_Idempotent:
    """BB5: Multiple stop() calls → idempotent, no crash."""

    def test_double_stop_does_not_raise(self):
        agent = _make_agent()
        try:
            _run(agent.stop())
            _run(agent.stop())
            _run(agent.stop())
        except Exception as exc:
            pytest.fail(f"Multiple stop() calls raised: {exc}")

    def test_double_stop_leaves_running_false(self):
        agent = _make_agent()
        _run(agent.stop())
        _run(agent.stop())

        assert agent._running is False


# ---------------------------------------------------------------------------
# White-box tests (WB) — verify internals
# ---------------------------------------------------------------------------


class TestWB1_RedisFail_IsCallEndedReturnsFalse:
    """WB1: Redis.get() raises → _is_call_ended() returns False (fail-safe: keep running)."""

    def test_redis_exception_returns_false(self):
        redis = _make_redis_error()
        agent = _make_agent(redis=redis)

        result = _run(agent._is_call_ended())

        assert result is False

    def test_redis_exception_does_not_propagate(self):
        redis = _make_redis_error()
        agent = _make_agent(redis=redis)

        try:
            _run(agent._is_call_ended())
        except Exception as exc:
            pytest.fail(f"Exception escaped _is_call_ended: {exc}")

    def test_no_redis_client_returns_false(self):
        agent = _make_agent(redis=None)

        result = _run(agent._is_call_ended())

        assert result is False


class TestWB2_StartedAtSetWhenRunBegins:
    """WB2: _started_at is None before run(), set when run() starts."""

    def test_started_at_is_none_before_run(self):
        agent = _make_agent()

        assert agent._started_at is None

    def test_started_at_set_after_run(self):
        redis = _make_redis_ended()
        agent = _make_agent(redis=redis)

        with patch.object(agent, "_log_event"):
            with patch("asyncio.sleep", new_callable=AsyncMock):
                _run(agent.run())

        assert agent._started_at is not None
        assert isinstance(agent._started_at, float)


class TestWB3_EventsLoggedOnStartAndStop:
    """WB3: Events logged on agent_started and agent_stopped lifecycle points."""

    def test_agent_started_event_logged(self):
        redis = _make_redis_ended()
        agent = _make_agent(redis=redis)
        logged = []

        def capture(event_type, extra=None):
            logged.append(event_type)

        with patch.object(agent, "_log_event", side_effect=capture):
            with patch("asyncio.sleep", new_callable=AsyncMock):
                _run(agent.run())

        assert "agent_started" in logged

    def test_agent_stopped_event_logged(self):
        redis = _make_redis_ended()
        agent = _make_agent(redis=redis)
        logged = []

        def capture(event_type, extra=None):
            logged.append(event_type)

        with patch.object(agent, "_log_event", side_effect=capture):
            with patch("asyncio.sleep", new_callable=AsyncMock):
                _run(agent.run())

        assert "agent_stopped" in logged

    def test_log_event_writes_rlm_capture_prefix(self):
        agent = _make_agent()
        written = []

        def capture_write(path, mode="r", **kwargs):
            if mode == "a":
                m = MagicMock()
                m.__enter__ = lambda s: s
                m.__exit__ = MagicMock(return_value=False)
                m.write = lambda data: written.append(data)
                return m
            return open(path, mode, **kwargs)

        with patch("builtins.open", side_effect=capture_write):
            with patch("pathlib.Path.mkdir"):
                agent._log_event("agent_started")

        assert len(written) == 1
        record = json.loads(written[0].strip())
        assert record["event_type"] == "rlm_capture_agent_started"

    def test_log_event_includes_session_id(self):
        agent = _make_agent()
        written = []

        def capture_write(path, mode="r", **kwargs):
            if mode == "a":
                m = MagicMock()
                m.__enter__ = lambda s: s
                m.__exit__ = MagicMock(return_value=False)
                m.write = lambda data: written.append(data)
                return m
            return open(path, mode, **kwargs)

        with patch("builtins.open", side_effect=capture_write):
            with patch("pathlib.Path.mkdir"):
                agent._log_event("agent_started")

        record = json.loads(written[0].strip())
        assert record["session_id"] == _SESSION_ID

    def test_log_event_timestamp_is_utc(self):
        agent = _make_agent()
        written = []

        def capture_write(path, mode="r", **kwargs):
            if mode == "a":
                m = MagicMock()
                m.__enter__ = lambda s: s
                m.__exit__ = MagicMock(return_value=False)
                m.write = lambda data: written.append(data)
                return m
            return open(path, mode, **kwargs)

        with patch("builtins.open", side_effect=capture_write):
            with patch("pathlib.Path.mkdir"):
                agent._log_event("agent_started")

        record = json.loads(written[0].strip())
        dt = datetime.fromisoformat(record["timestamp"])
        assert dt.tzinfo is not None


class TestWB4_PollIntervalConstant:
    """WB4: POLL_INTERVAL_SECONDS == 30."""

    def test_poll_interval_is_30(self):
        assert POLL_INTERVAL_SECONDS == 30

    def test_sleep_called_with_poll_interval(self):
        """run() must await asyncio.sleep(POLL_INTERVAL_SECONDS) each cycle."""
        # Use a counter to stop after first sleep call
        redis = AsyncMock()
        call_count = 0

        async def get_side_effect(key):
            nonlocal call_count
            call_count += 1
            if call_count >= 2:
                # Return 'ended' on second Redis poll so loop exits
                return json.dumps({"status": "ended"}).encode()
            return None

        redis.get = get_side_effect

        sleep_calls = []

        async def capture_sleep(secs):
            sleep_calls.append(secs)

        agent = _make_agent(redis=redis)

        with patch.object(agent, "_log_event"):
            with patch("asyncio.sleep", side_effect=capture_sleep):
                _run(agent.run())

        assert any(s == POLL_INTERVAL_SECONDS for s in sleep_calls)


class TestWB5_MaxRuntimeConstant:
    """WB5: MAX_RUNTIME_SECONDS == 3600."""

    def test_max_runtime_is_3600(self):
        assert MAX_RUNTIME_SECONDS == 3600


class TestWB6_RunReturnsNormally:
    """WB6: run() returns normally (no exception) after stopping."""

    def test_run_does_not_raise_on_ended_call(self):
        redis = _make_redis_ended()
        agent = _make_agent(redis=redis)

        try:
            with patch.object(agent, "_log_event"):
                with patch("asyncio.sleep", new_callable=AsyncMock):
                    _run(agent.run())
        except Exception as exc:
            pytest.fail(f"run() raised unexpectedly: {exc}")

    def test_run_does_not_raise_with_no_redis(self):
        """No Redis + zombie guard set to 0 → run() exits normally with no exception."""
        agent = _make_agent(redis=None)

        import core.agents.rlm_capture_agent as mod

        original = mod.MAX_RUNTIME_SECONDS
        mod.MAX_RUNTIME_SECONDS = 0
        try:
            with patch.object(agent, "_log_event"):
                with patch("asyncio.sleep", new_callable=AsyncMock):
                    _run(agent.run())
        except Exception as exc:
            pytest.fail(f"run() raised unexpectedly: {exc}")
        finally:
            mod.MAX_RUNTIME_SECONDS = original


# ---------------------------------------------------------------------------
# Run summary
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    result = pytest.main([__file__, "-v", "--tb=short"])
    sys.exit(result)
