#!/usr/bin/env python3
"""
Tests for Story 3.04 (Track B): CorrectionLoop — MVFL Re-injection Engine

BB1: Bad output corrected on attempt 2 → success=True, attempts=2, escalated=False
BB2: All 3 attempts fail → escalated=True, attempts=3
BB3: Correction prompt always starts with "CORRECTION: "
BB4: First attempt clean → success=True, attempts=1

WB1: MAX_CORRECTION_ATTEMPTS constant = 3
WB2: Each attempt re-runs MVFLTrigger.evaluate (mock to count calls)
WB3: Escalation function called exactly once on 3-strike
WB4: Events logged for each attempt (mock log to count entries)
"""
import sys
import asyncio
import json
import tempfile
import os
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch, call

sys.path.insert(0, '/mnt/e/genesis-system')


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def run(coro):
    """Run a coroutine synchronously for test purposes."""
    return asyncio.get_event_loop().run_until_complete(coro)


def _make_trigger_result(triggered: bool, trigger_type: str = "syntax", details: str = "bad format"):
    """Build a MVFLTriggerResult for test injection."""
    from core.mvfl.mvfl_trigger import MVFLTriggerResult
    return MVFLTriggerResult(
        triggered=triggered,
        trigger_type=trigger_type if triggered else None,
        severity=1 if triggered else 0,
        details=details if triggered else "Clean output",
    )


def _clean_output(task_id: str = "t1") -> dict:
    """Build a clean (non-triggering) output dict."""
    return {"task_id": task_id, "status": "completed", "output": "ok"}


def _bad_output(task_id: str = "t1") -> dict:
    """Build an error output that will trigger MVFLTrigger."""
    return {"task_id": task_id, "status": "error", "error": "API_FAIL"}


def _make_trigger_mock(responses):
    """
    Return a mock MVFLTrigger whose evaluate() returns successive MVFLTriggerResults
    from the given list. If more calls than list items, repeats the last.
    """
    mock = MagicMock()
    call_counts = {"n": 0}

    def side_effect(output, payload):
        idx = min(call_counts["n"], len(responses) - 1)
        call_counts["n"] += 1
        return responses[idx]

    mock.evaluate.side_effect = side_effect
    return mock


def _make_voyager_mock(should_block: bool = False):
    """Return a mock VoyagerDefense that always returns a non-blocking score."""
    from core.mvfl.voyager_defense import VoyagerScore
    mock = MagicMock()
    mock.score.return_value = VoyagerScore(score=0.0, matched_scars=[], should_block=should_block)
    return mock


# ---------------------------------------------------------------------------
# BB1: Bad output corrected on attempt 2 → success=True, attempts=2, escalated=False
# ---------------------------------------------------------------------------

def test_bb1_corrected_on_attempt_2():
    """Bad output fails attempt 1, succeeds on attempt 2."""
    from core.mvfl.correction_loop import CorrectionLoop

    # attempt 1 → still bad; attempt 2 → clean
    trigger_mock = _make_trigger_mock([
        _make_trigger_result(True, "syntax", "Missing task_id"),   # after attempt 1
        _make_trigger_result(False),                               # after attempt 2
    ])
    voyager_mock = _make_voyager_mock(should_block=False)

    # dispatch returns clean output on attempt 2 (doesn't matter, trigger decides)
    dispatch = AsyncMock(return_value=_clean_output())

    loop = CorrectionLoop(
        trigger=trigger_mock,
        voyager=voyager_mock,
        dispatch_fn=dispatch,
    )

    initial_trigger = _make_trigger_result(True, "syntax", "Missing task_id")
    result = run(loop.run(
        task_payload={"task_id": "t1", "prompt": "Do something"},
        failed_output=_bad_output(),
        trigger_result=initial_trigger,
    ))

    assert result.success is True, f"Expected success, got: {result}"
    assert result.attempts == 2, f"Expected 2 attempts, got: {result.attempts}"
    assert result.escalated is False, f"Expected not escalated, got: {result}"
    print("BB1 PASSED")


# ---------------------------------------------------------------------------
# BB2: All 3 attempts fail → escalated=True, attempts=3
# ---------------------------------------------------------------------------

