#!/usr/bin/env python3
"""
Tests for Story 3.06 (Track B): MVFLInterceptor — Wires MVFL Into Chain

BB1: Clean output → post_execute passes through unchanged (no mvfl_corrected key)
BB2: Triggered output → CorrectionLoop called, result has mvfl_corrected=True
BB3: MVFL decision recorded in events.jsonl
BB4: 3-strike escalation → result has mvfl_escalated=True
BB5: on_correction is pure passthrough (returns unchanged payload)

WB1: Priority=90 ensures MVFL runs after JIT hydration (10) and business logic
WB2: pre_execute truly returns unchanged payload (identity)
WB3: VoyagerDefense.score is called (synchronous)
WB4: metadata.name == "mvfl"
WB5: on_error calls trigger.evaluate on error dict
"""
import asyncio
import json
import sys
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

sys.path.insert(0, '/mnt/e/genesis-system')


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def run(coro):
    """Run a coroutine synchronously for test purposes."""
    return asyncio.get_event_loop().run_until_complete(coro)


def _make_trigger_result(triggered: bool, trigger_type: str = "syntax", details: str = "bad"):
    """Build an MVFLTriggerResult for injection."""
    from core.mvfl.mvfl_trigger import MVFLTriggerResult
    return MVFLTriggerResult(
        triggered=triggered,
        trigger_type=trigger_type if triggered else None,
        severity=1 if triggered else 0,
        details=details if triggered else "Clean output",
    )


def _make_voyager_score(should_block: bool = False, score: float = 0.0):
    """Build a VoyagerScore for injection."""
    from core.mvfl.voyager_defense import VoyagerScore
    return VoyagerScore(score=score, matched_scars=[], should_block=should_block)


def _make_correction_result(success: bool, attempts: int = 1, escalated: bool = False):
    """Build a CorrectionResult for injection."""
    from core.mvfl.correction_loop import CorrectionResult
    output = {"task_id": "t1", "status": "completed", "output": "corrected"} if success else {
        "task_id": "t1", "status": "error", "error": "MVFL_ESCALATION_REQUIRED"
    }
    return CorrectionResult(success=success, output=output, attempts=attempts, escalated=escalated)


def _make_clean_trigger_mock():
    """Return a mock MVFLTrigger that reports clean output."""
    mock = MagicMock()
    mock.evaluate.return_value = _make_trigger_result(False)
    return mock


def _make_triggered_trigger_mock(trigger_type: str = "syntax"):
    """Return a mock MVFLTrigger that always triggers."""
    mock = MagicMock()
    mock.evaluate.return_value = _make_trigger_result(True, trigger_type, "detected error")
    return mock


def _make_clean_voyager_mock():
    """Return a mock VoyagerDefense that never blocks."""
    mock = MagicMock()
    mock.score.return_value = _make_voyager_score(should_block=False)
    return mock


def _make_blocking_voyager_mock():
    """Return a mock VoyagerDefense that always blocks."""
    mock = MagicMock()
    mock.score.return_value = _make_voyager_score(should_block=True, score=0.95)
    return mock


def _make_interceptor(trigger=None, voyager=None, correction_loop=None):
    """Build an MVFLInterceptor with all dependencies mocked by default."""
    from core.mvfl.mvfl_interceptor import MVFLInterceptor
    return MVFLInterceptor(
        trigger=trigger or _make_clean_trigger_mock(),
        voyager=voyager or _make_clean_voyager_mock(),
        correction_loop=correction_loop or MagicMock(),
    )


def _make_task(task_id: str = "task-001"):
    return {"task_id": task_id, "prompt": "Do the thing."}


# ---------------------------------------------------------------------------
# BB1: Clean output → post_execute passes through unchanged (no mvfl_corrected key)
# ---------------------------------------------------------------------------

