#!/usr/bin/env python3
"""
Tests for Story 5.03: IntentClassifier — Redis Cache Layer
AIVA RLM Nexus PRD v2 — Track A, Module 5 (Intent + Routing)

Black box tests (BB1-BB3): cache hit / UNKNOWN not cached / TTL-expired behaviour.
White box tests (WB1-WB3): SHA256 key, JSON serialization, UNKNOWN guard placement.

All Redis calls are fully mocked — no real Redis, no real Gemini.
"""

import asyncio
import hashlib
import inspect
import json
import sys

import pytest

sys.path.insert(0, "/mnt/e/genesis-system")

from unittest.mock import AsyncMock, MagicMock, call, patch
from datetime import datetime, timezone
from core.intent import IntentClassifier, IntentSignal, IntentType
from core.intent.intent_classifier import INTENT_CACHE_TTL, _INTENT_CACHE_KEY_TEMPLATE


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _run(coro):
    """Run a coroutine synchronously for test purposes."""
    return asyncio.get_event_loop().run_until_complete(coro)


def _valid_json(
    intent_type: str = "book_job",
    confidence: float = 0.92,
    entities: dict = None,
    requires_swarm: bool = True,
) -> str:
    return json.dumps({
        "intent_type": intent_type,
        "confidence": confidence,
        "extracted_entities": entities or {"service": "plumbing"},
        "requires_swarm": requires_swarm,
        "reasoning": "test",
    })


def _async_gemini_client(text: str) -> MagicMock:
    """Build an async mock Gemini client that returns `text`."""
    client = MagicMock()
    response = MagicMock()
    response.text = text

    async def _fake_async(prompt):
        return response

    client.generate_content_async = _fake_async
    return client


def _mock_redis(stored: dict = None) -> MagicMock:
    """
    Build a mock async Redis client.

    stored: dict mapping cache keys to JSON strings already in Redis.
             Pass None (default) for an empty cache.
    """
    store = dict(stored or {})
    redis = MagicMock()

    async def fake_get(key):
        return store.get(key)

    async def fake_setex(key, ttl, value):
        store[key] = value

    redis.get = fake_get
    redis.setex = AsyncMock(side_effect=fake_setex)
    return redis, store


def _make_classifier(gemini_text: str, redis_client=None) -> IntentClassifier:
    return IntentClassifier(_async_gemini_client(gemini_text), redis_client)


def _sha256(utterance: str) -> str:
    return hashlib.sha256(utterance.encode("utf-8")).hexdigest()


def _expected_cache_key(utterance: str) -> str:
    return _INTENT_CACHE_KEY_TEMPLATE.format(utterance_hash=_sha256(utterance))


# ===========================================================================
# BB1: classify() called twice with same utterance → Gemini called only once
# ===========================================================================

