"""
AIVA Queen - Reward Modeling System for RLHF
=============================================

Complete production implementation of Reward Modeling for Reinforcement Learning
from Human Feedback (RLHF). This module provides the foundation for training
AIVA to align with human preferences.

Components:
- RewardModel: Neural network for reward prediction
- PreferenceDataset: Dataset of human preference pairs
- RewardTrainer: Training loop for reward model
- RewardInference: Real-time reward scoring
- HumanFeedbackCollector: Interface for gathering feedback

Author: AIVA Queen System
Version: 1.0.0
"""

from __future__ import annotations

import asyncio
import hashlib
import json
import logging
import math
import os
import pickle
import random
import sys
sys.path.append('/mnt/e/genesis-system/data/genesis-memory')
from elestio_config import PostgresConfig
import psycopg2
import psycopg2.extras
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta
from enum import Enum, auto
from pathlib import Path
from typing import (
    Any,
    AsyncIterator,
    Callable,
    Dict,
    Generic,
    Iterator,
    List,
    Optional,
    Protocol,
    Tuple,
    TypeVar,
    Union,
)

import numpy as np

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("AIVA.RewardModel")


# =============================================================================
# TYPE DEFINITIONS AND PROTOCOLS
# =============================================================================

T = TypeVar("T")
EmbeddingType = Union[List[float], np.ndarray]


class EmbeddingProvider(Protocol):
    """Protocol for embedding generation."""

    async def embed(self, text: str) -> EmbeddingType:
        """Generate embedding for text."""
        ...

    async def embed_batch(self, texts: List[str]) -> List[EmbeddingType]:
        """Generate embeddings for multiple texts."""
        ...


class FeedbackSource(Enum):
    """Sources of human feedback."""
    DIRECT_RATING = auto()
    COMPARISON = auto()
    IMPLICIT_SIGNAL = auto()
    EXPERT_ANNOTATION = auto()
    CROWDSOURCE = auto()
    AUTOMATED = auto()


class PreferenceType(Enum):
    """Types of preference signals."""
    STRONGLY_PREFER_A = auto()
    PREFER_A = auto()
    SLIGHTLY_PREFER_A = auto()
    EQUAL = auto()
    SLIGHTLY_PREFER_B = auto()
    PREFER_B = auto()
    STRONGLY_PREFER_B = auto()


# =============================================================================
# DATA STRUCTURES
# =============================================================================