def test_bb1_clean_output_passes_through_unchanged():
    """BB1: When trigger is clean and voyager does not block, result is unchanged."""
    interceptor = _make_interceptor(
        trigger=_make_clean_trigger_mock(),
        voyager=_make_clean_voyager_mock(),
    )

    result = {"task_id": "task-001", "status": "completed", "output": "ok"}
    original_keys = set(result.keys())

    run(interceptor.post_execute(result, _make_task()))

    # No mvfl_corrected or mvfl_escalated keys should have been added
    assert "mvfl_corrected" not in result, f"Unexpected mvfl_corrected in clean result: {result}"
    assert "mvfl_escalated" not in result, f"Unexpected mvfl_escalated in clean result: {result}"
    # Original content must be preserved
    assert result["status"] == "completed"
    assert result["output"] == "ok"
    print("BB1 PASSED — clean output untouched by post_execute")


# ---------------------------------------------------------------------------
# BB2: Triggered output → CorrectionLoop called, result has mvfl_corrected=True
# ---------------------------------------------------------------------------

def test_bb2_triggered_output_corrects_result():
    """BB2: When trigger fires and correction succeeds, result gets mvfl_corrected=True."""
    from core.mvfl.mvfl_interceptor import MVFLInterceptor

    correction_mock = MagicMock()
    correction_mock.run = AsyncMock(return_value=_make_correction_result(success=True, attempts=1))

    interceptor = MVFLInterceptor(
        trigger=_make_triggered_trigger_mock("syntax"),
        voyager=_make_clean_voyager_mock(),
        correction_loop=correction_mock,
    )

    result = {"task_id": "task-002", "status": "error", "error": "bad format"}

    run(interceptor.post_execute(result, _make_task("task-002")))

    assert result.get("mvfl_corrected") is True, f"Expected mvfl_corrected=True, got: {result}"
    assert result.get("mvfl_attempts") == 1, f"Expected mvfl_attempts=1, got: {result}"
    correction_mock.run.assert_awaited_once()
    print("BB2 PASSED — triggered output corrected, mvfl_corrected=True in result")


# ---------------------------------------------------------------------------
# BB3: MVFL decision recorded in events.jsonl
# ---------------------------------------------------------------------------

def test_bb3_decision_recorded_in_events_jsonl():
    """BB3: Every post_execute call writes an mvfl_decision event to events.jsonl."""
    from core.mvfl.mvfl_interceptor import MVFLInterceptor
    import core.mvfl.mvfl_interceptor as interceptor_mod

    with tempfile.TemporaryDirectory() as tmpdir:
        log_dir = Path(tmpdir)
        original_events_dir = interceptor_mod.EVENTS_DIR
        interceptor_mod.EVENTS_DIR = log_dir

        try:
            interceptor = MVFLInterceptor(
                trigger=_make_clean_trigger_mock(),
                voyager=_make_clean_voyager_mock(),
                correction_loop=MagicMock(),
            )

            result = {"task_id": "task-003", "status": "completed", "output": "ok"}
            run(interceptor.post_execute(result, _make_task("task-003")))

            log_path = log_dir / "events.jsonl"
            assert log_path.exists(), "events.jsonl was not created"

            lines = log_path.read_text().strip().splitlines()
            assert len(lines) >= 1, "No events written to events.jsonl"

            event = json.loads(lines[-1])
            assert event["event_type"] == "mvfl_decision", (
                f"Expected event_type='mvfl_decision', got: {event['event_type']}"
            )
            assert event["task_id"] == "task-003", f"Wrong task_id in event: {event['task_id']}"
            assert "timestamp" in event, "Missing timestamp in event"
            assert "triggered" in event, "Missing triggered in event"
            assert "blocked" in event, "Missing blocked in event"
            assert "voyager_score" in event, "Missing voyager_score in event"
        finally:
            interceptor_mod.EVENTS_DIR = original_events_dir

    print("BB3 PASSED — mvfl_decision event written to events.jsonl")


# ---------------------------------------------------------------------------
# BB4: 3-strike escalation → result has mvfl_escalated=True
# ---------------------------------------------------------------------------

