"""
tests/track_b/test_story_9_04.py

Story 9.04: ConversationAggregator — Weekly Summary

Black Box Tests (BB1–BB4):
    BB1  10 sagas in mock DB → total_tasks=10
    BB2  lookback_days=1 → only uses 1-day interval in query
    BB3  Snippets containing "API_KEY" → sanitized to [REDACTED]
    BB4  No pg_connection → returns zeroed summary

White Box Tests (WB1–WB4):
    WB1  Postgres query uses parameterized interval (not string concat)
    WB2  Snippets limited to MAX_SNIPPETS (20 entries)
    WB3  Sanitization catches all SENSITIVE_PATTERNS
    WB4  period_start and period_end are UTC datetimes
"""

from __future__ import annotations

import sys
from datetime import datetime, timezone, timedelta
from unittest.mock import MagicMock, call, patch

import pytest

# ---------------------------------------------------------------------------
# Path setup
# ---------------------------------------------------------------------------

GENESIS_ROOT = "/mnt/e/genesis-system"
if GENESIS_ROOT not in sys.path:
    sys.path.insert(0, GENESIS_ROOT)

# ---------------------------------------------------------------------------
# Imports under test
# ---------------------------------------------------------------------------

from core.epoch.conversation_aggregator import (  # noqa: E402
    ConversationAggregator,
    WeeklyConversationSummary,
    SENSITIVE_PATTERNS,
    MAX_SNIPPETS,
)

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _make_mock_pg(total_tasks: int = 0, failed_tasks: int = 0, snippets=None):
    """
    Build a mock psycopg2 connection whose cursor returns prescribed values.

    Call order matters — cursor.fetchone() is called twice (total, failed) then
    cursor.fetchall() is called once for snippets.
    """
    if snippets is None:
        snippets = []

    cursor = MagicMock()
    # fetchone() calls: first = total_tasks, second = failed_tasks
    cursor.fetchone.side_effect = [(total_tasks,), (failed_tasks,)]
    cursor.fetchall.return_value = snippets

    conn = MagicMock()
    conn.cursor.return_value = cursor
    return conn, cursor


# ===========================================================================
# BB Tests — Black Box
# ===========================================================================


def test_bb1_ten_sagas_returns_total_tasks_10():
    """BB1: 10 sagas in mock DB → total_tasks=10."""
    conn, _ = _make_mock_pg(total_tasks=10, failed_tasks=2, snippets=[])
    agg = ConversationAggregator(pg_connection=conn)

    summary = agg.aggregate(lookback_days=7)

    assert isinstance(summary, WeeklyConversationSummary)
    assert summary.total_tasks == 10
    assert summary.failed_tasks == 2


def test_bb2_lookback_days_1_uses_1_day_interval():
    """BB2: lookback_days=1 → the value 1 is passed as interval parameter."""
    conn, cursor = _make_mock_pg(total_tasks=3, failed_tasks=0)
    agg = ConversationAggregator(pg_connection=conn)

    agg.aggregate(lookback_days=1)

    # Gather all execute() calls to confirm interval parameter
    all_calls = cursor.execute.call_args_list
    # First two calls are for sagas (total + failed); third for events
    # Each call's second argument (params tuple) must contain 1 for lookback
    interval_values_used = []
    for c in all_calls:
        args, kwargs = c
        if len(args) >= 2:
            params = args[1]
            if params:
                interval_values_used.append(params[0])

    assert 1 in interval_values_used, (
        f"Expected interval value 1 in execute params, got: {interval_values_used}"
    )


def test_bb3_snippet_with_api_key_is_sanitized():
    """BB3: Snippets containing 'API_KEY' → sanitized to [REDACTED]."""
    raw_snippets = [
        ("TASK_COMPLETE", "ran with API_KEY=sk-abc123"),
        ("TASK_START", "normal event description"),
    ]
    conn, _ = _make_mock_pg(total_tasks=5, failed_tasks=0, snippets=raw_snippets)
    agg = ConversationAggregator(pg_connection=conn)

    summary = agg.aggregate()

    assert len(summary.conversation_snippets) == 2
    # First snippet must not contain the raw key
    assert "API_KEY=sk-abc123" not in summary.conversation_snippets[0]
    assert "[REDACTED]" in summary.conversation_snippets[0]
    # Second snippet is clean
    assert "normal event description" in summary.conversation_snippets[1]