@dataclass
class Response:
    """Represents a model response."""
    id: str
    prompt: str
    text: str
    metadata: Dict[str, Any] = field(default_factory=dict)
    embedding: Optional[EmbeddingType] = None
    timestamp: float = field(default_factory=time.time)
    model_id: Optional[str] = None

    def __hash__(self) -> int:
        return hash(self.id)

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        return {
            "id": self.id,
            "prompt": self.prompt,
            "text": self.text,
            "metadata": self.metadata,
            "timestamp": self.timestamp,
            "model_id": self.model_id
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Response":
        """Create from dictionary."""
        return cls(
            id=data["id"],
            prompt=data["prompt"],
            text=data["text"],
            metadata=data.get("metadata", {}),
            timestamp=data.get("timestamp", time.time()),
            model_id=data.get("model_id")
        )


@dataclass
class PreferencePair:
    """A pair of responses with preference annotation."""
    id: str
    prompt: str
    response_a: Response
    response_b: Response
    preference: PreferenceType
    annotator_id: str
    source: FeedbackSource
    confidence: float = 1.0
    reasoning: Optional[str] = None
    metadata: Dict[str, Any] = field(default_factory=dict)
    timestamp: float = field(default_factory=time.time)
    validated: bool = False

    @property
    def margin(self) -> float:
        """Get preference margin for Bradley-Terry model."""
        margins = {
            PreferenceType.STRONGLY_PREFER_A: 1.0,
            PreferenceType.PREFER_A: 0.7,
            PreferenceType.SLIGHTLY_PREFER_A: 0.4,
            PreferenceType.EQUAL: 0.0,
            PreferenceType.SLIGHTLY_PREFER_B: -0.4,
            PreferenceType.PREFER_B: -0.7,
            PreferenceType.STRONGLY_PREFER_B: -1.0,
        }
        return margins.get(self.preference, 0.0)

    @property
    def label(self) -> float:
        """Get binary label (probability that A is preferred)."""
        margin = self.margin
        return (margin + 1.0) / 2.0  # Map [-1, 1] to [0, 1]

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            "id": self.id,
            "prompt": self.prompt,
            "response_a": self.response_a.to_dict(),
            "response_b": self.response_b.to_dict(),
            "preference": self.preference.name,
            "annotator_id": self.annotator_id,
            "source": self.source.name,
            "confidence": self.confidence,
            "reasoning": self.reasoning,
            "metadata": self.metadata,
            "timestamp": self.timestamp,
            "validated": self.validated
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "PreferencePair":
        """Create from dictionary."""
        return cls(
            id=data["id"],
            prompt=data["prompt"],
            response_a=Response.from_dict(data["response_a"]),
            response_b=Response.from_dict(data["response_b"]),
            preference=PreferenceType[data["preference"]],
            annotator_id=data["annotator_id"],
            source=FeedbackSource[data["source"]],
            confidence=data.get("confidence", 1.0),
            reasoning=data.get("reasoning"),
            metadata=data.get("metadata", {}),
            timestamp=data.get("timestamp", time.time()),
            validated=data.get("validated", False)
        )


@dataclass
class RewardScore:
    """Reward score output."""
    score: float
    confidence: float
    components: Dict[str, float] = field(default_factory=dict)
    explanation: Optional[str] = None
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


@dataclass
class TrainingMetrics:
    """Training metrics for reward model."""
    epoch: int
    loss: float
    accuracy: float
    agreement_rate: float
    learning_rate: float
    batch_size: int
    timestamp: float = field(default_factory=time.time)
    additional_metrics: Dict[str, float] = field(default_factory=dict)


# =============================================================================
# PREFERENCE DATASET
# =============================================================================

class PreferenceDataset:
    """
    Dataset for managing human preference pairs.

    Features:
    - Persistent storage with PostgreSQL (Elestio)
    - Efficient batch iteration
    - Stratified sampling by preference type
    - Quality filtering and validation
    - Cross-annotator agreement tracking
    """

    def __init__(
        self,
        db_path: Optional[str] = None,
        min_confidence: float = 0.5,
        require_validation: bool = False
    ):
        """
        Initialize preference dataset.

        Args:
            db_path: Deprecated (ignored). Uses Elestio PostgreSQL.
            min_confidence: Minimum confidence threshold
            require_validation: Whether to require validated pairs only
        """
        self.min_confidence = min_confidence
        self.require_validation = require_validation
        self._pairs: Dict[str, PreferencePair] = {}
        self._annotator_stats: Dict[str, Dict[str, int]] = {}
        self._prompt_index: Dict[str, List[str]] = {}

        self._init_database()
        self._load_from_database()

    def _get_conn(self):
        """Get a PostgreSQL connection from Elestio."""
        return psycopg2.connect(**PostgresConfig.get_connection_params())

    def _init_database(self) -> None:
        """Ensure PostgreSQL tables exist (tables should already exist in PG)."""
        conn = self._get_conn()
        try:
            with conn.cursor() as cur:
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS rlm_preference_pairs (
                        id TEXT PRIMARY KEY,
                        prompt TEXT NOT NULL,
                        response_a_id TEXT NOT NULL,
                        response_a_text TEXT NOT NULL,
                        response_b_id TEXT NOT NULL,
                        response_b_text TEXT NOT NULL,
                        preference TEXT NOT NULL,
                        annotator_id TEXT NOT NULL,
                        source TEXT NOT NULL,
                        confidence REAL DEFAULT 1.0,
                        reasoning TEXT,
                        metadata TEXT,
                        timestamp DOUBLE PRECISION,
                        validated INTEGER DEFAULT 0
                    )
                """)

                cur.execute("""
                    CREATE TABLE IF NOT EXISTS rlm_annotator_stats (
                        annotator_id TEXT PRIMARY KEY,
                        total_annotations INTEGER DEFAULT 0,
                        agreement_rate REAL DEFAULT 0.0,
                        avg_confidence REAL DEFAULT 0.0,
                        last_active DOUBLE PRECISION
                    )
                """)
            conn.commit()
        finally:
            conn.close()

    def _load_from_database(self) -> None:
        """Load existing pairs from database."""
        conn = self._get_conn()
        try:
            with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
                cur.execute("SELECT * FROM rlm_preference_pairs")

                for row in cur:
                    pair = PreferencePair(
                        id=row["id"],
                        prompt=row["prompt"],
                        response_a=Response(
                            id=row["response_a_id"],
                            prompt=row["prompt"],
                            text=row["response_a_text"]
                        ),
                        response_b=Response(
                            id=row["response_b_id"],
                            prompt=row["prompt"],
                            text=row["response_b_text"]
                        ),
                        preference=PreferenceType[row["preference"]],
                        annotator_id=row["annotator_id"],
                        source=FeedbackSource[row["source"]],
                        confidence=row["confidence"],
                        reasoning=row["reasoning"],
                        metadata=json.loads(row["metadata"]) if row["metadata"] else {},
                        timestamp=row["timestamp"],
                        validated=bool(row["validated"])
                    )
                    self._pairs[pair.id] = pair

                    # Update prompt index
                    if pair.prompt not in self._prompt_index:
                        self._prompt_index[pair.prompt] = []
                    self._prompt_index[pair.prompt].append(pair.id)
        finally:
            conn.close()

    def add(self, pair: PreferencePair) -> str:
        """
        Add a preference pair to the dataset.

        Args:
            pair: Preference pair to add

        Returns:
            ID of the added pair
        """
        self._pairs[pair.id] = pair

        # Update prompt index
        if pair.prompt not in self._prompt_index:
            self._prompt_index[pair.prompt] = []
        self._prompt_index[pair.prompt].append(pair.id)

        # Update annotator stats
        if pair.annotator_id not in self._annotator_stats:
            self._annotator_stats[pair.annotator_id] = {
                "total": 0,
                "by_type": {}
            }
        self._annotator_stats[pair.annotator_id]["total"] += 1

        # Persist to database
        self._save_pair(pair)

        logger.info(f"Added preference pair {pair.id}")
        return pair.id

    def _save_pair(self, pair: PreferencePair) -> None:
        """Save pair to database."""
        conn = self._get_conn()
        try:
            with conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO rlm_preference_pairs
                    (id, prompt, response_a_id, response_a_text, response_b_id,
                     response_b_text, preference, annotator_id, source, confidence,
                     reasoning, metadata, timestamp, validated)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                    ON CONFLICT (id) DO UPDATE SET
                        prompt = EXCLUDED.prompt,
                        response_a_id = EXCLUDED.response_a_id,
                        response_a_text = EXCLUDED.response_a_text,
                        response_b_id = EXCLUDED.response_b_id,
                        response_b_text = EXCLUDED.response_b_text,
                        preference = EXCLUDED.preference,
                        annotator_id = EXCLUDED.annotator_id,
                        source = EXCLUDED.source,
                        confidence = EXCLUDED.confidence,
                        reasoning = EXCLUDED.reasoning,
                        metadata = EXCLUDED.metadata,
                        timestamp = EXCLUDED.timestamp,
                        validated = EXCLUDED.validated
                """, (
                    pair.id,
                    pair.prompt,
                    pair.response_a.id,
                    pair.response_a.text,
                    pair.response_b.id,
                    pair.response_b.text,
                    pair.preference.name,
                    pair.annotator_id,
                    pair.source.name,
                    pair.confidence,
                    pair.reasoning,
                    json.dumps(pair.metadata),
                    pair.timestamp,
                    int(pair.validated)
                ))
            conn.commit()
        finally:
            conn.close()

    def get(self, pair_id: str) -> Optional[PreferencePair]:
        """Get a preference pair by ID."""
        return self._pairs.get(pair_id)

    def get_by_prompt(self, prompt: str) -> List[PreferencePair]:
        """Get all preference pairs for a prompt."""
        pair_ids = self._prompt_index.get(prompt, [])
        return [self._pairs[pid] for pid in pair_ids if pid in self._pairs]

    def remove(self, pair_id: str) -> bool:
        """Remove a preference pair."""
        if pair_id not in self._pairs:
            return False

        pair = self._pairs.pop(pair_id)

        # Update prompt index
        if pair.prompt in self._prompt_index:
            self._prompt_index[pair.prompt] = [
                pid for pid in self._prompt_index[pair.prompt]
                if pid != pair_id
            ]

        # Remove from database
        conn = self._get_conn()
        try:
            with conn.cursor() as cur:
                cur.execute("DELETE FROM rlm_preference_pairs WHERE id = %s", (pair_id,))
            conn.commit()
        finally:
            conn.close()

        return True

    def __len__(self) -> int:
        return len(self._pairs)

    def __iter__(self) -> Iterator[PreferencePair]:
        """Iterate over all pairs with filtering."""
        for pair in self._pairs.values():
            if pair.confidence >= self.min_confidence:
                if not self.require_validation or pair.validated:
                    yield pair

    def get_training_pairs(
        self,
        batch_size: int = 32,
        shuffle: bool = True,
        stratify: bool = True
    ) -> Iterator[List[PreferencePair]]:
        """
        Get training batches of preference pairs.

        Args:
            batch_size: Number of pairs per batch
            shuffle: Whether to shuffle pairs
            stratify: Whether to stratify by preference type

        Yields:
            Batches of preference pairs
        """
        pairs = list(self)

        if stratify:
            # Group by preference type
            by_type: Dict[PreferenceType, List[PreferencePair]] = {}
            for pair in pairs:
                if pair.preference not in by_type:
                    by_type[pair.preference] = []
                by_type[pair.preference].append(pair)

            # Interleave types
            pairs = []
            max_len = max(len(v) for v in by_type.values()) if by_type else 0
            for i in range(max_len):
                for ptype in by_type:
                    if i < len(by_type[ptype]):
                        pairs.append(by_type[ptype][i])

        if shuffle:
            random.shuffle(pairs)

        # Yield batches
        for i in range(0, len(pairs), batch_size):
            yield pairs[i:i + batch_size]

    def get_statistics(self) -> Dict[str, Any]:
        """Get dataset statistics."""
        pairs = list(self)

        if not pairs:
            return {"total": 0}

        # Count by preference type
        by_type = {}
        for pair in pairs:
            if pair.preference.name not in by_type:
                by_type[pair.preference.name] = 0
            by_type[pair.preference.name] += 1

        # Count by source
        by_source = {}
        for pair in pairs:
            if pair.source.name not in by_source:
                by_source[pair.source.name] = 0
            by_source[pair.source.name] += 1

        # Annotator statistics
        annotator_counts = {}
        for pair in pairs:
            if pair.annotator_id not in annotator_counts:
                annotator_counts[pair.annotator_id] = 0
            annotator_counts[pair.annotator_id] += 1

        return {
            "total": len(pairs),
            "by_preference_type": by_type,
            "by_source": by_source,
            "unique_prompts": len(set(p.prompt for p in pairs)),
            "unique_annotators": len(annotator_counts),
            "avg_confidence": sum(p.confidence for p in pairs) / len(pairs),
            "validated_count": sum(1 for p in pairs if p.validated),
            "annotator_distribution": annotator_counts
        }

    def compute_inter_annotator_agreement(
        self,
        prompt: str
    ) -> Optional[float]:
        """
        Compute inter-annotator agreement for a prompt.

        Uses Fleiss' Kappa for multiple annotators.

        Args:
            prompt: The prompt to analyze

        Returns:
            Agreement score (0-1) or None if insufficient data
        """
        pairs = self.get_by_prompt(prompt)

        if len(pairs) < 2:
            return None

        # Map preferences to numeric scale
        preference_values = [p.margin for p in pairs]

        # Compute variance-based agreement
        mean_pref = sum(preference_values) / len(preference_values)
        variance = sum((p - mean_pref) ** 2 for p in preference_values) / len(preference_values)

        # Normalize (max variance for binary choices is 1.0)
        agreement = 1.0 - min(variance, 1.0)

        return agreement

    def export_to_json(self, filepath: str) -> None:
        """Export dataset to JSON file."""
        pairs = [p.to_dict() for p in self._pairs.values()]
        with open(filepath, "w") as f:
            json.dump({
                "pairs": pairs,
                "statistics": self.get_statistics(),
                "exported_at": datetime.now().isoformat()
            }, f, indent=2)

    @classmethod
    def import_from_json(
        cls,
        filepath: str,
        db_path: Optional[str] = None
    ) -> "PreferenceDataset":
        """Import dataset from JSON file."""
        with open(filepath, "r") as f:
            data = json.load(f)

        dataset = cls(db_path=db_path)

        for pair_data in data.get("pairs", []):
            pair = PreferencePair.from_dict(pair_data)
            dataset.add(pair)

        return dataset