def test_bb4_three_strike_escalation_sets_mvfl_escalated():
    """BB4: When CorrectionLoop exhausts all attempts, result has mvfl_escalated=True."""
    from core.mvfl.mvfl_interceptor import MVFLInterceptor

    correction_mock = MagicMock()
    correction_mock.run = AsyncMock(
        return_value=_make_correction_result(success=False, attempts=3, escalated=True)
    )

    interceptor = MVFLInterceptor(
        trigger=_make_triggered_trigger_mock("external_rejection"),
        voyager=_make_clean_voyager_mock(),
        correction_loop=correction_mock,
    )

    result = {"task_id": "task-004", "status": "error", "error": "API timeout"}

    run(interceptor.post_execute(result, _make_task("task-004")))

    assert result.get("mvfl_escalated") is True, (
        f"Expected mvfl_escalated=True after 3-strike, got: {result}"
    )
    assert result.get("mvfl_attempts") == 3, (
        f"Expected mvfl_attempts=3, got: {result.get('mvfl_attempts')}"
    )
    correction_mock.run.assert_awaited_once()
    print("BB4 PASSED — 3-strike escalation sets mvfl_escalated=True")


# ---------------------------------------------------------------------------
# BB5: on_correction is pure passthrough (returns unchanged payload)
# ---------------------------------------------------------------------------

def test_bb5_on_correction_is_passthrough():
    """BB5: on_correction must return the exact same dict it receives."""
    interceptor = _make_interceptor()

    payload = {"task_id": "task-005", "prompt": "CORRECTION: fix this", "attempt": 2}
    returned = run(interceptor.on_correction(payload))

    assert returned is payload, "on_correction must return the same dict object (identity)"
    assert returned == {"task_id": "task-005", "prompt": "CORRECTION: fix this", "attempt": 2}
    print("BB5 PASSED — on_correction is a pure passthrough")


# ---------------------------------------------------------------------------
# WB1: Priority=90 ensures MVFL runs after JIT hydration (10) and business logic
# ---------------------------------------------------------------------------

def test_wb1_priority_is_90():
    """WB1: MVFLInterceptor.metadata.priority must be 90."""
    interceptor = _make_interceptor()
    assert interceptor.metadata.priority == 90, (
        f"Expected priority=90, got: {interceptor.metadata.priority}"
    )
    print("WB1 PASSED — priority=90 (runs after priority 10 JIT hydration)")


# ---------------------------------------------------------------------------
# WB2: pre_execute truly returns unchanged payload (identity)
# ---------------------------------------------------------------------------

def test_wb2_pre_execute_is_identity():
    """WB2: pre_execute must return the exact same dict, unmodified."""
    interceptor = _make_interceptor()

    payload = {"task_id": "task-wb2", "prompt": "Hello", "custom_key": 42}
    returned = run(interceptor.pre_execute(payload))

    assert returned is payload, "pre_execute must return the same dict object"
    assert returned == {"task_id": "task-wb2", "prompt": "Hello", "custom_key": 42}
    print("WB2 PASSED — pre_execute returns identity (unchanged payload)")


# ---------------------------------------------------------------------------
# WB3: VoyagerDefense.score is called (synchronous call)
# ---------------------------------------------------------------------------

def test_wb3_voyager_score_is_called():
    """WB3: post_execute must call voyager.score exactly once."""
    from core.mvfl.mvfl_interceptor import MVFLInterceptor

    voyager_mock = _make_clean_voyager_mock()
    interceptor = MVFLInterceptor(
        trigger=_make_clean_trigger_mock(),
        voyager=voyager_mock,
        correction_loop=MagicMock(),
    )

    result = {"task_id": "task-wb3", "status": "completed"}
    run(interceptor.post_execute(result, _make_task("task-wb3")))

    voyager_mock.score.assert_called_once_with(result)
    print("WB3 PASSED — voyager.score called once during post_execute")


# ---------------------------------------------------------------------------
# WB4: metadata.name == "mvfl"
# ---------------------------------------------------------------------------

def test_wb4_metadata_name_is_mvfl():
    """WB4: MVFLInterceptor.metadata.name must be 'mvfl'."""
    interceptor = _make_interceptor()
    assert interceptor.metadata.name == "mvfl", (
        f"Expected metadata.name='mvfl', got: {interceptor.metadata.name!r}"
    )
    print("WB4 PASSED — metadata.name == 'mvfl'")


# ---------------------------------------------------------------------------
# WB5: on_error calls trigger.evaluate on error dict
# ---------------------------------------------------------------------------

