#!/usr/bin/env python3
"""
Tests for Story 3.05 (Track B): MemGPTEscalation — Opus Escalation on 3-Strike

BB1: Escalation builds prompt with all failure details
BB2: Escalation uses ESCALATION_MODEL = "claude-opus-4-6"
BB3: Escalation event written to events.jsonl
BB4: No dispatch function -> returns stub escalated response

WB1: Failure history includes details from all attempts
WB2: On Opus failure (dispatch raises) -> error dict with MVFL_ESCALATION_FAILED
WB3: _build_history_string formats each attempt numbered
WB4: ESCALATION_SYSTEM_PROMPT contains "{failure_history}" placeholder
"""
import asyncio
import json
import os
import sys
import tempfile
from pathlib import Path

sys.path.insert(0, '/mnt/e/genesis-system')


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def run(coro):
    """Run an async coroutine synchronously."""
    return asyncio.get_event_loop().run_until_complete(coro)


def make_task(task_id="task-001"):
    return {"task_id": task_id, "prompt": "Do the thing."}


def make_failed_output(task_id="task-001", reason="bad output"):
    return {"task_id": task_id, "status": "error", "reason": reason}


# ---------------------------------------------------------------------------
# BB1: Escalation builds prompt with all failure details
# ---------------------------------------------------------------------------

def test_bb1_prompt_contains_failure_details():
    """BB1: The system prompt sent to Opus must contain all failure details."""
    from core.mvfl.memgpt_escalation import MemGPTEscalation

    captured = {}

    async def capture_dispatch(model, system_prompt, task_payload):
        captured["model"] = model
        captured["system_prompt"] = system_prompt
        captured["task_payload"] = task_payload
        return {"task_id": task_payload.get("task_id"), "status": "completed", "output": "fixed"}

    esc = MemGPTEscalation(dispatch_fn=capture_dispatch)

    task = make_task("t-bb1")
    history = [
        {"task_id": "t-bb1", "status": "error", "reason": "syntax failure"},
        {"task_id": "t-bb1", "status": "error", "reason": "semantic mismatch"},
        {"task_id": "t-bb1", "status": "error", "reason": "external rejection"},
    ]
    result = run(esc.escalate(task, history[-1], failure_history=history))

    assert "system_prompt" in captured, "dispatch_fn was not called"
    sp = captured["system_prompt"]

    # All three failure reasons must appear in the prompt
    assert "syntax failure" in sp, f"Expected 'syntax failure' in prompt:\n{sp}"
    assert "semantic mismatch" in sp, f"Expected 'semantic mismatch' in prompt:\n{sp}"
    assert "external rejection" in sp, f"Expected 'external rejection' in prompt:\n{sp}"

    # attempt_count must be present
    assert "3" in sp, f"Expected attempt count '3' in prompt:\n{sp}"

    print("BB1 PASSED — prompt contains all failure details")


# ---------------------------------------------------------------------------
# BB2: Escalation uses ESCALATION_MODEL = "claude-opus-4-6"
# ---------------------------------------------------------------------------

def test_bb2_escalation_model_is_opus():
    """BB2: The model dispatched to must be claude-opus-4-6."""
    from core.mvfl.memgpt_escalation import MemGPTEscalation, ESCALATION_MODEL

    assert ESCALATION_MODEL == "claude-opus-4-6", (
        f"ESCALATION_MODEL must be 'claude-opus-4-6', got '{ESCALATION_MODEL}'"
    )

    dispatched_to = {}

    async def capture_dispatch(model, system_prompt, task_payload):
        dispatched_to["model"] = model
        return {"task_id": task_payload.get("task_id"), "status": "completed", "output": "ok"}

    esc = MemGPTEscalation(dispatch_fn=capture_dispatch)
    run(esc.escalate(make_task(), make_failed_output()))

    assert dispatched_to["model"] == "claude-opus-4-6", (
        f"Dispatched to wrong model: {dispatched_to['model']}"
    )
    print("BB2 PASSED — model is claude-opus-4-6")


# ---------------------------------------------------------------------------
# BB3: Escalation event written to events.jsonl
# ---------------------------------------------------------------------------

