#!/usr/bin/env python3
"""
Tests for Story 4.05 (Track B): RoutingTelemetry — Tier Distribution Analytics

Black Box tests (BB): verify the public contract from the outside.
White Box tests (WB): verify internal implementation choices (Redis INCR,
    no TTL, reset behaviour, zero-division guard).

Story: 4.05
File under test: core/routing/routing_telemetry.py
"""
import json
import sys
sys.path.insert(0, '/mnt/e/genesis-system')

import pytest
from pathlib import Path
from unittest.mock import MagicMock, call, patch

from core.routing.routing_telemetry import (
    RoutingTelemetry,
    T0_KEY,
    T1_KEY,
    T2_KEY,
    EVENTS_LOG_PATH,
)
from core.routing.tier_classifier import RoutingDecision


# ---------------------------------------------------------------------------
# Helpers / fixtures
# ---------------------------------------------------------------------------

def _decision(tier: str) -> RoutingDecision:
    """Build a minimal RoutingDecision for a given tier."""
    model_map = {"T0": "python_function", "T1": "gemini-flash", "T2": "claude-opus-4-6"}
    return RoutingDecision(tier=tier, model=model_map[tier], rationale=f"test {tier}")


@pytest.fixture
def telemetry_no_redis():
    """RoutingTelemetry with no Redis — uses local in-memory counters."""
    return RoutingTelemetry(redis_client=None)


@pytest.fixture
def mock_redis():
    """A MagicMock that mimics a redis.Redis client."""
    r = MagicMock()
    # Simulate a real Redis: track incr values per key using a backing dict
    _store: dict[str, int] = {}

    def _incr(key):
        _store[key] = _store.get(key, 0) + 1
        return _store[key]

    def _get(key):
        val = _store.get(key)
        return str(val).encode() if val is not None else None

    def _delete(*keys):
        for k in keys:
            _store.pop(k, None)

    r.incr.side_effect = _incr
    r.get.side_effect = _get
    r.delete.side_effect = _delete
    r._store = _store  # expose for assertions
    return r


@pytest.fixture
def telemetry_redis(mock_redis):
    """RoutingTelemetry backed by a mock Redis."""
    return RoutingTelemetry(redis_client=mock_redis)


# ===========================================================================
# Black Box tests
# ===========================================================================


def test_bb1_t0_t1_pct_83_3_for_8_2_2_distribution(telemetry_no_redis):
    """BB1: 8 T0 + 2 T1 + 2 T2 → t0_t1_pct = 83.3%"""
    for _ in range(8):
        telemetry_no_redis.record(_decision("T0"))
    for _ in range(2):
        telemetry_no_redis.record(_decision("T1"))
    for _ in range(2):
        telemetry_no_redis.record(_decision("T2"))

    dist = telemetry_no_redis.get_distribution()
    assert dist["t0"] == 8
    assert dist["t1"] == 2
    assert dist["t2"] == 2
    # (8+2) / 12 * 100 = 83.333... → rounded to 83.3
    assert dist["t0_t1_pct"] == pytest.approx(83.3, abs=0.1)


def test_bb2_get_distribution_reflects_recorded_counts(telemetry_no_redis):
    """BB2: get_distribution() returns the exact counts that were recorded."""
    telemetry_no_redis.record(_decision("T0"))
    telemetry_no_redis.record(_decision("T0"))
    telemetry_no_redis.record(_decision("T1"))
    telemetry_no_redis.record(_decision("T2"))
    telemetry_no_redis.record(_decision("T2"))
    telemetry_no_redis.record(_decision("T2"))

    dist = telemetry_no_redis.get_distribution()
    assert dist["t0"] == 2
    assert dist["t1"] == 1
    assert dist["t2"] == 3
    # (2+1)/6*100 = 50.0
    assert dist["t0_t1_pct"] == pytest.approx(50.0, abs=0.1)


def test_bb3_log_tier_report_writes_json_to_events_jsonl(telemetry_no_redis, tmp_path):
    """BB3: log_tier_report() appends a valid JSON entry to events.jsonl."""
    log_file = tmp_path / "events.jsonl"

    with patch(
        "core.routing.routing_telemetry.EVENTS_LOG_PATH", log_file
    ):
        telemetry_no_redis.record(_decision("T0"))
        telemetry_no_redis.record(_decision("T1"))
        telemetry_no_redis.log_tier_report()

    assert log_file.exists(), "events.jsonl must be created"
    lines = log_file.read_text().strip().split("\n")
    assert len(lines) == 1, "Exactly one JSON line must be written"

    entry = json.loads(lines[0])
    assert entry["event"] == "tier_distribution_report"
    assert "timestamp" in entry
    assert entry["t0"] == 1
    assert entry["t1"] == 1
    assert entry["t2"] == 0
    assert "t0_t1_pct" in entry


def test_bb4_no_redis_local_counters_work_correctly(telemetry_no_redis):
    """BB4: Without Redis, local in-memory counters accumulate correctly."""
    # Three separate record calls; no Redis involved
    for _ in range(5):
        telemetry_no_redis.record(_decision("T0"))
    for _ in range(3):
        telemetry_no_redis.record(_decision("T2"))

    dist = telemetry_no_redis.get_distribution()
    assert dist["t0"] == 5
    assert dist["t1"] == 0
    assert dist["t2"] == 3
    total = 5 + 3
    expected_pct = round(5 / total * 100, 1)
    assert dist["t0_t1_pct"] == pytest.approx(expected_pct, abs=0.1)


# ===========================================================================
# White Box tests
# ===========================================================================