def test_wb5_on_error_calls_trigger_evaluate():
    """WB5: on_error must call trigger.evaluate with the error dict context."""
    from core.mvfl.mvfl_interceptor import MVFLInterceptor

    trigger_mock = _make_clean_trigger_mock()
    interceptor = MVFLInterceptor(
        trigger=trigger_mock,
        voyager=_make_clean_voyager_mock(),
        correction_loop=MagicMock(),
    )

    task = _make_task("task-wb5")
    err = ValueError("Something exploded")
    returned = run(interceptor.on_error(err, task))

    # trigger.evaluate must have been called once with the error-derived dict
    assert trigger_mock.evaluate.call_count == 1, (
        f"Expected 1 call to trigger.evaluate, got: {trigger_mock.evaluate.call_count}"
    )

    # The first positional arg to evaluate must contain the error string
    call_args = trigger_mock.evaluate.call_args
    error_output = call_args[0][0]  # first positional arg
    assert error_output["status"] == "error", (
        f"Error output must have status='error', got: {error_output}"
    )
    assert "Something exploded" in error_output["error"], (
        f"Error message must be in error_output['error'], got: {error_output}"
    )

    # When trigger is not fired, returned dict must contain the error
    assert "error" in returned, f"Return dict must contain 'error' key: {returned}"
    print("WB5 PASSED — on_error calls trigger.evaluate with error dict")


# ---------------------------------------------------------------------------
# Additional edge cases
# ---------------------------------------------------------------------------

def test_bb6_voyager_block_alone_triggers_correction():
    """Extra: VoyagerDefense block (without trigger) also runs CorrectionLoop."""
    from core.mvfl.mvfl_interceptor import MVFLInterceptor

    correction_mock = MagicMock()
    correction_mock.run = AsyncMock(return_value=_make_correction_result(success=True, attempts=2))

    interceptor = MVFLInterceptor(
        trigger=_make_clean_trigger_mock(),       # trigger says clean
        voyager=_make_blocking_voyager_mock(),    # voyager says block
        correction_loop=correction_mock,
    )

    result = {"task_id": "task-extra", "status": "completed", "output": "suspicious"}
    run(interceptor.post_execute(result, _make_task("task-extra")))

    correction_mock.run.assert_awaited_once()
    assert result.get("mvfl_corrected") is True, (
        f"Expected mvfl_corrected after voyager block, got: {result}"
    )
    print("BB6 PASSED — voyager block alone triggers correction loop")


def test_wb6_mvfl_interceptor_exported_from_package():
    """WB6: MVFLInterceptor must be importable from core.mvfl package."""
    from core.mvfl import MVFLInterceptor
    assert MVFLInterceptor is not None
    interceptor = MVFLInterceptor()
    assert interceptor.metadata.name == "mvfl"
    print("WB6 PASSED — MVFLInterceptor exported from core.mvfl.__init__")


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    tests = [
        test_bb1_clean_output_passes_through_unchanged,
        test_bb2_triggered_output_corrects_result,
        test_bb3_decision_recorded_in_events_jsonl,
        test_bb4_three_strike_escalation_sets_mvfl_escalated,
        test_bb5_on_correction_is_passthrough,
        test_wb1_priority_is_90,
        test_wb2_pre_execute_is_identity,
        test_wb3_voyager_score_is_called,
        test_wb4_metadata_name_is_mvfl,
        test_wb5_on_error_calls_trigger_evaluate,
        test_bb6_voyager_block_alone_triggers_correction,
        test_wb6_mvfl_interceptor_exported_from_package,
    ]

    passed = 0
    failed = 0
    for t in tests:
        try:
            t()
            passed += 1
        except Exception as exc:
            import traceback
            print(f"FAILED: {t.__name__} — {exc}")
            traceback.print_exc()
            failed += 1

    print(f"\n{'='*60}")
    print(f"Story 3.06 Test Results: {passed}/{len(tests)} passed")
    if failed:
        print(f"FAILED: {failed} test(s)")
        sys.exit(1)
    else:
        print("ALL TESTS PASSED")
        sys.exit(0)