def test_bb3_escalation_event_written():
    """BB3: An mvfl_escalation event must be appended to the events log."""
    with tempfile.TemporaryDirectory() as tmpdir:
        log_path = Path(tmpdir) / "events.jsonl"

        # Temporarily override EVENTS_LOG_PATH
        import core.mvfl.memgpt_escalation as mod
        original_path = mod.EVENTS_LOG_PATH
        mod.EVENTS_LOG_PATH = log_path

        try:
            from core.mvfl.memgpt_escalation import MemGPTEscalation
            esc = MemGPTEscalation(dispatch_fn=None)  # No dispatch — stub response
            run(esc.escalate(make_task("t-bb3"), make_failed_output("t-bb3")))

            assert log_path.exists(), "events.jsonl was not created"
            lines = log_path.read_text().strip().splitlines()
            assert len(lines) >= 1, "No events written to log"

            event = json.loads(lines[-1])
            assert event["event"] == "mvfl_escalation", f"Wrong event type: {event['event']}"
            assert event["task_id"] == "t-bb3", f"Wrong task_id: {event['task_id']}"
            assert event["model"] == "claude-opus-4-6", f"Wrong model in event: {event['model']}"
            assert "timestamp" in event, "Missing timestamp in event"
            assert "attempt_count" in event, "Missing attempt_count in event"
        finally:
            mod.EVENTS_LOG_PATH = original_path

    print("BB3 PASSED — escalation event written to events.jsonl")


# ---------------------------------------------------------------------------
# BB4: No dispatch function -> returns stub escalated response
# ---------------------------------------------------------------------------

def test_bb4_no_dispatch_returns_stub():
    """BB4: Without a dispatch_fn, escalate() returns a stub escalated dict."""
    from core.mvfl.memgpt_escalation import MemGPTEscalation, ESCALATION_MODEL

    esc = MemGPTEscalation(dispatch_fn=None)
    result = run(esc.escalate(make_task("t-bb4"), make_failed_output("t-bb4")))

    assert result["status"] == "escalated", f"Expected status='escalated', got: {result['status']}"
    assert result["task_id"] == "t-bb4", f"Wrong task_id: {result['task_id']}"
    assert result["model"] == ESCALATION_MODEL, f"Wrong model in stub: {result['model']}"
    assert "output" in result, "Missing 'output' key in stub response"
    assert "attempt_count" in result, "Missing 'attempt_count' key in stub response"

    print("BB4 PASSED — stub response returned when no dispatch function")


# ---------------------------------------------------------------------------
# WB1: Failure history includes details from all attempts
# ---------------------------------------------------------------------------

def test_wb1_history_all_attempts_in_prompt():
    """WB1: _build_history_string must number each attempt and include all fields."""
    from core.mvfl.memgpt_escalation import MemGPTEscalation

    esc = MemGPTEscalation()
    history = [
        {"reason": "alpha_error", "code": 500},
        {"reason": "beta_error",  "code": 422},
    ]
    result = esc._build_history_string(history)

    assert "--- Attempt 1 ---" in result, "Missing '--- Attempt 1 ---' header"
    assert "--- Attempt 2 ---" in result, "Missing '--- Attempt 2 ---' header"
    assert "alpha_error" in result, "Missing first attempt reason"
    assert "beta_error" in result,  "Missing second attempt reason"
    assert "500" in result, "Missing first attempt code"
    assert "422" in result, "Missing second attempt code"

    print("WB1 PASSED — all attempts included in history string")


# ---------------------------------------------------------------------------
# WB2: On Opus failure (dispatch raises) -> error dict with MVFL_ESCALATION_FAILED
# ---------------------------------------------------------------------------