class TestBB1_CacheHit:
    """
    BB1: After the first classify() call warms the cache, the second identical
    call returns the cached result without calling Gemini again.
    """

    def test_gemini_called_once_for_two_identical_utterances(self):
        redis, _ = _mock_redis()
        gemini = _async_gemini_client(_valid_json("answer_faq", confidence=0.88))
        call_count = 0

        original_async = gemini.generate_content_async

        async def counting_async(prompt):
            nonlocal call_count
            call_count += 1
            return await original_async(prompt)

        gemini.generate_content_async = counting_async
        clf = IntentClassifier(gemini, redis)

        _run(clf.classify("What are your hours?", session_id="s1"))
        _run(clf.classify("What are your hours?", session_id="s2"))

        assert call_count == 1, (
            f"Gemini must be called exactly once for duplicate utterances, got {call_count}"
        )

    def test_second_call_returns_intent_signal(self):
        redis, _ = _mock_redis()
        clf = _make_classifier(_valid_json("answer_faq", confidence=0.88), redis)

        _run(clf.classify("What are your hours?", session_id="s1"))
        result = _run(clf.classify("What are your hours?", session_id="s2"))

        assert isinstance(result, IntentSignal)

    def test_second_call_has_correct_intent_type(self):
        redis, _ = _mock_redis()
        clf = _make_classifier(_valid_json("answer_faq", confidence=0.88), redis)

        _run(clf.classify("What are your hours?", session_id="s1"))
        result = _run(clf.classify("What are your hours?", session_id="s2"))

        assert result.intent_type is IntentType.ANSWER_FAQ

    def test_second_call_has_correct_confidence(self):
        redis, _ = _mock_redis()
        clf = _make_classifier(_valid_json("answer_faq", confidence=0.88), redis)

        _run(clf.classify("What are your hours?", session_id="s1"))
        result = _run(clf.classify("What are your hours?", session_id="s2"))

        assert abs(result.confidence - 0.88) < 1e-6

    def test_different_utterances_both_call_gemini(self):
        redis, _ = _mock_redis()
        gemini = _async_gemini_client(_valid_json("answer_faq", confidence=0.88))
        call_count = 0

        original_async = gemini.generate_content_async

        async def counting_async(prompt):
            nonlocal call_count
            call_count += 1
            return await original_async(prompt)

        gemini.generate_content_async = counting_async
        clf = IntentClassifier(gemini, redis)

        _run(clf.classify("What are your hours?", session_id="s1"))
        _run(clf.classify("How do I book?", session_id="s2"))

        assert call_count == 2, (
            f"Different utterances must each call Gemini; got {call_count} calls"
        )

    def test_no_redis_always_calls_gemini_twice(self):
        """Without Redis, every classify() call hits Gemini."""
        gemini = _async_gemini_client(_valid_json("answer_faq", confidence=0.88))
        call_count = 0

        original_async = gemini.generate_content_async

        async def counting_async(prompt):
            nonlocal call_count
            call_count += 1
            return await original_async(prompt)

        gemini.generate_content_async = counting_async
        clf = IntentClassifier(gemini, redis_client=None)

        _run(clf.classify("What are your hours?", session_id="s1"))
        _run(clf.classify("What are your hours?", session_id="s2"))

        assert call_count == 2, (
            "Without Redis, every call must reach Gemini"
        )


# ===========================================================================
# BB2: UNKNOWN result → not cached; Gemini called again next time
# ===========================================================================

class TestBB2_UnknownNotCached:
    """
    BB2: When IntentType.UNKNOWN is returned, it must NOT be written to Redis.
    A subsequent identical call must call Gemini again.
    """

    def test_unknown_not_cached(self):
        redis, _ = _mock_redis()
        clf = _make_classifier("{not valid json}", redis)

        # First call → UNKNOWN (malformed JSON)
        first = _run(clf.classify("mumbling", session_id="s1"))
        assert first.intent_type is IntentType.UNKNOWN

        # Redis setex must NOT have been called
        redis.setex.assert_not_called()

    def test_gemini_called_twice_when_unknown(self):
        """UNKNOWN result skips cache → next call is a fresh Gemini hit."""
        redis, _ = _mock_redis()
        gemini = _async_gemini_client("{bad json}")
        call_count = 0

        original_async = gemini.generate_content_async

        async def counting_async(prompt):
            nonlocal call_count
            call_count += 1
            return await original_async(prompt)

        gemini.generate_content_async = counting_async
        clf = IntentClassifier(gemini, redis)

        _run(clf.classify("mumbling", session_id="s1"))
        _run(clf.classify("mumbling", session_id="s2"))

        assert call_count == 2, (
            f"UNKNOWN results must not be cached; Gemini should be called twice, got {call_count}"
        )

    def test_subsequent_valid_call_caches_correctly(self):
        """After an UNKNOWN, a valid utterance from the same session IS cached."""
        redis, _ = _mock_redis()
        gemini = MagicMock()
        # First call: bad JSON → UNKNOWN
        # Second call: valid JSON → ANSWER_FAQ
        responses = iter(["{bad json}", _valid_json("answer_faq", confidence=0.9)])

        async def fake_async(prompt):
            r = MagicMock()
            r.text = next(responses)
            return r

        gemini.generate_content_async = fake_async
        clf = IntentClassifier(gemini, redis)

        _run(clf.classify("mumbling", session_id="s1"))    # UNKNOWN → not cached
        _run(clf.classify("What are your hours?", session_id="s1"))  # valid → cached

        # setex called exactly once (for the valid utterance only)
        assert redis.setex.call_count == 1