def test_wb1_redis_uses_incr_not_set(mock_redis, telemetry_redis):
    """WB1: record() calls redis.incr() (atomic), never redis.set()."""
    telemetry_redis.record(_decision("T0"))
    telemetry_redis.record(_decision("T1"))
    telemetry_redis.record(_decision("T2"))

    # incr must have been called exactly 3 times (once per tier)
    assert mock_redis.incr.call_count == 3
    # set must never be called — using INCR, not SET
    mock_redis.set.assert_not_called()
    # setex must never be called — no TTL on counters
    mock_redis.setex.assert_not_called()


def test_wb2_no_ttl_on_counter_keys(mock_redis, telemetry_redis):
    """WB2: Counters have no TTL — verify setex and expire are never called."""
    for _ in range(4):
        telemetry_redis.record(_decision("T0"))

    mock_redis.setex.assert_not_called()
    mock_redis.expire.assert_not_called()
    mock_redis.expireat.assert_not_called()


def test_wb3_reset_clears_all_counters(telemetry_no_redis):
    """WB3: reset() zeroes local counters so get_distribution returns all zeros."""
    for _ in range(10):
        telemetry_no_redis.record(_decision("T0"))
    for _ in range(5):
        telemetry_no_redis.record(_decision("T2"))

    telemetry_no_redis.reset()
    dist = telemetry_no_redis.get_distribution()

    assert dist["t0"] == 0
    assert dist["t1"] == 0
    assert dist["t2"] == 0
    assert dist["t0_t1_pct"] == 0.0


def test_wb3_reset_clears_redis_keys(mock_redis, telemetry_redis):
    """WB3 (Redis): reset() deletes all three Redis keys."""
    for _ in range(3):
        telemetry_redis.record(_decision("T1"))

    telemetry_redis.reset()

    # delete must have been called with all three keys
    mock_redis.delete.assert_called_once_with(T0_KEY, T1_KEY, T2_KEY)

    # After reset, distribution should be zero (keys gone from mock _store)
    dist = telemetry_redis.get_distribution()
    assert dist["t0"] == 0
    assert dist["t1"] == 0
    assert dist["t2"] == 0
    assert dist["t0_t1_pct"] == 0.0


def test_wb4_empty_state_returns_zero_pct_no_division_error(telemetry_no_redis):
    """WB4: Before any records, t0_t1_pct = 0.0 (no ZeroDivisionError)."""
    dist = telemetry_no_redis.get_distribution()
    assert dist["t0"] == 0
    assert dist["t1"] == 0
    assert dist["t2"] == 0
    assert dist["t0_t1_pct"] == 0.0  # not an exception


# ---------------------------------------------------------------------------
# Extra: unknown tier is silently ignored
# ---------------------------------------------------------------------------

def test_unknown_tier_is_ignored(telemetry_no_redis):
    """record() with an unrecognised tier must not raise and must not alter counts."""
    bad_decision = RoutingDecision(tier="TX", model="unknown", rationale="bad tier")
    telemetry_no_redis.record(bad_decision)
    dist = telemetry_no_redis.get_distribution()
    assert dist["t0"] == 0
    assert dist["t1"] == 0
    assert dist["t2"] == 0


# ---------------------------------------------------------------------------
# Extra: log_tier_report appends multiple entries (not overwrites)
# ---------------------------------------------------------------------------

def test_log_tier_report_appends_not_overwrites(telemetry_no_redis, tmp_path):
    """log_tier_report() appends; calling it twice gives two JSON lines."""
    log_file = tmp_path / "events.jsonl"

    with patch("core.routing.routing_telemetry.EVENTS_LOG_PATH", log_file):
        telemetry_no_redis.log_tier_report()
        telemetry_no_redis.log_tier_report()

    lines = log_file.read_text().strip().split("\n")
    assert len(lines) == 2, "Second call must append, not overwrite"
    for line in lines:
        entry = json.loads(line)
        assert entry["event"] == "tier_distribution_report"


# ---------------------------------------------------------------------------
# Export check: RoutingTelemetry available from core.routing package
# ---------------------------------------------------------------------------

def test_routing_telemetry_exported_from_package():
    """RoutingTelemetry must be importable from core.routing."""
    from core.routing import RoutingTelemetry as RT
    assert RT is RoutingTelemetry


if __name__ == "__main__":
    # Allow running directly: python tests/track_b/test_story_4_05.py
    import traceback

    t_local = RoutingTelemetry(redis_client=None)
    tests_run = 0
    tests_passed = 0

    def run_test(name, fn):
        global tests_run, tests_passed
        tests_run += 1
        try:
            fn()
            print(f"  [PASS] {name}")
            tests_passed += 1
        except Exception as exc:
            print(f"  [FAIL] {name}: {exc}")
            traceback.print_exc()

    # BB1
    def _bb1():
        t = RoutingTelemetry()
        for _ in range(8): t.record(_decision("T0"))
        for _ in range(2): t.record(_decision("T1"))
        for _ in range(2): t.record(_decision("T2"))
        d = t.get_distribution()
        assert d["t0_t1_pct"] == pytest.approx(83.3, abs=0.1)

    # BB4
    def _bb4():
        t = RoutingTelemetry()
        for _ in range(5): t.record(_decision("T0"))
        d = t.get_distribution()
        assert d["t0"] == 5

    # WB4
    def _wb4():
        t = RoutingTelemetry()
        d = t.get_distribution()
        assert d["t0_t1_pct"] == 0.0

    run_test("BB1: 8+2+2 → 83.3%", _bb1)
    run_test("BB4: local counters work", _bb4)
    run_test("WB4: empty state → 0.0, no ZeroDivisionError", _wb4)

    print(f"\n{tests_passed}/{tests_run} tests passed")
    if tests_passed == tests_run:
        print("ALL TESTS PASSED — Story 4.05 (Track B)")
    else:
        sys.exit(1)
