"""Tests for RLM Neo-Cortex Component Interface Contracts (Story 0.01).

Black box tests:
    BB1: Import all public names from core.rlm.contracts -- assert no ImportError
    BB2: Create a MemoryRecord with only required fields -- assert defaults populated
    BB3: Verify EntitlementManifest.decay_policy defaults to "moderate"
    BB4: Verify MemoryTier("working") == MemoryTier.WORKING

White box tests:
    WB1: Instantiate a class with partial MemoryGatewayProtocol -- assert structural mismatch
    WB2: Verify PreferencePair.confidence default is 1.0 via dataclasses.fields()

Additional coverage tests:
    AC1: All 6 Protocol classes are importable and inspectable via typing.get_type_hints()
    AC2: Constants have correct values
    AC3: FeedbackSignal enum values are -1, 0, 1
    AC4: CustomerTier has exactly 4 values
    AC5: MemoryRecord serialization round-trip via dataclasses.asdict()
    AC6: Package-level imports from core.rlm work correctly
"""
from __future__ import annotations

import dataclasses
import typing
from datetime import datetime
from uuid import UUID, uuid4

import pytest


# ---------------------------------------------------------------------------
# BB1: Import all public names from core.rlm.contracts -- no ImportError
# ---------------------------------------------------------------------------


class TestBB1ImportAllPublicNames:
    """Black box test: every public symbol is importable without error."""

    def test_import_enums(self) -> None:
        from core.rlm.contracts import CustomerTier, FeedbackSignal, MemoryTier

        assert MemoryTier is not None
        assert CustomerTier is not None
        assert FeedbackSignal is not None

    def test_import_dataclasses(self) -> None:
        from core.rlm.contracts import (
            EntitlementManifest,
            MemoryRecord,
            PreferencePair,
        )

        assert MemoryRecord is not None
        assert EntitlementManifest is not None
        assert PreferencePair is not None

    def test_import_protocols(self) -> None:
        from core.rlm.contracts import (
            DecaySchedulerProtocol,
            EntitlementLedgerProtocol,
            FeedbackCollectorProtocol,
            MemoryGatewayProtocol,
            SurpriseIntegrationProtocol,
            TenantPartitionProtocol,
        )

        assert MemoryGatewayProtocol is not None
        assert EntitlementLedgerProtocol is not None
        assert SurpriseIntegrationProtocol is not None
        assert DecaySchedulerProtocol is not None
        assert FeedbackCollectorProtocol is not None
        assert TenantPartitionProtocol is not None

    def test_import_constants(self) -> None:
        from core.rlm.contracts import (
            DEFAULT_DECAY_HALF_LIFE,
            EMBEDDING_DIM,
            MAX_CONTENT_LENGTH,
            REDIS_TTL_ENTITLEMENT,
            SURPRISE_THRESHOLD,
        )

        assert EMBEDDING_DIM is not None
        assert DEFAULT_DECAY_HALF_LIFE is not None
        assert SURPRISE_THRESHOLD is not None
        assert MAX_CONTENT_LENGTH is not None
        assert REDIS_TTL_ENTITLEMENT is not None


# ---------------------------------------------------------------------------
# BB2: Create MemoryRecord with only required fields -- defaults populated
# ---------------------------------------------------------------------------


class TestBB2MemoryRecordDefaults:
    """Black box test: MemoryRecord populates defaults correctly."""

    def test_required_fields_only(self) -> None:
        from core.rlm.contracts import MemoryRecord, MemoryTier

        tid = uuid4()
        record = MemoryRecord(
            tenant_id=tid,
            content="Test memory content",
            source="test_source",
            domain="test_domain",
        )
        assert record.tenant_id == tid
        assert record.content == "Test memory content"
        assert record.source == "test_source"
        assert record.domain == "test_domain"
        assert record.surprise_score == 0.0
        assert record.memory_tier == MemoryTier.WORKING
        assert record.metadata == {}
        assert isinstance(record.created_at, datetime)
        assert record.vector_id is None
        assert record.pg_id is None

    def test_explicit_overrides(self) -> None:
        from core.rlm.contracts import MemoryRecord, MemoryTier

        tid = uuid4()
        record = MemoryRecord(
            tenant_id=tid,
            content="High surprise content",
            source="api",
            domain="knowledge",
            surprise_score=0.92,
            memory_tier=MemoryTier.SEMANTIC,
            metadata={"key": "value"},
            vector_id="vec-123",
            pg_id=42,
        )
        assert record.surprise_score == 0.92
        assert record.memory_tier == MemoryTier.SEMANTIC
        assert record.metadata == {"key": "value"}
        assert record.vector_id == "vec-123"
        assert record.pg_id == 42