# =============================================================================
# REWARD MODEL
# =============================================================================

class RewardModelLayer(ABC):
    """Abstract base class for reward model layers."""

    @abstractmethod
    def forward(self, x: np.ndarray) -> np.ndarray:
        """Forward pass."""
        pass

    @abstractmethod
    def backward(self, grad: np.ndarray) -> np.ndarray:
        """Backward pass."""
        pass

    @abstractmethod
    def get_params(self) -> Dict[str, np.ndarray]:
        """Get layer parameters."""
        pass

    @abstractmethod
    def set_params(self, params: Dict[str, np.ndarray]) -> None:
        """Set layer parameters."""
        pass


class DenseLayer(RewardModelLayer):
    """Dense (fully connected) layer."""

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        activation: str = "relu",
        dropout: float = 0.0
    ):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.activation = activation
        self.dropout = dropout

        # Xavier initialization
        scale = np.sqrt(2.0 / (input_dim + output_dim))
        self.W = np.random.randn(input_dim, output_dim) * scale
        self.b = np.zeros(output_dim)

        # For backward pass
        self._input: Optional[np.ndarray] = None
        self._pre_activation: Optional[np.ndarray] = None
        self._dropout_mask: Optional[np.ndarray] = None

        # Gradients
        self.dW: Optional[np.ndarray] = None
        self.db: Optional[np.ndarray] = None

    def _activate(self, x: np.ndarray) -> np.ndarray:
        """Apply activation function."""
        if self.activation == "relu":
            return np.maximum(0, x)
        elif self.activation == "gelu":
            return 0.5 * x * (1 + np.tanh(
                np.sqrt(2 / np.pi) * (x + 0.044715 * x ** 3)
            ))
        elif self.activation == "sigmoid":
            return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
        elif self.activation == "tanh":
            return np.tanh(x)
        elif self.activation == "none" or self.activation is None:
            return x
        else:
            return x

    def _activate_grad(self, x: np.ndarray) -> np.ndarray:
        """Compute gradient of activation function."""
        if self.activation == "relu":
            return (x > 0).astype(float)
        elif self.activation == "sigmoid":
            sig = 1 / (1 + np.exp(-np.clip(x, -500, 500)))
            return sig * (1 - sig)
        elif self.activation == "tanh":
            return 1 - np.tanh(x) ** 2
        elif self.activation == "gelu":
            # Approximate GELU gradient
            return 0.5 * (1 + np.tanh(
                np.sqrt(2 / np.pi) * (x + 0.044715 * x ** 3)
            )) + 0.5 * x * (1 - np.tanh(
                np.sqrt(2 / np.pi) * (x + 0.044715 * x ** 3)
            ) ** 2) * np.sqrt(2 / np.pi) * (1 + 0.134145 * x ** 2)
        else:
            return np.ones_like(x)

    def forward(self, x: np.ndarray, training: bool = True) -> np.ndarray:
        """Forward pass."""
        self._input = x
        self._pre_activation = x @ self.W + self.b
        output = self._activate(self._pre_activation)

        # Apply dropout during training
        if training and self.dropout > 0:
            self._dropout_mask = (
                np.random.rand(*output.shape) > self.dropout
            ).astype(float) / (1 - self.dropout)
            output = output * self._dropout_mask
        else:
            self._dropout_mask = None

        return output

    def backward(self, grad: np.ndarray) -> np.ndarray:
        """Backward pass."""
        if self._input is None or self._pre_activation is None:
            raise RuntimeError("Forward pass must be called before backward")

        # Apply dropout mask if used
        if self._dropout_mask is not None:
            grad = grad * self._dropout_mask

        # Gradient through activation
        grad = grad * self._activate_grad(self._pre_activation)

        # Compute parameter gradients
        self.dW = self._input.T @ grad / grad.shape[0]
        self.db = np.mean(grad, axis=0)

        # Compute input gradient
        return grad @ self.W.T

    def get_params(self) -> Dict[str, np.ndarray]:
        return {"W": self.W.copy(), "b": self.b.copy()}

    def set_params(self, params: Dict[str, np.ndarray]) -> None:
        self.W = params["W"].copy()
        self.b = params["b"].copy()


class LayerNorm(RewardModelLayer):
    """Layer normalization."""

    def __init__(self, dim: int, eps: float = 1e-5):
        self.dim = dim
        self.eps = eps
        self.gamma = np.ones(dim)
        self.beta = np.zeros(dim)

        self._input: Optional[np.ndarray] = None
        self._mean: Optional[np.ndarray] = None
        self._var: Optional[np.ndarray] = None
        self._normalized: Optional[np.ndarray] = None

        self.dgamma: Optional[np.ndarray] = None
        self.dbeta: Optional[np.ndarray] = None

    def forward(self, x: np.ndarray, training: bool = True) -> np.ndarray:
        self._input = x
        self._mean = np.mean(x, axis=-1, keepdims=True)
        self._var = np.var(x, axis=-1, keepdims=True)
        self._normalized = (x - self._mean) / np.sqrt(self._var + self.eps)
        return self.gamma * self._normalized + self.beta

    def backward(self, grad: np.ndarray) -> np.ndarray:
        if self._normalized is None:
            raise RuntimeError("Forward must be called before backward")

        self.dgamma = np.sum(grad * self._normalized, axis=0)
        self.dbeta = np.sum(grad, axis=0)

        N = grad.shape[-1]
        dnorm = grad * self.gamma
        dvar = np.sum(
            dnorm * (self._input - self._mean) * -0.5 *
            (self._var + self.eps) ** -1.5,
            axis=-1, keepdims=True
        )
        dmean = (
            np.sum(dnorm * -1 / np.sqrt(self._var + self.eps), axis=-1, keepdims=True)
            + dvar * np.sum(-2 * (self._input - self._mean), axis=-1, keepdims=True) / N
        )

        dx = (
            dnorm / np.sqrt(self._var + self.eps)
            + dvar * 2 * (self._input - self._mean) / N
            + dmean / N
        )
        return dx

    def get_params(self) -> Dict[str, np.ndarray]:
        return {"gamma": self.gamma.copy(), "beta": self.beta.copy()}

    def set_params(self, params: Dict[str, np.ndarray]) -> None:
        self.gamma = params["gamma"].copy()
        self.beta = params["beta"].copy()