def test_bb4_no_pg_connection_returns_zeroed_summary():
    """BB4: No pg_connection → returns zeroed WeeklyConversationSummary."""
    agg = ConversationAggregator(pg_connection=None)

    summary = agg.aggregate(lookback_days=7)

    assert isinstance(summary, WeeklyConversationSummary)
    assert summary.total_sessions == 0
    assert summary.total_tasks == 0
    assert summary.failed_tasks == 0
    assert summary.conversation_snippets == []


# ===========================================================================
# WB Tests — White Box
# ===========================================================================


def test_wb1_queries_are_parameterized_not_string_concat():
    """WB1: Postgres queries use parameterized intervals — no f-string/format injection."""
    conn, cursor = _make_mock_pg(total_tasks=4, failed_tasks=1)
    agg = ConversationAggregator(pg_connection=conn)

    agg.aggregate(lookback_days=14)

    for c in cursor.execute.call_args_list:
        args, _ = c
        sql = args[0]
        # The literal number should NOT be in the SQL string directly
        assert "14" not in sql, (
            f"lookback_days value '14' was interpolated into SQL string: {sql!r}"
        )
        # Parameterized placeholder must be present
        assert "%s" in sql, f"Expected parameterized query but got: {sql!r}"


def test_wb2_snippets_limited_to_max_snippets():
    """WB2: fetchall returns more than 20 rows but only MAX_SNIPPETS are stored."""
    # Build 25 raw snippet rows
    raw_snippets = [
        (f"EVENT_{i}", f"description {i}") for i in range(25)
    ]
    conn, cursor = _make_mock_pg(
        total_tasks=25, failed_tasks=0, snippets=raw_snippets
    )
    agg = ConversationAggregator(pg_connection=conn)

    summary = agg.aggregate()

    # The LIMIT is passed to the DB query; fetchall returns what the DB gave us.
    # Verify the LIMIT parameter in the events query equals MAX_SNIPPETS.
    events_query_call = None
    for c in cursor.execute.call_args_list:
        args, _ = c
        sql = args[0]
        if "events" in sql.lower():
            events_query_call = args
            break

    assert events_query_call is not None, "No events query was executed"
    params = events_query_call[1]
    # Second param to the events query is the LIMIT
    assert MAX_SNIPPETS in params, (
        f"MAX_SNIPPETS ({MAX_SNIPPETS}) not found in events query params: {params}"
    )

    # Snippets in summary match what fetchall returned (25 rows → 25 snippets)
    # because LIMIT is applied by the DB, not Python
    assert len(summary.conversation_snippets) == 25


def test_wb3_sanitize_catches_all_sensitive_patterns():
    """WB3: _sanitize catches every pattern in SENSITIVE_PATTERNS."""
    for pattern in SENSITIVE_PATTERNS:
        dirty = f"some text {pattern} more text"
        result = ConversationAggregator._sanitize(dirty)
        assert "[REDACTED]" in result, (
            f"Pattern '{pattern}' was not sanitized. Got: {result!r}"
        )
        assert pattern not in result, (
            f"Pattern '{pattern}' still present after sanitization. Got: {result!r}"
        )


def test_wb4_period_start_and_end_are_utc_datetimes():
    """WB4: period_start and period_end are UTC-aware datetime objects."""
    agg = ConversationAggregator(pg_connection=None)

    summary = agg.aggregate(lookback_days=7)

    assert isinstance(summary.period_start, datetime)
    assert isinstance(summary.period_end, datetime)
    assert summary.period_start.tzinfo is not None, "period_start must be tz-aware"
    assert summary.period_end.tzinfo is not None, "period_end must be tz-aware"
    # UTC check
    assert summary.period_start.utcoffset().total_seconds() == 0
    assert summary.period_end.utcoffset().total_seconds() == 0
    # period_start must be before period_end
    assert summary.period_start < summary.period_end


# ===========================================================================
# Additional edge-case tests
# ===========================================================================


def test_period_window_matches_lookback_days():
    """period_end - period_start ≈ lookback_days (within 1 second)."""
    agg = ConversationAggregator(pg_connection=None)

    summary = agg.aggregate(lookback_days=14)

    delta = summary.period_end - summary.period_start
    expected_seconds = 14 * 86400
    assert abs(delta.total_seconds() - expected_seconds) < 2, (
        f"Expected ~{expected_seconds}s window, got {delta.total_seconds()}s"
    )