def test_bb2_all_attempts_fail_escalation():
    """Three consecutive bad outputs trigger escalation."""
    from core.mvfl.correction_loop import CorrectionLoop

    # All 3 re-evaluations still triggered
    trigger_mock = _make_trigger_mock([
        _make_trigger_result(True, "syntax", "bad"),
        _make_trigger_result(True, "syntax", "bad"),
        _make_trigger_result(True, "syntax", "bad"),
    ])
    voyager_mock = _make_voyager_mock(should_block=False)

    dispatch = AsyncMock(return_value=_bad_output())
    escalation = AsyncMock(return_value={"task_id": "t1", "status": "escalated"})

    loop = CorrectionLoop(
        trigger=trigger_mock,
        voyager=voyager_mock,
        dispatch_fn=dispatch,
        escalation_fn=escalation,
    )

    initial_trigger = _make_trigger_result(True, "syntax", "bad")
    result = run(loop.run(
        task_payload={"task_id": "t1", "prompt": "Do something"},
        failed_output=_bad_output(),
        trigger_result=initial_trigger,
    ))

    assert result.success is False, f"Expected failure, got: {result}"
    assert result.escalated is True, f"Expected escalated=True, got: {result}"
    assert result.attempts == 3, f"Expected 3 attempts, got: {result.attempts}"
    print("BB2 PASSED")


# ---------------------------------------------------------------------------
# BB3: Correction prompt always starts with "CORRECTION: "
# ---------------------------------------------------------------------------

def test_bb3_correction_prompt_prefix():
    """Every re-dispatched payload must have a prompt starting with 'CORRECTION: '."""
    from core.mvfl.correction_loop import CorrectionLoop

    captured_payloads = []

    async def capture_dispatch(payload):
        captured_payloads.append(payload)
        return _clean_output()

    # First re-evaluation is clean → loop exits after attempt 1
    trigger_mock = _make_trigger_mock([_make_trigger_result(False)])
    voyager_mock = _make_voyager_mock(should_block=False)

    loop = CorrectionLoop(
        trigger=trigger_mock,
        voyager=voyager_mock,
        dispatch_fn=capture_dispatch,
    )

    initial_trigger = _make_trigger_result(True, "semantic", "contradiction found")
    run(loop.run(
        task_payload={"task_id": "t2", "prompt": "Analyze this"},
        failed_output=_bad_output("t2"),
        trigger_result=initial_trigger,
    ))

    assert len(captured_payloads) == 1, f"Expected 1 dispatch call, got {len(captured_payloads)}"
    prompt = captured_payloads[0]["prompt"]
    assert prompt.startswith("CORRECTION: "), f"Prompt does not start with 'CORRECTION: ': {prompt!r}"
    assert "contradiction found" in prompt, f"Trigger details missing from prompt: {prompt!r}"
    assert "Analyze this" in prompt, f"Original prompt missing: {prompt!r}"
    print("BB3 PASSED")


# ---------------------------------------------------------------------------
# BB4: First attempt clean → success=True, attempts=1
# ---------------------------------------------------------------------------

def test_bb4_first_attempt_clean():
    """If the corrected output is clean on attempt 1, return immediately."""
    from core.mvfl.correction_loop import CorrectionLoop

    trigger_mock = _make_trigger_mock([_make_trigger_result(False)])
    voyager_mock = _make_voyager_mock(should_block=False)
    dispatch = AsyncMock(return_value=_clean_output())

    loop = CorrectionLoop(
        trigger=trigger_mock,
        voyager=voyager_mock,
        dispatch_fn=dispatch,
    )

    initial_trigger = _make_trigger_result(True, "syntax", "Missing status")
    result = run(loop.run(
        task_payload={"task_id": "t3", "prompt": "Quick task"},
        failed_output=_bad_output("t3"),
        trigger_result=initial_trigger,
    ))

    assert result.success is True
    assert result.attempts == 1, f"Expected 1 attempt, got {result.attempts}"
    assert result.escalated is False
    dispatch.assert_awaited_once()
    print("BB4 PASSED")


# ---------------------------------------------------------------------------
# WB1: MAX_CORRECTION_ATTEMPTS constant = 3
# ---------------------------------------------------------------------------

def test_wb1_max_correction_attempts_constant():
    """MAX_CORRECTION_ATTEMPTS must be exactly 3."""
    from core.mvfl.correction_loop import MAX_CORRECTION_ATTEMPTS
    assert MAX_CORRECTION_ATTEMPTS == 3, f"Expected 3, got {MAX_CORRECTION_ATTEMPTS}"
    print("WB1 PASSED")


# ---------------------------------------------------------------------------
# WB2: Each attempt re-runs MVFLTrigger.evaluate (mock to count calls)
# ---------------------------------------------------------------------------