class RewardModel:
    """
    Neural network for reward prediction based on human preferences.

    Architecture:
    - Input: Concatenated embeddings of prompt and response
    - Multiple dense layers with residual connections
    - Layer normalization
    - Scalar reward output

    Training:
    - Bradley-Terry preference model
    - Pairwise ranking loss
    - L2 regularization
    """

    def __init__(
        self,
        embedding_dim: int = 768,
        hidden_dims: List[int] = None,
        dropout: float = 0.1,
        l2_reg: float = 0.01,
        use_layer_norm: bool = True
    ):
        """
        Initialize reward model.

        Args:
            embedding_dim: Dimension of input embeddings
            hidden_dims: List of hidden layer dimensions
            dropout: Dropout rate
            l2_reg: L2 regularization strength
            use_layer_norm: Whether to use layer normalization
        """
        self.embedding_dim = embedding_dim
        self.hidden_dims = hidden_dims or [512, 256, 128]
        self.dropout = dropout
        self.l2_reg = l2_reg
        self.use_layer_norm = use_layer_norm

        self.layers: List[RewardModelLayer] = []
        self.layer_norms: List[Optional[LayerNorm]] = []

        self._build_network()

        # Training state
        self.training_mode = True
        self._optimizer_state: Dict[str, Any] = {}

    def _build_network(self) -> None:
        """Build the neural network architecture."""
        input_dim = self.embedding_dim * 2  # Prompt + response

        prev_dim = input_dim
        for i, hidden_dim in enumerate(self.hidden_dims):
            activation = "gelu" if i < len(self.hidden_dims) - 1 else "tanh"
            layer = DenseLayer(
                prev_dim, hidden_dim,
                activation=activation,
                dropout=self.dropout if i < len(self.hidden_dims) - 1 else 0
            )
            self.layers.append(layer)

            if self.use_layer_norm and i < len(self.hidden_dims) - 1:
                self.layer_norms.append(LayerNorm(hidden_dim))
            else:
                self.layer_norms.append(None)

            prev_dim = hidden_dim

        # Output layer (scalar reward)
        self.output_layer = DenseLayer(
            self.hidden_dims[-1], 1,
            activation="none",
            dropout=0
        )

    def forward(
        self,
        prompt_embedding: np.ndarray,
        response_embedding: np.ndarray
    ) -> np.ndarray:
        """
        Forward pass to compute reward.

        Args:
            prompt_embedding: Prompt embedding [batch, embedding_dim]
            response_embedding: Response embedding [batch, embedding_dim]

        Returns:
            Reward scores [batch, 1]
        """
        # Concatenate prompt and response embeddings
        x = np.concatenate([prompt_embedding, response_embedding], axis=-1)

        # Pass through layers
        for i, layer in enumerate(self.layers):
            residual = x if x.shape[-1] == self.hidden_dims[i] else None
            x = layer.forward(x, training=self.training_mode)

            if self.layer_norms[i] is not None:
                x = self.layer_norms[i].forward(x, training=self.training_mode)

            # Residual connection if dimensions match
            if residual is not None:
                x = x + residual

        # Output layer
        reward = self.output_layer.forward(x, training=self.training_mode)

        return reward

    def compute_preference_loss(
        self,
        reward_a: np.ndarray,
        reward_b: np.ndarray,
        labels: np.ndarray
    ) -> Tuple[float, np.ndarray, np.ndarray]:
        """
        Compute Bradley-Terry preference loss.

        Loss = -log(sigmoid(reward_preferred - reward_rejected))

        Args:
            reward_a: Rewards for response A [batch, 1]
            reward_b: Rewards for response B [batch, 1]
            labels: Probability that A is preferred [batch, 1]

        Returns:
            Tuple of (loss, grad_a, grad_b)
        """
        # Compute log-likelihood under Bradley-Terry model
        diff = reward_a - reward_b
        log_prob_a = -np.logaddexp(0, -diff)  # log(sigmoid(diff))
        log_prob_b = -np.logaddexp(0, diff)   # log(sigmoid(-diff))

        # Weighted combination based on labels
        loss = -np.mean(
            labels * log_prob_a + (1 - labels) * log_prob_b
        )

        # Compute gradients
        sigmoid_diff = 1 / (1 + np.exp(-np.clip(diff, -500, 500)))
        grad = labels - sigmoid_diff

        grad_a = -grad / diff.shape[0]
        grad_b = grad / diff.shape[0]

        return float(loss), grad_a, grad_b

    def backward(
        self,
        grad_a: np.ndarray,
        grad_b: np.ndarray
    ) -> None:
        """
        Backward pass through the network.

        Args:
            grad_a: Gradient w.r.t. reward_a
            grad_b: Gradient w.r.t. reward_b
        """
        # Average gradients for shared parameters
        grad = (grad_a + grad_b) / 2

        # Backward through output layer
        grad = self.output_layer.backward(grad)

        # Backward through hidden layers
        for i in range(len(self.layers) - 1, -1, -1):
            if self.layer_norms[i] is not None:
                grad = self.layer_norms[i].backward(grad)
            grad = self.layers[i].backward(grad)

    def get_all_params(self) -> Dict[str, np.ndarray]:
        """Get all model parameters."""
        params = {}

        for i, layer in enumerate(self.layers):
            for name, value in layer.get_params().items():
                params[f"layer_{i}_{name}"] = value

        for i, ln in enumerate(self.layer_norms):
            if ln is not None:
                for name, value in ln.get_params().items():
                    params[f"layernorm_{i}_{name}"] = value

        for name, value in self.output_layer.get_params().items():
            params[f"output_{name}"] = value

        return params

    def set_all_params(self, params: Dict[str, np.ndarray]) -> None:
        """Set all model parameters."""
        for i, layer in enumerate(self.layers):
            layer_params = {}
            for name in ["W", "b"]:
                key = f"layer_{i}_{name}"
                if key in params:
                    layer_params[name] = params[key]
            if layer_params:
                layer.set_params(layer_params)

        for i, ln in enumerate(self.layer_norms):
            if ln is not None:
                ln_params = {}
                for name in ["gamma", "beta"]:
                    key = f"layernorm_{i}_{name}"
                    if key in params:
                        ln_params[name] = params[key]
                if ln_params:
                    ln.set_params(ln_params)

        output_params = {}
        for name in ["W", "b"]:
            key = f"output_{name}"
            if key in params:
                output_params[name] = params[key]
        if output_params:
            self.output_layer.set_params(output_params)

    def compute_l2_regularization(self) -> Tuple[float, Dict[str, np.ndarray]]:
        """Compute L2 regularization loss and gradients."""
        l2_loss = 0.0
        l2_grads = {}

        params = self.get_all_params()
        for name, value in params.items():
            if "W" in name:  # Only regularize weights, not biases
                l2_loss += 0.5 * self.l2_reg * np.sum(value ** 2)
                l2_grads[name] = self.l2_reg * value

        return l2_loss, l2_grads

    def save(self, filepath: str) -> None:
        """Save model to file."""
        state = {
            "params": self.get_all_params(),
            "config": {
                "embedding_dim": self.embedding_dim,
                "hidden_dims": self.hidden_dims,
                "dropout": self.dropout,
                "l2_reg": self.l2_reg,
                "use_layer_norm": self.use_layer_norm
            }
        }
        with open(filepath, "wb") as f:
            pickle.dump(state, f)
        logger.info(f"Saved model to {filepath}")

    @classmethod
    def load(cls, filepath: str) -> "RewardModel":
        """Load model from file."""
        with open(filepath, "rb") as f:
            state = pickle.load(f)

        model = cls(**state["config"])
        model.set_all_params(state["params"])
        logger.info(f"Loaded model from {filepath}")
        return model

    def train(self) -> None:
        """Set model to training mode."""
        self.training_mode = True

    def eval(self) -> None:
        """Set model to evaluation mode."""
        self.training_mode = False


# =============================================================================
# REWARD TRAINER
# =============================================================================