# ---------------------------------------------------------------------------
# BB3: EntitlementManifest.decay_policy defaults to "moderate"
# ---------------------------------------------------------------------------


class TestBB3EntitlementManifestDefaults:
    """Black box test: default values on EntitlementManifest."""

    def test_decay_policy_default(self) -> None:
        from core.rlm.contracts import CustomerTier, EntitlementManifest

        manifest = EntitlementManifest(
            tenant_id=uuid4(),
            tier=CustomerTier.STARTER,
        )
        assert manifest.decay_policy == "moderate"

    def test_memory_limit_default(self) -> None:
        from core.rlm.contracts import CustomerTier, EntitlementManifest

        manifest = EntitlementManifest(
            tenant_id=uuid4(),
            tier=CustomerTier.PROFESSIONAL,
        )
        assert manifest.memory_limit_mb == 100
        assert manifest.max_memories_per_day == 500
        assert manifest.allowed_mcp_tools == []
        assert manifest.features == {}


# ---------------------------------------------------------------------------
# BB4: MemoryTier("working") == MemoryTier.WORKING
# ---------------------------------------------------------------------------


class TestBB4MemoryTierStringConstruction:
    """Black box test: enum construction from string value."""

    def test_string_construction(self) -> None:
        from core.rlm.contracts import MemoryTier

        assert MemoryTier("working") == MemoryTier.WORKING

    def test_all_tiers_from_string(self) -> None:
        from core.rlm.contracts import MemoryTier

        assert MemoryTier("discard") == MemoryTier.DISCARD
        assert MemoryTier("working") == MemoryTier.WORKING
        assert MemoryTier("episodic") == MemoryTier.EPISODIC
        assert MemoryTier("semantic") == MemoryTier.SEMANTIC

    def test_memory_tier_has_exactly_four_values(self) -> None:
        from core.rlm.contracts import MemoryTier

        members = list(MemoryTier)
        assert len(members) == 4
        expected_names = {"DISCARD", "WORKING", "EPISODIC", "SEMANTIC"}
        assert {m.name for m in members} == expected_names


# ---------------------------------------------------------------------------
# WB1: Partial implementation of MemoryGatewayProtocol -- structural check
# ---------------------------------------------------------------------------


class TestWB1ProtocolStructuralCheck:
    """White box test: Protocol structural typing enforcement.

    We verify that the protocol classes have the expected methods by
    inspecting their type hints. Python's Protocol is structural --
    runtime isinstance checks need runtime_checkable, but we can verify
    the protocol defines the expected methods.
    """

    def test_memory_gateway_protocol_methods(self) -> None:
        from core.rlm.contracts import MemoryGatewayProtocol

        hints = typing.get_type_hints(MemoryGatewayProtocol)
        # Protocol class-level hints may be empty; check method-level
        expected_methods = [
            "write_memory",
            "read_memories",
            "delete_memory",
            "search_memories",
        ]
        for method_name in expected_methods:
            assert hasattr(MemoryGatewayProtocol, method_name), (
                f"MemoryGatewayProtocol missing method: {method_name}"
            )

    def test_all_protocols_have_expected_methods(self) -> None:
        from core.rlm.contracts import (
            DecaySchedulerProtocol,
            EntitlementLedgerProtocol,
            FeedbackCollectorProtocol,
            MemoryGatewayProtocol,
            SurpriseIntegrationProtocol,
            TenantPartitionProtocol,
        )

        protocol_methods = {
            MemoryGatewayProtocol: [
                "write_memory",
                "read_memories",
                "delete_memory",
                "search_memories",
            ],
            EntitlementLedgerProtocol: [
                "get_manifest",
                "update_tier",
                "check_quota",
            ],
            SurpriseIntegrationProtocol: [
                "score_content",
            ],
            DecaySchedulerProtocol: [
                "run_decay_cycle",
                "get_decay_stats",
            ],
            FeedbackCollectorProtocol: [
                "record_feedback",
                "get_pair_count",
            ],
            TenantPartitionProtocol: [
                "create_tenant",
                "delete_tenant_data",
                "verify_isolation",
            ],
        }

        for protocol_cls, methods in protocol_methods.items():
            for method_name in methods:
                assert hasattr(protocol_cls, method_name), (
                    f"{protocol_cls.__name__} missing method: {method_name}"
                )