def test_zero_tasks_gives_minimum_one_session():
    """When total_tasks=0 from DB, total_sessions should be max(1, 0//5) = 1."""
    conn, _ = _make_mock_pg(total_tasks=0, failed_tasks=0, snippets=[])
    agg = ConversationAggregator(pg_connection=conn)

    summary = agg.aggregate()

    assert summary.total_sessions >= 1, (
        f"Expected total_sessions >= 1, got {summary.total_sessions}"
    )


def test_weekly_conversation_summary_is_dataclass():
    """WeeklyConversationSummary is a proper dataclass with all required fields."""
    import dataclasses

    assert dataclasses.is_dataclass(WeeklyConversationSummary)
    field_names = {f.name for f in dataclasses.fields(WeeklyConversationSummary)}
    required = {
        "total_sessions",
        "total_tasks",
        "failed_tasks",
        "conversation_snippets",
        "period_start",
        "period_end",
    }
    assert required.issubset(field_names), (
        f"Missing fields: {required - field_names}"
    )


def test_sanitize_text_before_sensitive_pattern_preserved():
    """Text before a sensitive pattern is preserved; only from pattern onward is redacted."""
    text = "Event started successfully. API_KEY=abc123 was found."
    result = ConversationAggregator._sanitize(text)

    assert result.startswith("Event started successfully. ")
    assert "[REDACTED]" in result
    assert "API_KEY=abc123" not in result


def test_aggregate_returns_correct_snippet_count():
    """Aggregate returns exactly as many snippets as rows returned by fetchall."""
    raw_snippets = [
        ("EVT_A", "clean description one"),
        ("EVT_B", "clean description two"),
        ("EVT_C", "clean description three"),
    ]
    conn, _ = _make_mock_pg(total_tasks=3, failed_tasks=0, snippets=raw_snippets)
    agg = ConversationAggregator(pg_connection=conn)

    summary = agg.aggregate()

    assert len(summary.conversation_snippets) == 3
    assert "EVT_A: clean description one" in summary.conversation_snippets


def test_package_init_exports():
    """PKG: core.epoch __init__.py exports ConversationAggregator and WeeklyConversationSummary."""
    from core.epoch import (  # noqa: F401
        ConversationAggregator as CA,
        WeeklyConversationSummary as WCS,
    )

    assert CA is ConversationAggregator
    assert WCS is WeeklyConversationSummary


# ===========================================================================
# Standalone runner
# ===========================================================================

if __name__ == "__main__":
    import traceback

    tests = [
        ("BB1: 10 sagas → total_tasks=10", test_bb1_ten_sagas_returns_total_tasks_10),
        ("BB2: lookback_days=1 uses 1-day interval", test_bb2_lookback_days_1_uses_1_day_interval),
        ("BB3: API_KEY in snippet → [REDACTED]", test_bb3_snippet_with_api_key_is_sanitized),
        ("BB4: no pg_connection → zeroed summary", test_bb4_no_pg_connection_returns_zeroed_summary),
        ("WB1: queries are parameterized", test_wb1_queries_are_parameterized_not_string_concat),
        ("WB2: snippets limited to MAX_SNIPPETS", test_wb2_snippets_limited_to_max_snippets),
        ("WB3: sanitize catches all patterns", test_wb3_sanitize_catches_all_sensitive_patterns),
        ("WB4: period dates are UTC-aware", test_wb4_period_start_and_end_are_utc_datetimes),
        ("EDGE: window matches lookback_days", test_period_window_matches_lookback_days),
        ("EDGE: zero tasks → min 1 session", test_zero_tasks_gives_minimum_one_session),
        ("EDGE: WeeklyConversationSummary is dataclass", test_weekly_conversation_summary_is_dataclass),
        ("EDGE: text before pattern preserved", test_sanitize_text_before_sensitive_pattern_preserved),
        ("EDGE: snippet count matches rows", test_aggregate_returns_correct_snippet_count),
        ("PKG: __init__.py exports", test_package_init_exports),
    ]

    passed = 0
    total = len(tests)
    for name, fn in tests:
        try:
            fn()
            print(f"  [PASS] {name}")
            passed += 1
        except Exception as exc:
            print(f"  [FAIL] {name}: {exc}")
            traceback.print_exc()

    print(f"\n{passed}/{total} tests passed")
    if passed == total:
        print("ALL TESTS PASSED -- Story 9.04 (Track B)")
    else:
        sys.exit(1)