class AdamOptimizer:
    """Adam optimizer implementation."""

    def __init__(
        self,
        learning_rate: float = 1e-4,
        beta1: float = 0.9,
        beta2: float = 0.999,
        epsilon: float = 1e-8,
        weight_decay: float = 0.0
    ):
        self.lr = learning_rate
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.weight_decay = weight_decay

        self.m: Dict[str, np.ndarray] = {}  # First moment
        self.v: Dict[str, np.ndarray] = {}  # Second moment
        self.t = 0  # Time step

    def step(
        self,
        params: Dict[str, np.ndarray],
        grads: Dict[str, np.ndarray]
    ) -> Dict[str, np.ndarray]:
        """
        Perform optimization step.

        Args:
            params: Model parameters
            grads: Parameter gradients

        Returns:
            Updated parameters
        """
        self.t += 1

        updated_params = {}
        for name, param in params.items():
            if name not in grads:
                updated_params[name] = param
                continue

            grad = grads[name]

            # Apply weight decay
            if self.weight_decay > 0 and "W" in name:
                grad = grad + self.weight_decay * param

            # Initialize moments
            if name not in self.m:
                self.m[name] = np.zeros_like(param)
                self.v[name] = np.zeros_like(param)

            # Update moments
            self.m[name] = self.beta1 * self.m[name] + (1 - self.beta1) * grad
            self.v[name] = self.beta2 * self.v[name] + (1 - self.beta2) * grad ** 2

            # Bias correction
            m_hat = self.m[name] / (1 - self.beta1 ** self.t)
            v_hat = self.v[name] / (1 - self.beta2 ** self.t)

            # Update parameters
            updated_params[name] = param - self.lr * m_hat / (np.sqrt(v_hat) + self.epsilon)

        return updated_params