# ===========================================================================
# BB3: TTL expiry → Gemini called again
# ===========================================================================

class TestBB3_TTLExpiry:
    """
    BB3: When the Redis key has expired (get returns None), Gemini is called again.
    This is simulated by pre-populating the cache with a key that get() returns None.
    """

    def test_cache_miss_after_ttl_calls_gemini(self):
        """
        If Redis returns None (key expired or never set), Gemini is called.
        """
        redis, _ = _mock_redis(stored={})   # empty store → simulates expired TTL
        gemini = _async_gemini_client(_valid_json("answer_faq", confidence=0.88))
        call_count = 0

        original_async = gemini.generate_content_async

        async def counting_async(prompt):
            nonlocal call_count
            call_count += 1
            return await original_async(prompt)

        gemini.generate_content_async = counting_async
        clf = IntentClassifier(gemini, redis)

        _run(clf.classify("What are your hours?", session_id="s1"))

        assert call_count == 1, "Should call Gemini when cache is empty (TTL expired)"

    def test_ttl_constant_is_60(self):
        """The cache TTL constant must be exactly 60 seconds."""
        assert INTENT_CACHE_TTL == 60, (
            f"INTENT_CACHE_TTL must be 60, got {INTENT_CACHE_TTL}"
        )

    def test_setex_called_with_correct_ttl(self):
        """setex must be called with INTENT_CACHE_TTL (60) as the expiry."""
        redis, _ = _mock_redis()
        clf = _make_classifier(_valid_json("answer_faq", confidence=0.88), redis)

        _run(clf.classify("What are your hours?", session_id="s1"))

        assert redis.setex.call_count == 1
        _, args, _ = redis.setex.mock_calls[0]
        # setex(key, ttl, value) — positional
        _key, ttl, _payload = redis.setex.call_args[0]
        assert ttl == INTENT_CACHE_TTL, (
            f"setex TTL must be {INTENT_CACHE_TTL}, got {ttl}"
        )


# ===========================================================================
# WB1: Cache key is SHA256 hash, not plain utterance text
# ===========================================================================

class TestWB1_CacheKeyIsSHA256:
    """WB1: The cache key uses SHA256 of the utterance, not the raw text."""

    def test_cache_key_is_sha256_hex(self):
        clf = IntentClassifier(MagicMock())
        utterance = "I need a plumber"
        key = clf._cache_key(utterance)
        expected_hash = hashlib.sha256(utterance.encode("utf-8")).hexdigest()
        assert key == f"intent:cache:{expected_hash}"

    def test_cache_key_does_not_contain_raw_utterance(self):
        clf = IntentClassifier(MagicMock())
        utterance = "UNIQUE_UTTERANCE_CANARY_12345"
        key = clf._cache_key(utterance)
        assert utterance not in key, (
            "Raw utterance text must not appear verbatim in the cache key"
        )

    def test_utterance_hash_is_64_char_hex(self):
        clf = IntentClassifier(MagicMock())
        h = clf._utterance_hash("some utterance")
        assert len(h) == 64
        assert all(c in "0123456789abcdef" for c in h)

    def test_different_utterances_produce_different_keys(self):
        clf = IntentClassifier(MagicMock())
        key1 = clf._cache_key("I need a plumber")
        key2 = clf._cache_key("What are your hours?")
        assert key1 != key2

    def test_same_utterance_produces_same_key(self):
        clf = IntentClassifier(MagicMock())
        utterance = "I need a plumber"
        assert clf._cache_key(utterance) == clf._cache_key(utterance)

    def test_hashlib_sha256_used_in_source(self):
        import core.intent.intent_classifier as mod
        source = inspect.getsource(mod)
        assert "hashlib.sha256" in source, (
            "Cache key must use hashlib.sha256"
        )