def test_wb2_trigger_evaluate_called_each_attempt():
    """trigger.evaluate() must be called once per correction attempt."""
    from core.mvfl.correction_loop import CorrectionLoop

    # All 3 fail so we hit the full loop
    trigger_mock = _make_trigger_mock([
        _make_trigger_result(True, "syntax", "err"),
        _make_trigger_result(True, "syntax", "err"),
        _make_trigger_result(True, "syntax", "err"),
    ])
    voyager_mock = _make_voyager_mock(should_block=False)
    dispatch = AsyncMock(return_value=_bad_output())

    loop = CorrectionLoop(
        trigger=trigger_mock,
        voyager=voyager_mock,
        dispatch_fn=dispatch,
    )

    initial_trigger = _make_trigger_result(True, "syntax", "err")
    run(loop.run(
        task_payload={"task_id": "t4", "prompt": "test"},
        failed_output=_bad_output("t4"),
        trigger_result=initial_trigger,
    ))

    # evaluate called once per attempt (3 attempts total)
    assert trigger_mock.evaluate.call_count == 3, (
        f"Expected 3 evaluate calls, got {trigger_mock.evaluate.call_count}"
    )
    print("WB2 PASSED")


# ---------------------------------------------------------------------------
# WB3: Escalation function called exactly once on 3-strike
# ---------------------------------------------------------------------------

def test_wb3_escalation_called_exactly_once():
    """Escalation fn must be called exactly once when all 3 attempts fail."""
    from core.mvfl.correction_loop import CorrectionLoop

    trigger_mock = _make_trigger_mock([
        _make_trigger_result(True, "syntax", "err"),
        _make_trigger_result(True, "syntax", "err"),
        _make_trigger_result(True, "syntax", "err"),
    ])
    voyager_mock = _make_voyager_mock(should_block=False)
    dispatch = AsyncMock(return_value=_bad_output())
    escalation = AsyncMock(return_value={"task_id": "t5", "status": "escalated"})

    loop = CorrectionLoop(
        trigger=trigger_mock,
        voyager=voyager_mock,
        dispatch_fn=dispatch,
        escalation_fn=escalation,
    )

    initial_trigger = _make_trigger_result(True, "syntax", "err")
    result = run(loop.run(
        task_payload={"task_id": "t5", "prompt": "test"},
        failed_output=_bad_output("t5"),
        trigger_result=initial_trigger,
    ))

    escalation.assert_awaited_once()
    assert result.escalated is True
    print("WB3 PASSED")


# ---------------------------------------------------------------------------
# WB4: Events logged for each attempt (mock log path to count entries)
# ---------------------------------------------------------------------------

def test_wb4_events_logged_per_attempt():
    """Each correction attempt + escalation writes an event to the log file."""
    from core.mvfl.correction_loop import CorrectionLoop

    trigger_mock = _make_trigger_mock([
        _make_trigger_result(True, "syntax", "err"),
        _make_trigger_result(True, "syntax", "err"),
        _make_trigger_result(True, "syntax", "err"),
    ])
    voyager_mock = _make_voyager_mock(should_block=False)
    dispatch = AsyncMock(return_value=_bad_output())

    loop = CorrectionLoop(
        trigger=trigger_mock,
        voyager=voyager_mock,
        dispatch_fn=dispatch,
    )

    with tempfile.TemporaryDirectory() as tmpdir:
        log_path = Path(tmpdir) / "events.jsonl"

        import core.mvfl.correction_loop as cl_module
        original_path = cl_module.EVENTS_LOG_PATH
        cl_module.EVENTS_LOG_PATH = log_path

        try:
            initial_trigger = _make_trigger_result(True, "syntax", "err")
            run(loop.run(
                task_payload={"task_id": "t6", "prompt": "test"},
                failed_output=_bad_output("t6"),
                trigger_result=initial_trigger,
            ))
        finally:
            cl_module.EVENTS_LOG_PATH = original_path

        lines = log_path.read_text().strip().splitlines()
        events = [json.loads(l) for l in lines]

        attempt_events = [e for e in events if e["event"] == "mvfl_correction_attempt"]
        escalation_events = [e for e in events if e["event"] == "mvfl_escalation"]

        assert len(attempt_events) == 3, f"Expected 3 attempt events, got {len(attempt_events)}"
        assert len(escalation_events) == 1, f"Expected 1 escalation event, got {len(escalation_events)}"
        print("WB4 PASSED")


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    test_bb1_corrected_on_attempt_2()
    test_bb2_all_attempts_fail_escalation()
    test_bb3_correction_prompt_prefix()
    test_bb4_first_attempt_clean()
    test_wb1_max_correction_attempts_constant()
    test_wb2_trigger_evaluate_called_each_attempt()
    test_wb3_escalation_called_exactly_once()
    test_wb4_events_logged_per_attempt()
    print("\nALL TESTS PASSED — Story 3.04 (Track B): CorrectionLoop")