def test_wb2_dispatch_exception_returns_error_dict():
    """WB2: If dispatch_fn raises, escalate() returns error dict with MVFL_ESCALATION_FAILED."""
    from core.mvfl.memgpt_escalation import MemGPTEscalation

    async def failing_dispatch(model, system_prompt, task_payload):
        raise RuntimeError("Network timeout")

    esc = MemGPTEscalation(dispatch_fn=failing_dispatch)
    result = run(esc.escalate(make_task("t-wb2"), make_failed_output("t-wb2")))

    assert result["status"] == "error", f"Expected status='error', got: {result['status']}"
    assert result["error"] == "MVFL_ESCALATION_FAILED", (
        f"Expected error='MVFL_ESCALATION_FAILED', got: {result['error']}"
    )
    assert "error_detail" in result, "Missing 'error_detail' in error response"
    assert "Network timeout" in result["error_detail"], (
        f"Exception message not in error_detail: {result['error_detail']}"
    )
    assert result["model"] == "claude-opus-4-6", "Model not preserved in error dict"

    print("WB2 PASSED — dispatch exception produces MVFL_ESCALATION_FAILED error dict")


# ---------------------------------------------------------------------------
# WB3: _build_history_string formats each attempt numbered
# ---------------------------------------------------------------------------

def test_wb3_build_history_string_numbered():
    """WB3: Each attempt block is numbered sequentially starting from 1."""
    from core.mvfl.memgpt_escalation import MemGPTEscalation

    esc = MemGPTEscalation()

    # Single dict attempt
    result_single = esc._build_history_string([{"k": "v"}])
    assert result_single.startswith("--- Attempt 1 ---"), (
        f"First attempt must start with '--- Attempt 1 ---', got:\n{result_single}"
    )

    # Non-dict attempt (raw string)
    result_raw = esc._build_history_string(["raw_error_message"])
    assert "--- Attempt 1 ---" in result_raw, "Missing numbered header for raw attempt"
    assert "raw_error_message" in result_raw, "Raw string not included in output"

    # Three attempts — verify all three headers present
    three_history = [{"x": 1}, {"x": 2}, {"x": 3}]
    result_three = esc._build_history_string(three_history)
    for i in range(1, 4):
        assert f"--- Attempt {i} ---" in result_three, (
            f"Missing '--- Attempt {i} ---' in:\n{result_three}"
        )

    print("WB3 PASSED — _build_history_string formats each attempt numbered correctly")


# ---------------------------------------------------------------------------
# WB4: ESCALATION_SYSTEM_PROMPT contains "{failure_history}" placeholder
# ---------------------------------------------------------------------------

def test_wb4_system_prompt_has_placeholder():
    """WB4: ESCALATION_SYSTEM_PROMPT must contain {failure_history} and {attempt_count}."""
    from core.mvfl.memgpt_escalation import ESCALATION_SYSTEM_PROMPT

    assert "{failure_history}" in ESCALATION_SYSTEM_PROMPT, (
        "ESCALATION_SYSTEM_PROMPT missing '{failure_history}' placeholder"
    )
    assert "{attempt_count}" in ESCALATION_SYSTEM_PROMPT, (
        "ESCALATION_SYSTEM_PROMPT missing '{attempt_count}' placeholder"
    )

    # Verify .format() works without KeyError
    rendered = ESCALATION_SYSTEM_PROMPT.format(attempt_count=3, failure_history="test history")
    assert "test history" in rendered, "Rendered prompt does not include failure_history"
    assert "3" in rendered, "Rendered prompt does not include attempt_count"

    print("WB4 PASSED — ESCALATION_SYSTEM_PROMPT has correct placeholders and formats cleanly")


# ---------------------------------------------------------------------------
# Test runner
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    tests = [
        test_bb1_prompt_contains_failure_details,
        test_bb2_escalation_model_is_opus,
        test_bb3_escalation_event_written,
        test_bb4_no_dispatch_returns_stub,
        test_wb1_history_all_attempts_in_prompt,
        test_wb2_dispatch_exception_returns_error_dict,
        test_wb3_build_history_string_numbered,
        test_wb4_system_prompt_has_placeholder,
    ]

    passed = 0
    failed = 0
    for t in tests:
        try:
            t()
            passed += 1
        except Exception as exc:
            print(f"FAILED: {t.__name__} — {exc}")
            failed += 1

    print(f"\n{'='*50}")
    print(f"Story 3.05 Test Results: {passed}/{len(tests)} passed")
    if failed:
        print(f"FAILED: {failed} test(s)")
        sys.exit(1)
    else:
        print("ALL TESTS PASSED")
        sys.exit(0)