# ---------------------------------------------------------------------------
# WB2: PreferencePair.confidence default is 1.0 via dataclasses.fields()
# ---------------------------------------------------------------------------


class TestWB2PreferencePairFieldDefaults:
    """White box test: inspect dataclass field metadata."""

    def test_confidence_default_via_fields(self) -> None:
        from core.rlm.contracts import PreferencePair

        fields_map = {f.name: f for f in dataclasses.fields(PreferencePair)}
        confidence_field = fields_map["confidence"]
        assert confidence_field.default == 1.0

    def test_annotator_id_default_via_fields(self) -> None:
        from core.rlm.contracts import PreferencePair

        fields_map = {f.name: f for f in dataclasses.fields(PreferencePair)}
        annotator_field = fields_map["annotator_id"]
        assert annotator_field.default == "telegram_feedback"


# ---------------------------------------------------------------------------
# AC1: All 6 Protocol classes inspectable via typing.get_type_hints()
# ---------------------------------------------------------------------------


class TestAC1ProtocolTypeHints:
    """Additional coverage: protocol type hint inspection."""

    def test_protocols_are_inspectable(self) -> None:
        from core.rlm.contracts import (
            DecaySchedulerProtocol,
            EntitlementLedgerProtocol,
            FeedbackCollectorProtocol,
            MemoryGatewayProtocol,
            SurpriseIntegrationProtocol,
            TenantPartitionProtocol,
        )

        protocols = [
            MemoryGatewayProtocol,
            EntitlementLedgerProtocol,
            SurpriseIntegrationProtocol,
            DecaySchedulerProtocol,
            FeedbackCollectorProtocol,
            TenantPartitionProtocol,
        ]
        for proto in protocols:
            # Should not raise
            hints = typing.get_type_hints(proto)
            assert isinstance(hints, dict)

    def test_protocol_method_return_types(self) -> None:
        """Verify key method return type annotations exist."""
        from core.rlm.contracts import MemoryGatewayProtocol

        write_hints = typing.get_type_hints(
            MemoryGatewayProtocol.write_memory
        )
        assert "return" in write_hints


# ---------------------------------------------------------------------------
# AC2: Constants have correct values
# ---------------------------------------------------------------------------


class TestAC2Constants:
    """Additional coverage: verify constant values match spec."""

    def test_embedding_dim(self) -> None:
        from core.rlm.contracts import EMBEDDING_DIM

        assert EMBEDDING_DIM == 768
        assert isinstance(EMBEDDING_DIM, int)

    def test_default_decay_half_life(self) -> None:
        from core.rlm.contracts import DEFAULT_DECAY_HALF_LIFE

        assert DEFAULT_DECAY_HALF_LIFE == 7.0
        assert isinstance(DEFAULT_DECAY_HALF_LIFE, float)

    def test_surprise_threshold(self) -> None:
        from core.rlm.contracts import SURPRISE_THRESHOLD

        assert SURPRISE_THRESHOLD == 0.50
        assert isinstance(SURPRISE_THRESHOLD, float)

    def test_max_content_length(self) -> None:
        from core.rlm.contracts import MAX_CONTENT_LENGTH

        assert MAX_CONTENT_LENGTH == 32_000
        assert isinstance(MAX_CONTENT_LENGTH, int)

    def test_redis_ttl_entitlement(self) -> None:
        from core.rlm.contracts import REDIS_TTL_ENTITLEMENT

        assert REDIS_TTL_ENTITLEMENT == 300
        assert isinstance(REDIS_TTL_ENTITLEMENT, int)


# ---------------------------------------------------------------------------
# AC3: FeedbackSignal enum values
# ---------------------------------------------------------------------------


class TestAC3FeedbackSignal:
    """Additional coverage: FeedbackSignal enum integrity."""

    def test_values(self) -> None:
        from core.rlm.contracts import FeedbackSignal

        assert FeedbackSignal.NEGATIVE.value == -1
        assert FeedbackSignal.NEUTRAL.value == 0
        assert FeedbackSignal.POSITIVE.value == 1

    def test_member_count(self) -> None:
        from core.rlm.contracts import FeedbackSignal

        assert len(list(FeedbackSignal)) == 3