# ===========================================================================
# WB2: IntentSignal serialized to JSON for Redis storage
# ===========================================================================

class TestWB2_JSONSerialization:
    """WB2: IntentSignal is serialized via json.dumps / json.loads round-trip."""

    def _make_signal(self) -> IntentSignal:
        return IntentSignal(
            session_id="sess-42",
            utterance="I need a plumber",
            intent_type=IntentType.BOOK_JOB,
            confidence=0.92,
            extracted_entities={"service": "plumbing"},
            requires_swarm=True,
            created_at=datetime(2026, 2, 25, 10, 0, 0, tzinfo=timezone.utc),
            raw_gemini_response='{"intent_type": "book_job"}',
        )

    def test_signal_to_dict_produces_json_serialisable_dict(self):
        clf = IntentClassifier(MagicMock())
        signal = self._make_signal()
        d = clf._signal_to_dict(signal)
        # Must not raise
        payload = json.dumps(d)
        assert isinstance(payload, str)

    def test_signal_to_dict_intent_type_is_string(self):
        clf = IntentClassifier(MagicMock())
        d = clf._signal_to_dict(self._make_signal())
        assert isinstance(d["intent_type"], str)
        assert d["intent_type"] == "book_job"

    def test_signal_to_dict_created_at_is_isoformat_string(self):
        clf = IntentClassifier(MagicMock())
        d = clf._signal_to_dict(self._make_signal())
        assert isinstance(d["created_at"], str)
        # Must be parseable as ISO 8601
        parsed = datetime.fromisoformat(d["created_at"])
        assert isinstance(parsed, datetime)

    def test_round_trip_preserves_intent_type(self):
        clf = IntentClassifier(MagicMock())
        signal = self._make_signal()
        d = clf._signal_to_dict(signal)
        restored = clf._dict_to_signal(d)
        assert restored.intent_type is IntentType.BOOK_JOB

    def test_round_trip_preserves_confidence(self):
        clf = IntentClassifier(MagicMock())
        signal = self._make_signal()
        d = clf._signal_to_dict(signal)
        restored = clf._dict_to_signal(d)
        assert abs(restored.confidence - 0.92) < 1e-9

    def test_round_trip_preserves_extracted_entities(self):
        clf = IntentClassifier(MagicMock())
        signal = self._make_signal()
        d = clf._signal_to_dict(signal)
        restored = clf._dict_to_signal(d)
        assert restored.extracted_entities == {"service": "plumbing"}

    def test_round_trip_preserves_requires_swarm(self):
        clf = IntentClassifier(MagicMock())
        signal = self._make_signal()
        d = clf._signal_to_dict(signal)
        restored = clf._dict_to_signal(d)
        assert restored.requires_swarm is True

    def test_setex_stores_valid_json_string(self):
        redis, store = _mock_redis()
        clf = _make_classifier(_valid_json("answer_faq", confidence=0.9), redis)

        _run(clf.classify("What are your hours?", session_id="s1"))

        key = _expected_cache_key("What are your hours?")
        assert key in store
        # Value must be valid JSON
        parsed = json.loads(store[key])
        assert "intent_type" in parsed


# ===========================================================================
# WB3: UNKNOWN check happens before cache write
# ===========================================================================