class RewardTrainer:
    """
    Training loop for reward model.

    Features:
    - Mini-batch gradient descent
    - Learning rate scheduling
    - Early stopping
    - Checkpointing
    - Comprehensive logging
    """

    def __init__(
        self,
        model: RewardModel,
        dataset: PreferenceDataset,
        embedding_provider: Optional[EmbeddingProvider] = None,
        learning_rate: float = 1e-4,
        batch_size: int = 32,
        epochs: int = 10,
        validation_split: float = 0.1,
        early_stopping_patience: int = 3,
        checkpoint_dir: Optional[str] = None,
        log_interval: int = 10
    ):
        """
        Initialize trainer.

        Args:
            model: Reward model to train
            dataset: Preference dataset
            embedding_provider: Provider for generating embeddings
            learning_rate: Initial learning rate
            batch_size: Training batch size
            epochs: Number of training epochs
            validation_split: Fraction of data for validation
            early_stopping_patience: Epochs to wait for improvement
            checkpoint_dir: Directory for checkpoints
            log_interval: Steps between logging
        """
        self.model = model
        self.dataset = dataset
        self.embedding_provider = embedding_provider
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.epochs = epochs
        self.validation_split = validation_split
        self.early_stopping_patience = early_stopping_patience
        self.checkpoint_dir = checkpoint_dir
        self.log_interval = log_interval

        self.optimizer = AdamOptimizer(learning_rate=learning_rate)
        self.metrics_history: List[TrainingMetrics] = []

        self._best_val_loss = float("inf")
        self._patience_counter = 0

        if checkpoint_dir:
            Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)

    def _get_embedding(self, text: str) -> np.ndarray:
        """Get embedding for text (placeholder for async version)."""
        # In production, use embedding_provider
        # For now, return random embedding
        return np.random.randn(self.model.embedding_dim).astype(np.float32)

    async def _get_embedding_async(self, text: str) -> np.ndarray:
        """Get embedding for text asynchronously."""
        if self.embedding_provider:
            emb = await self.embedding_provider.embed(text)
            return np.array(emb, dtype=np.float32)
        return np.random.randn(self.model.embedding_dim).astype(np.float32)

    def _prepare_batch(
        self,
        pairs: List[PreferencePair]
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """
        Prepare a batch of preference pairs for training.

        Returns:
            Tuple of (prompt_emb, response_a_emb, response_b_emb, labels)
        """
        prompt_embs = []
        response_a_embs = []
        response_b_embs = []
        labels = []

        for pair in pairs:
            prompt_embs.append(self._get_embedding(pair.prompt))
            response_a_embs.append(self._get_embedding(pair.response_a.text))
            response_b_embs.append(self._get_embedding(pair.response_b.text))
            labels.append(pair.label)

        return (
            np.stack(prompt_embs),
            np.stack(response_a_embs),
            np.stack(response_b_embs),
            np.array(labels).reshape(-1, 1)
        )

    def _split_data(
        self
    ) -> Tuple[List[PreferencePair], List[PreferencePair]]:
        """Split dataset into train and validation sets."""
        all_pairs = list(self.dataset)
        random.shuffle(all_pairs)

        split_idx = int(len(all_pairs) * (1 - self.validation_split))
        return all_pairs[:split_idx], all_pairs[split_idx:]

    def train_epoch(
        self,
        train_pairs: List[PreferencePair]
    ) -> Tuple[float, float]:
        """
        Train for one epoch.

        Returns:
            Tuple of (average loss, accuracy)
        """
        self.model.train()

        random.shuffle(train_pairs)
        total_loss = 0.0
        correct = 0
        total = 0

        for i in range(0, len(train_pairs), self.batch_size):
            batch = train_pairs[i:i + self.batch_size]
            prompt_emb, resp_a_emb, resp_b_emb, labels = self._prepare_batch(batch)

            # Forward pass for both responses
            reward_a = self.model.forward(prompt_emb, resp_a_emb)
            reward_b = self.model.forward(prompt_emb, resp_b_emb)

            # Compute loss
            loss, grad_a, grad_b = self.model.compute_preference_loss(
                reward_a, reward_b, labels
            )

            # Add L2 regularization
            l2_loss, l2_grads = self.model.compute_l2_regularization()
            loss += l2_loss

            # Backward pass
            self.model.backward(grad_a, grad_b)

            # Collect gradients from layers
            grads = {}
            for j, layer in enumerate(self.model.layers):
                if layer.dW is not None:
                    grads[f"layer_{j}_W"] = layer.dW
                    grads[f"layer_{j}_b"] = layer.db

            for j, ln in enumerate(self.model.layer_norms):
                if ln is not None and ln.dgamma is not None:
                    grads[f"layernorm_{j}_gamma"] = ln.dgamma
                    grads[f"layernorm_{j}_beta"] = ln.dbeta

            if self.model.output_layer.dW is not None:
                grads["output_W"] = self.model.output_layer.dW
                grads["output_b"] = self.model.output_layer.db

            # Add L2 gradients
            for name, grad in l2_grads.items():
                if name in grads:
                    grads[name] = grads[name] + grad

            # Optimizer step
            params = self.model.get_all_params()
            updated_params = self.optimizer.step(params, grads)
            self.model.set_all_params(updated_params)

            total_loss += loss * len(batch)

            # Compute accuracy
            predictions = (reward_a > reward_b).astype(float)
            correct += np.sum((predictions > 0.5) == (labels > 0.5))
            total += len(batch)

            if (i // self.batch_size) % self.log_interval == 0:
                logger.debug(f"Batch {i // self.batch_size}: loss={loss:.4f}")

        avg_loss = total_loss / len(train_pairs)
        accuracy = correct / total

        return avg_loss, accuracy

    def validate(
        self,
        val_pairs: List[PreferencePair]
    ) -> Tuple[float, float]:
        """
        Validate model on validation set.

        Returns:
            Tuple of (average loss, accuracy)
        """
        self.model.eval()

        total_loss = 0.0
        correct = 0
        total = 0

        for i in range(0, len(val_pairs), self.batch_size):
            batch = val_pairs[i:i + self.batch_size]
            prompt_emb, resp_a_emb, resp_b_emb, labels = self._prepare_batch(batch)

            reward_a = self.model.forward(prompt_emb, resp_a_emb)
            reward_b = self.model.forward(prompt_emb, resp_b_emb)

            loss, _, _ = self.model.compute_preference_loss(
                reward_a, reward_b, labels
            )

            total_loss += loss * len(batch)

            predictions = (reward_a > reward_b).astype(float)
            correct += np.sum((predictions > 0.5) == (labels > 0.5))
            total += len(batch)

        avg_loss = total_loss / len(val_pairs) if val_pairs else 0
        accuracy = correct / total if total > 0 else 0

        return avg_loss, accuracy

    def train(self) -> List[TrainingMetrics]:
        """
        Run full training loop.

        Returns:
            List of training metrics per epoch
        """
        train_pairs, val_pairs = self._split_data()

        logger.info(
            f"Starting training: {len(train_pairs)} train, "
            f"{len(val_pairs)} validation pairs"
        )

        for epoch in range(self.epochs):
            start_time = time.time()

            # Train epoch
            train_loss, train_acc = self.train_epoch(train_pairs)

            # Validate
            val_loss, val_acc = self.validate(val_pairs)

            elapsed = time.time() - start_time

            # Compute agreement rate
            agreement_rate = self._compute_agreement_rate(val_pairs)

            metrics = TrainingMetrics(
                epoch=epoch + 1,
                loss=train_loss,
                accuracy=train_acc,
                agreement_rate=agreement_rate,
                learning_rate=self.learning_rate,
                batch_size=self.batch_size,
                additional_metrics={
                    "val_loss": val_loss,
                    "val_accuracy": val_acc,
                    "epoch_time": elapsed
                }
            )
            self.metrics_history.append(metrics)

            logger.info(
                f"Epoch {epoch + 1}/{self.epochs}: "
                f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
                f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}"
            )

            # Early stopping check
            if val_loss < self._best_val_loss:
                self._best_val_loss = val_loss
                self._patience_counter = 0

                if self.checkpoint_dir:
                    self._save_checkpoint(epoch + 1, is_best=True)
            else:
                self._patience_counter += 1
                if self._patience_counter >= self.early_stopping_patience:
                    logger.info(
                        f"Early stopping at epoch {epoch + 1}"
                    )
                    break

            # Regular checkpoint
            if self.checkpoint_dir and (epoch + 1) % 5 == 0:
                self._save_checkpoint(epoch + 1)

        return self.metrics_history

    def _compute_agreement_rate(
        self,
        pairs: List[PreferencePair]
    ) -> float:
        """Compute agreement rate with human preferences."""
        if not pairs:
            return 0.0

        self.model.eval()
        correct = 0

        for pair in pairs:
            prompt_emb = self._get_embedding(pair.prompt).reshape(1, -1)
            resp_a_emb = self._get_embedding(pair.response_a.text).reshape(1, -1)
            resp_b_emb = self._get_embedding(pair.response_b.text).reshape(1, -1)

            reward_a = self.model.forward(prompt_emb, resp_a_emb)
            reward_b = self.model.forward(prompt_emb, resp_b_emb)

            model_prefers_a = reward_a[0, 0] > reward_b[0, 0]
            human_prefers_a = pair.label > 0.5

            if model_prefers_a == human_prefers_a:
                correct += 1

        return correct / len(pairs)

    def _save_checkpoint(
        self,
        epoch: int,
        is_best: bool = False
    ) -> None:
        """Save model checkpoint."""
        if not self.checkpoint_dir:
            return

        filename = f"checkpoint_epoch_{epoch}.pkl"
        if is_best:
            filename = "best_model.pkl"

        filepath = os.path.join(self.checkpoint_dir, filename)
        self.model.save(filepath)

        # Save training state
        state = {
            "epoch": epoch,
            "metrics_history": [asdict(m) for m in self.metrics_history],
            "best_val_loss": self._best_val_loss,
            "optimizer_t": self.optimizer.t
        }
        state_path = os.path.join(self.checkpoint_dir, "training_state.json")
        with open(state_path, "w") as f:
            json.dump(state, f, indent=2)


# =============================================================================
# REWARD INFERENCE
# =============================================================================

class RewardInference:
    """
    Real-time reward scoring for inference.

    Features:
    - Batch scoring
    - Caching
    - Confidence estimation
    - Component-wise scoring
    """

    def __init__(
        self,
        model: RewardModel,
        embedding_provider: Optional[EmbeddingProvider] = None,
        cache_size: int = 1000,
        confidence_method: str = "ensemble"
    ):
        """
        Initialize inference engine.

        Args:
            model: Trained reward model
            embedding_provider: Provider for generating embeddings
            cache_size: Maximum cache size
            confidence_method: Method for confidence estimation
        """
        self.model = model
        self.embedding_provider = embedding_provider
        self.cache_size = cache_size
        self.confidence_method = confidence_method

        self._cache: Dict[str, RewardScore] = {}
        self._cache_order: List[str] = []

        self.model.eval()

    def _get_cache_key(self, prompt: str, response: str) -> str:
        """Generate cache key."""
        content = f"{prompt}|||{response}"
        return hashlib.md5(content.encode()).hexdigest()

    def _get_embedding(self, text: str) -> np.ndarray:
        """Get embedding synchronously (placeholder)."""
        return np.random.randn(self.model.embedding_dim).astype(np.float32)

    async def _get_embedding_async(self, text: str) -> np.ndarray:
        """Get embedding asynchronously."""
        if self.embedding_provider:
            emb = await self.embedding_provider.embed(text)
            return np.array(emb, dtype=np.float32)
        return np.random.randn(self.model.embedding_dim).astype(np.float32)

    def score(
        self,
        prompt: str,
        response: str,
        use_cache: bool = True
    ) -> RewardScore:
        """
        Score a single response.

        Args:
            prompt: The prompt
            response: The response to score
            use_cache: Whether to use caching

        Returns:
            RewardScore with score and metadata
        """
        cache_key = self._get_cache_key(prompt, response)

        if use_cache and cache_key in self._cache:
            return self._cache[cache_key]

        prompt_emb = self._get_embedding(prompt).reshape(1, -1)
        response_emb = self._get_embedding(response).reshape(1, -1)

        reward = self.model.forward(prompt_emb, response_emb)
        score_value = float(reward[0, 0])

        # Estimate confidence
        confidence = self._estimate_confidence(prompt_emb, response_emb)

        result = RewardScore(
            score=score_value,
            confidence=confidence,
            components=self._get_score_components(prompt_emb, response_emb),
            metadata={
                "prompt_length": len(prompt),
                "response_length": len(response)
            }
        )

        # Update cache
        if use_cache:
            self._update_cache(cache_key, result)

        return result

    async def score_async(
        self,
        prompt: str,
        response: str,
        use_cache: bool = True
    ) -> RewardScore:
        """
        Score a single response asynchronously.

        Args:
            prompt: The prompt
            response: The response to score
            use_cache: Whether to use caching

        Returns:
            RewardScore with score and metadata
        """
        cache_key = self._get_cache_key(prompt, response)

        if use_cache and cache_key in self._cache:
            return self._cache[cache_key]

        prompt_emb = await self._get_embedding_async(prompt)
        response_emb = await self._get_embedding_async(response)

        prompt_emb = prompt_emb.reshape(1, -1)
        response_emb = response_emb.reshape(1, -1)

        reward = self.model.forward(prompt_emb, response_emb)
        score_value = float(reward[0, 0])

        confidence = self._estimate_confidence(prompt_emb, response_emb)

        result = RewardScore(
            score=score_value,
            confidence=confidence,
            components=self._get_score_components(prompt_emb, response_emb),
            metadata={
                "prompt_length": len(prompt),
                "response_length": len(response)
            }
        )

        if use_cache:
            self._update_cache(cache_key, result)

        return result

    async def score_batch_async(
        self,
        pairs: List[Tuple[str, str]],
        use_cache: bool = True
    ) -> List[RewardScore]:
        """
        Score multiple prompt-response pairs asynchronously.

        Args:
            pairs: List of (prompt, response) tuples
            use_cache: Whether to use caching

        Returns:
            List of RewardScores
        """
        tasks = [
            self.score_async(prompt, response, use_cache)
            for prompt, response in pairs
        ]
        return await asyncio.gather(*tasks)

    def _estimate_confidence(
        self,
        prompt_emb: np.ndarray,
        response_emb: np.ndarray
    ) -> float:
        """
        Estimate confidence in the reward score.

        Uses Monte Carlo dropout for uncertainty estimation.
        """
        if self.confidence_method == "ensemble":
            # Run multiple forward passes with dropout
            self.model.train()  # Enable dropout

            scores = []
            for _ in range(5):
                reward = self.model.forward(prompt_emb, response_emb)
                scores.append(float(reward[0, 0]))

            self.model.eval()

            # Confidence inversely related to variance
            variance = np.var(scores)
            confidence = 1.0 / (1.0 + variance)

            return float(confidence)
        else:
            return 0.8  # Default confidence

    def _get_score_components(
        self,
        prompt_emb: np.ndarray,
        response_emb: np.ndarray
    ) -> Dict[str, float]:
        """
        Get component-wise scores for interpretability.

        Returns contributions from different parts of the network.
        """
        components = {}

        # Concatenated input
        x = np.concatenate([prompt_emb, response_emb], axis=-1)

        # Track layer-wise activations
        for i, layer in enumerate(self.model.layers):
            x = layer.forward(x, training=False)
            if self.model.layer_norms[i] is not None:
                x = self.model.layer_norms[i].forward(x, training=False)

            components[f"layer_{i}_mean"] = float(np.mean(x))
            components[f"layer_{i}_std"] = float(np.std(x))

        return components

    def _update_cache(
        self,
        key: str,
        value: RewardScore
    ) -> None:
        """Update cache with LRU eviction."""
        if key in self._cache:
            self._cache_order.remove(key)
        elif len(self._cache) >= self.cache_size:
            oldest = self._cache_order.pop(0)
            del self._cache[oldest]

        self._cache[key] = value
        self._cache_order.append(key)

    def clear_cache(self) -> None:
        """Clear the scoring cache."""
        self._cache.clear()
        self._cache_order.clear()

    def rank_responses(
        self,
        prompt: str,
        responses: List[str]
    ) -> List[Tuple[str, RewardScore]]:
        """
        Rank multiple responses by reward score.

        Args:
            prompt: The prompt
            responses: List of responses to rank

        Returns:
            List of (response, score) tuples sorted by score descending
        """
        scored = [
            (response, self.score(prompt, response))
            for response in responses
        ]
        return sorted(scored, key=lambda x: x[1].score, reverse=True)

    async def rank_responses_async(
        self,
        prompt: str,
        responses: List[str]
    ) -> List[Tuple[str, RewardScore]]:
        """
        Rank multiple responses asynchronously.

        Args:
            prompt: The prompt
            responses: List of responses to rank

        Returns:
            List of (response, score) tuples sorted by score descending
        """
        pairs = [(prompt, response) for response in responses]
        scores = await self.score_batch_async(pairs)

        scored = list(zip(responses, scores))
        return sorted(scored, key=lambda x: x[1].score, reverse=True)


# =============================================================================
# HUMAN FEEDBACK COLLECTOR
# =============================================================================

class HumanFeedbackCollector:
    """
    Interface for gathering human feedback on AI responses.

    Features:
    - Multiple feedback formats (ratings, comparisons, free-form)
    - Quality control (attention checks, spam detection)
    - Annotator tracking and reliability scoring
    - Async batch collection
    """

    def __init__(
        self,
        dataset: PreferenceDataset,
        min_annotators_per_pair: int = 2,
        attention_check_frequency: float = 0.1,
        spam_threshold: float = 0.3
    ):
        """
        Initialize feedback collector.

        Args:
            dataset: Dataset to store collected preferences
            min_annotators_per_pair: Minimum annotators per comparison
            attention_check_frequency: Frequency of attention checks
            spam_threshold: Threshold for flagging spam annotators
        """
        self.dataset = dataset
        self.min_annotators_per_pair = min_annotators_per_pair
        self.attention_check_frequency = attention_check_frequency
        self.spam_threshold = spam_threshold

        self._annotator_reliability: Dict[str, float] = {}
        self._pending_comparisons: Dict[str, Dict[str, Any]] = {}
        self._attention_check_results: Dict[str, List[bool]] = {}

    def create_comparison(
        self,
        prompt: str,
        response_a: str,
        response_b: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> str:
        """
        Create a new comparison task.

        Args:
            prompt: The prompt
            response_a: First response
            response_b: Second response
            metadata: Additional metadata

        Returns:
            Comparison ID
        """
        comparison_id = str(uuid.uuid4())

        self._pending_comparisons[comparison_id] = {
            "prompt": prompt,
            "response_a": Response(
                id=str(uuid.uuid4()),
                prompt=prompt,
                text=response_a
            ),
            "response_b": Response(
                id=str(uuid.uuid4()),
                prompt=prompt,
                text=response_b
            ),
            "metadata": metadata or {},
            "annotations": [],
            "created_at": time.time()
        }

        return comparison_id

    async def submit_preference(
        self,
        comparison_id: str,
        annotator_id: str,
        preference: PreferenceType,
        confidence: float = 1.0,
        reasoning: Optional[str] = None,
        time_spent_seconds: float = 0.0
    ) -> bool:
        """
        Submit a preference annotation.

        Args:
            comparison_id: ID of the comparison
            annotator_id: ID of the annotator
            preference: The preference choice
            confidence: Annotator's confidence
            reasoning: Optional reasoning
            time_spent_seconds: Time spent on annotation

        Returns:
            Whether submission was accepted
        """
        if comparison_id not in self._pending_comparisons:
            logger.warning(f"Unknown comparison: {comparison_id}")
            return False

        comparison = self._pending_comparisons[comparison_id]

        # Check for duplicate annotation
        existing_annotators = [
            a["annotator_id"] for a in comparison["annotations"]
        ]
        if annotator_id in existing_annotators:
            logger.warning(
                f"Duplicate annotation from {annotator_id}"
            )
            return False

        # Spam detection based on time
        if time_spent_seconds < 2.0:  # Too fast
            self._record_spam_signal(annotator_id)
            logger.warning(f"Possible spam from {annotator_id}: too fast")

        # Record annotation
        annotation = {
            "annotator_id": annotator_id,
            "preference": preference,
            "confidence": confidence,
            "reasoning": reasoning,
            "time_spent": time_spent_seconds,
            "timestamp": time.time()
        }
        comparison["annotations"].append(annotation)

        # Check if we have enough annotations
        if len(comparison["annotations"]) >= self.min_annotators_per_pair:
            await self._finalize_comparison(comparison_id)

        return True

    async def _finalize_comparison(self, comparison_id: str) -> None:
        """
        Finalize a comparison and add to dataset.

        Aggregates multiple annotations into a single preference.
        """
        comparison = self._pending_comparisons.pop(comparison_id)

        # Aggregate preferences using weighted voting
        preference_scores = {pt: 0.0 for pt in PreferenceType}
        total_weight = 0.0

        for annotation in comparison["annotations"]:
            annotator_id = annotation["annotator_id"]
            reliability = self._annotator_reliability.get(annotator_id, 0.5)
            weight = reliability * annotation["confidence"]

            preference_scores[annotation["preference"]] += weight
            total_weight += weight

        # Select winning preference
        winning_preference = max(preference_scores, key=preference_scores.get)

        # Compute aggregate confidence
        if total_weight > 0:
            winning_score = preference_scores[winning_preference]
            aggregate_confidence = winning_score / total_weight
        else:
            aggregate_confidence = 0.5

        # Create and add preference pair
        pair = PreferencePair(
            id=str(uuid.uuid4()),
            prompt=comparison["prompt"],
            response_a=comparison["response_a"],
            response_b=comparison["response_b"],
            preference=winning_preference,
            annotator_id="aggregated",
            source=FeedbackSource.CROWDSOURCE,
            confidence=aggregate_confidence,
            metadata={
                "num_annotators": len(comparison["annotations"]),
                "annotations": comparison["annotations"],
                "original_comparison_id": comparison_id
            },
            validated=len(comparison["annotations"]) >= self.min_annotators_per_pair
        )

        self.dataset.add(pair)

        # Update annotator reliability
        await self._update_annotator_reliability(
            comparison["annotations"],
            winning_preference
        )

        logger.info(
            f"Finalized comparison {comparison_id}: "
            f"preference={winning_preference.name}"
        )

    async def _update_annotator_reliability(
        self,
        annotations: List[Dict[str, Any]],
        final_preference: PreferenceType
    ) -> None:
        """Update annotator reliability based on agreement."""
        for annotation in annotations:
            annotator_id = annotation["annotator_id"]

            if annotator_id not in self._annotator_reliability:
                self._annotator_reliability[annotator_id] = 0.5

            # Reward agreement, penalize disagreement
            agreed = annotation["preference"] == final_preference
            delta = 0.05 if agreed else -0.05

            self._annotator_reliability[annotator_id] = max(
                0.1,
                min(1.0, self._annotator_reliability[annotator_id] + delta)
            )

    def _record_spam_signal(self, annotator_id: str) -> None:
        """Record a spam signal for an annotator."""
        if annotator_id not in self._annotator_reliability:
            self._annotator_reliability[annotator_id] = 0.5

        self._annotator_reliability[annotator_id] = max(
            0.1,
            self._annotator_reliability[annotator_id] - 0.1
        )

    def create_attention_check(self) -> Dict[str, Any]:
        """
        Create an attention check comparison.

        Returns a comparison where the correct answer is known.
        """
        checks = [
            {
                "prompt": "What is 2 + 2?",
                "response_a": "The answer is 4.",
                "response_b": "The answer is purple elephant.",
                "correct": PreferenceType.STRONGLY_PREFER_A
            },
            {
                "prompt": "Write a greeting.",
                "response_a": "Hello! How can I help you today?",
                "response_b": "sjdkfhskjdfh random nonsense",
                "correct": PreferenceType.STRONGLY_PREFER_A
            }
        ]

        return random.choice(checks)

    async def verify_attention_check(
        self,
        annotator_id: str,
        check: Dict[str, Any],
        response: PreferenceType
    ) -> bool:
        """
        Verify an attention check response.

        Args:
            annotator_id: The annotator
            check: The attention check data
            response: The annotator's response

        Returns:
            Whether the check passed
        """
        passed = response == check["correct"]

        if annotator_id not in self._attention_check_results:
            self._attention_check_results[annotator_id] = []

        self._attention_check_results[annotator_id].append(passed)

        # Update reliability based on attention checks
        checks = self._attention_check_results[annotator_id]
        if len(checks) >= 3:
            pass_rate = sum(checks[-5:]) / len(checks[-5:])
            if pass_rate < 0.5:
                self._annotator_reliability[annotator_id] = max(
                    0.1,
                    self._annotator_reliability.get(annotator_id, 0.5) - 0.2
                )

        return passed

    def get_annotator_stats(
        self,
        annotator_id: str
    ) -> Dict[str, Any]:
        """
        Get statistics for an annotator.

        Args:
            annotator_id: The annotator ID

        Returns:
            Dictionary of statistics
        """
        attention_checks = self._attention_check_results.get(annotator_id, [])

        return {
            "reliability": self._annotator_reliability.get(annotator_id, 0.5),
            "attention_check_pass_rate": (
                sum(attention_checks) / len(attention_checks)
                if attention_checks else None
            ),
            "total_attention_checks": len(attention_checks),
            "is_trusted": self._annotator_reliability.get(annotator_id, 0.5) > 0.7
        }

    def get_pending_comparisons(
        self,
        limit: int = 10
    ) -> List[Dict[str, Any]]:
        """
        Get pending comparisons that need more annotations.

        Args:
            limit: Maximum number to return

        Returns:
            List of pending comparisons
        """
        pending = []
        for comp_id, comp in self._pending_comparisons.items():
            if len(comp["annotations"]) < self.min_annotators_per_pair:
                pending.append({
                    "id": comp_id,
                    "prompt": comp["prompt"],
                    "response_a": comp["response_a"].text,
                    "response_b": comp["response_b"].text,
                    "current_annotations": len(comp["annotations"]),
                    "needed_annotations": self.min_annotators_per_pair
                })

                if len(pending) >= limit:
                    break

        return pending


# =============================================================================
# MAIN ENTRY POINT
# =============================================================================

async def main():
    """Demonstrate the reward modeling system."""

    logger.info("Initializing AIVA Queen Reward Modeling System")

    # Create preference dataset (uses Elestio PostgreSQL)
    dataset = PreferenceDataset()

    # Add some sample preference pairs
    for i in range(10):
        pair = PreferencePair(
            id=str(uuid.uuid4()),
            prompt=f"Sample prompt {i}",
            response_a=Response(
                id=str(uuid.uuid4()),
                prompt=f"Sample prompt {i}",
                text=f"Good response to prompt {i}"
            ),
            response_b=Response(
                id=str(uuid.uuid4()),
                prompt=f"Sample prompt {i}",
                text=f"Bad response to prompt {i}"
            ),
            preference=PreferenceType.PREFER_A if i % 2 == 0 else PreferenceType.PREFER_B,
            annotator_id="demo_annotator",
            source=FeedbackSource.DIRECT_RATING,
            confidence=0.8
        )
        dataset.add(pair)

    logger.info(f"Dataset statistics: {dataset.get_statistics()}")

    # Create reward model
    model = RewardModel(
        embedding_dim=768,
        hidden_dims=[512, 256, 128],
        dropout=0.1,
        l2_reg=0.01
    )

    logger.info("Created reward model")

    # Create trainer
    trainer = RewardTrainer(
        model=model,
        dataset=dataset,
        learning_rate=1e-4,
        batch_size=4,
        epochs=3,
        validation_split=0.2,
        checkpoint_dir="reward_checkpoints"
    )

    # Train
    logger.info("Starting training...")
    metrics = trainer.train()

    for m in metrics:
        logger.info(
            f"Epoch {m.epoch}: loss={m.loss:.4f}, "
            f"accuracy={m.accuracy:.4f}"
        )

    # Create inference engine
    inference = RewardInference(model=model)

    # Score some responses
    score = inference.score(
        prompt="What is machine learning?",
        response="Machine learning is a subset of artificial intelligence."
    )
    logger.info(f"Reward score: {score.score:.4f}, confidence: {score.confidence:.4f}")

    # Rank responses
    prompt = "Explain quantum computing"
    responses = [
        "Quantum computing uses quantum mechanics for computation.",
        "I don't know what quantum computing is.",
        "Quantum computing is the future of technology."
    ]

    ranked = inference.rank_responses(prompt, responses)
    logger.info("Ranked responses:")
    for response, score in ranked:
        logger.info(f"  Score {score.score:.4f}: {response[:50]}...")

    # Human feedback collector
    collector = HumanFeedbackCollector(dataset=dataset)

    # Create and submit a comparison
    comp_id = collector.create_comparison(
        prompt="How do I learn Python?",
        response_a="Start with tutorials and practice coding daily.",
        response_b="Just buy a computer."
    )

    await collector.submit_preference(
        comparison_id=comp_id,
        annotator_id="annotator_1",
        preference=PreferenceType.STRONGLY_PREFER_A,
        confidence=0.9,
        time_spent_seconds=15.0
    )

    await collector.submit_preference(
        comparison_id=comp_id,
        annotator_id="annotator_2",
        preference=PreferenceType.PREFER_A,
        confidence=0.8,
        time_spent_seconds=12.0
    )

    logger.info(f"Final dataset size: {len(dataset)}")
    logger.info("Reward modeling demonstration complete")


if __name__ == "__main__":
    asyncio.run(main())