# ---------------------------------------------------------------------------
# AC4: CustomerTier has exactly 4 values
# ---------------------------------------------------------------------------


class TestAC4CustomerTier:
    """Additional coverage: CustomerTier enum integrity."""

    def test_member_count(self) -> None:
        from core.rlm.contracts import CustomerTier

        members = list(CustomerTier)
        assert len(members) == 4

    def test_expected_values(self) -> None:
        from core.rlm.contracts import CustomerTier

        expected = {"starter", "professional", "enterprise", "queen"}
        actual = {m.value for m in CustomerTier}
        assert actual == expected

    def test_string_construction(self) -> None:
        from core.rlm.contracts import CustomerTier

        assert CustomerTier("starter") == CustomerTier.STARTER
        assert CustomerTier("professional") == CustomerTier.PROFESSIONAL
        assert CustomerTier("enterprise") == CustomerTier.ENTERPRISE
        assert CustomerTier("queen") == CustomerTier.QUEEN


# ---------------------------------------------------------------------------
# AC5: MemoryRecord serialization round-trip via dataclasses.asdict()
# ---------------------------------------------------------------------------


class TestAC5Serialization:
    """Additional coverage: dataclass serialization."""

    def test_memory_record_asdict(self) -> None:
        from core.rlm.contracts import MemoryRecord, MemoryTier

        tid = uuid4()
        record = MemoryRecord(
            tenant_id=tid,
            content="test",
            source="api",
            domain="general",
        )
        d = dataclasses.asdict(record)
        assert d["tenant_id"] == tid
        assert d["content"] == "test"
        assert d["surprise_score"] == 0.0
        assert d["memory_tier"] == MemoryTier.WORKING.value

    def test_preference_pair_asdict(self) -> None:
        from core.rlm.contracts import PreferencePair

        pair = PreferencePair(
            input_text="What is X?",
            chosen_output="X is good",
            rejected_output="X is bad",
        )
        d = dataclasses.asdict(pair)
        assert d["input_text"] == "What is X?"
        assert d["confidence"] == 1.0
        assert d["annotator_id"] == "telegram_feedback"


# ---------------------------------------------------------------------------
# AC6: Package-level imports from core.rlm work correctly
# ---------------------------------------------------------------------------


class TestAC6PackageLevelImports:
    """Additional coverage: importing from core.rlm (not core.rlm.contracts)."""

    def test_import_enums_from_package(self) -> None:
        from core.rlm import CustomerTier, FeedbackSignal, MemoryTier

        assert MemoryTier.WORKING.value == "working"
        assert CustomerTier.STARTER.value == "starter"
        assert FeedbackSignal.POSITIVE.value == 1

    def test_import_dataclasses_from_package(self) -> None:
        from core.rlm import EntitlementManifest, MemoryRecord, PreferencePair

        assert MemoryRecord is not None
        assert EntitlementManifest is not None
        assert PreferencePair is not None

    def test_import_protocols_from_package(self) -> None:
        from core.rlm import (
            DecaySchedulerProtocol,
            EntitlementLedgerProtocol,
            FeedbackCollectorProtocol,
            MemoryGatewayProtocol,
            SurpriseIntegrationProtocol,
            TenantPartitionProtocol,
        )

        assert MemoryGatewayProtocol is not None
        assert EntitlementLedgerProtocol is not None
        assert SurpriseIntegrationProtocol is not None
        assert DecaySchedulerProtocol is not None
        assert FeedbackCollectorProtocol is not None
        assert TenantPartitionProtocol is not None

    def test_import_constants_from_package(self) -> None:
        from core.rlm import (
            DEFAULT_DECAY_HALF_LIFE,
            EMBEDDING_DIM,
            MAX_CONTENT_LENGTH,
            REDIS_TTL_ENTITLEMENT,
            SURPRISE_THRESHOLD,
        )

        assert EMBEDDING_DIM == 768
        assert SURPRISE_THRESHOLD == 0.50
        assert DEFAULT_DECAY_HALF_LIFE == 7.0
        assert MAX_CONTENT_LENGTH == 32_000
        assert REDIS_TTL_ENTITLEMENT == 300