class TestWB3_UnknownGuardBeforeCacheWrite:
    """WB3: The UNKNOWN guard in classify() must prevent setex from being called."""

    def test_setex_not_called_for_unknown(self):
        redis, _ = _mock_redis()
        clf = _make_classifier("{not valid json}", redis)

        result = _run(clf.classify("garbage", session_id="s1"))

        assert result.intent_type is IntentType.UNKNOWN
        redis.setex.assert_not_called()

    def test_setex_called_for_known_intent(self):
        redis, _ = _mock_redis()
        clf = _make_classifier(_valid_json("qualify_lead", confidence=0.8), redis)

        result = _run(clf.classify("Can I qualify?", session_id="s1"))

        assert result.intent_type is IntentType.QUALIFY_LEAD
        assert redis.setex.call_count == 1

    def test_get_called_before_setex(self):
        """get() (cache read) must be called before setex() (cache write)."""
        call_order = []
        redis = MagicMock()

        async def fake_get(key):
            call_order.append("get")
            return None  # cache miss

        async def fake_setex(key, ttl, value):
            call_order.append("setex")

        redis.get = fake_get
        redis.setex = AsyncMock(side_effect=fake_setex)

        clf = _make_classifier(_valid_json("answer_faq", confidence=0.9), redis)
        _run(clf.classify("What are your hours?", session_id="s1"))

        assert call_order == ["get", "setex"], (
            f"Expected ['get', 'setex'], got {call_order}"
        )

    def test_unknown_guard_in_source(self):
        """Source must contain the UNKNOWN check before _set_cached."""
        import core.intent.intent_classifier as mod
        source = inspect.getsource(mod)
        # Verify the guard pattern exists
        assert "IntentType.UNKNOWN" in source
        assert "_set_cached" in source

    def test_no_sqlite3_in_source(self):
        """sqlite3 is FORBIDDEN."""
        import core.intent.intent_classifier as mod
        source = inspect.getsource(mod)
        assert "import sqlite3" not in source, (
            "sqlite3 is FORBIDDEN in intent_classifier.py (Rule 7)"
        )


# ===========================================================================
# Regression guard: 5.02 public interface unchanged
# ===========================================================================

class TestRegression_502:
    """
    Verify that Story 5.02 behaviour is fully preserved after the 5.03 changes.
    All these tests re-use the IntentClassifier WITHOUT a redis_client to ensure
    the existing code paths are unchanged.
    """

    def test_classify_returns_intent_signal_no_redis(self):
        clf = _make_classifier(_valid_json("book_job", confidence=0.92), redis_client=None)
        signal = _run(clf.classify("I need a plumber", session_id="s1"))
        assert isinstance(signal, IntentSignal)

    def test_classify_book_job_no_redis(self):
        clf = _make_classifier(_valid_json("book_job", confidence=0.92), redis_client=None)
        signal = _run(clf.classify("I need a plumber", session_id="s1"))
        assert signal.intent_type is IntentType.BOOK_JOB

    def test_malformed_json_returns_unknown_no_redis(self):
        clf = _make_classifier("{not json}", redis_client=None)
        signal = _run(clf.classify("anything", session_id="s1"))
        assert signal.intent_type is IntentType.UNKNOWN

    def test_requires_swarm_forced_book_job_no_redis(self):
        clf = _make_classifier(_valid_json("book_job", requires_swarm=False), redis_client=None)
        signal = _run(clf.classify("book", session_id="s1"))
        assert signal.requires_swarm is True

    def test_confidence_clamped_no_redis(self):
        clf = _make_classifier(_valid_json("qualify_lead", confidence=99.0), redis_client=None)
        signal = _run(clf.classify("qualify", session_id="s1"))
        assert signal.confidence == 1.0

    def test_gemini_model_constant_unchanged(self):
        assert IntentClassifier.GEMINI_MODEL == "gemini-2.0-flash"

    def test_redis_none_by_default(self):
        clf = IntentClassifier(MagicMock())
        assert clf._redis is None


# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    result = pytest.main([__file__, "-v", "--tb=short"])
    sys.exit(result)
