"""
rlm_01_preference_learning.py - Preference Learning System for AIVA Queen RLHF

This module implements a comprehensive preference learning system for AIVA Queen's
Reinforcement Learning from Human Feedback (RLHF) pipeline. It enables the collection,
storage, modeling, and active querying of human preferences.

Key Components:
    - PreferenceCollector: Gather and validate human preference annotations
    - PreferenceDataset: Persistent storage and retrieval of preference pairs
    - BradleyTerryModel: Statistical model for preference prediction
    - PreferenceLearner: Online and batch learning from preferences
    - ActivePreferenceQuery: Information-theoretic selection of comparisons
    - PreferenceIntegrator: Bridge preferences into the training pipeline

Reference Papers:
    - Bradley-Terry Model (1952): Rank analysis of incomplete block designs
    - Christiano et al. (2017): Deep RL from Human Preferences
    - Rafailov et al. (2023): DPO - Direct Preference Optimization

Author: Genesis-OS AIVA Queen RLM Module
Version: 1.0.0
"""

from __future__ import annotations

import asyncio
import hashlib
import json
import logging
import math
import os
import pickle
import random
import sys
sys.path.append('/mnt/e/genesis-system/data/genesis-memory')
from elestio_config import PostgresConfig
import psycopg2
import psycopg2.extras
import statistics
import threading
import time
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta
from enum import Enum, auto
from pathlib import Path
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    Iterator,
    List,
    Optional,
    Protocol,
    Set,
    Tuple,
    TypeVar,
    Union,
)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("AIVA.PreferenceLearning")


# =============================================================================
# TYPE DEFINITIONS AND ENUMS
# =============================================================================

T = TypeVar("T")
ResponseId = str
AnnotatorId = str
ComparisonId = str


class PreferenceStrength(Enum):
    """Strength of preference between two options."""
    STRONGLY_PREFER_A = 3
    PREFER_A = 2
    SLIGHTLY_PREFER_A = 1
    TIE = 0
    SLIGHTLY_PREFER_B = -1
    PREFER_B = -2
    STRONGLY_PREFER_B = -3


class FeedbackQuality(Enum):
    """Quality assessment of feedback."""
    HIGH = auto()
    MEDIUM = auto()
    LOW = auto()
    SPAM = auto()
    ATTENTION_CHECK_FAILED = auto()


class LearningStrategy(Enum):
    """Strategy for preference learning updates."""
    ONLINE = auto()  # Update after each preference
    MINI_BATCH = auto()  # Update after N preferences
    FULL_BATCH = auto()  # Update on entire dataset
    STREAMING = auto()  # Continuous streaming updates


class QueryStrategy(Enum):
    """Strategy for active preference querying."""
    UNCERTAINTY = auto()  # Maximum model uncertainty
    INFORMATION_GAIN = auto()  # Maximum expected information gain
    DIVERSITY = auto()  # Maximum diversity of comparisons
    DISAGREEMENT = auto()  # Maximum annotator disagreement
    HYBRID = auto()  # Combination of strategies


# =============================================================================
# DATA STRUCTURES
# =============================================================================

@dataclass
class Response:
    """A model response to be compared."""
    id: ResponseId
    prompt: str
    text: str
    model_id: Optional[str] = None
    timestamp: float = field(default_factory=time.time)
    embedding: Optional[List[float]] = None
    metadata: Dict[str, Any] = field(default_factory=dict)

    def __hash__(self) -> int:
        return hash(self.id)

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Response):
            return False
        return self.id == other.id

    def to_dict(self) -> Dict[str, Any]:
        return {
            "id": self.id,
            "prompt": self.prompt,
            "text": self.text,
            "model_id": self.model_id,
            "timestamp": self.timestamp,
            "metadata": self.metadata
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Response":
        return cls(
            id=data["id"],
            prompt=data["prompt"],
            text=data["text"],
            model_id=data.get("model_id"),
            timestamp=data.get("timestamp", time.time()),
            metadata=data.get("metadata", {})
        )


@dataclass
class PreferencePair:
    """A pair of responses with human preference annotation."""
    id: ComparisonId
    prompt: str
    response_a: Response
    response_b: Response
    preference: PreferenceStrength
    annotator_id: AnnotatorId
    timestamp: float = field(default_factory=time.time)
    confidence: float = 1.0
    reasoning: Optional[str] = None
    quality: FeedbackQuality = FeedbackQuality.MEDIUM
    metadata: Dict[str, Any] = field(default_factory=dict)
    is_attention_check: bool = False

    @property
    def winner(self) -> Optional[Response]:
        """Get the winning response, or None for ties."""
        if self.preference.value > 0:
            return self.response_a
        elif self.preference.value < 0:
            return self.response_b
        return None

    @property
    def loser(self) -> Optional[Response]:
        """Get the losing response, or None for ties."""
        if self.preference.value > 0:
            return self.response_b
        elif self.preference.value < 0:
            return self.response_a
        return None

    @property
    def margin(self) -> float:
        """Get normalized preference margin in [-1, 1]."""
        return self.preference.value / 3.0

    @property
    def label(self) -> float:
        """Get probability label for A being preferred (for training)."""
        return (self.margin + 1.0) / 2.0

    def to_dict(self) -> Dict[str, Any]:
        return {
            "id": self.id,
            "prompt": self.prompt,
            "response_a": self.response_a.to_dict(),
            "response_b": self.response_b.to_dict(),
            "preference": self.preference.name,
            "annotator_id": self.annotator_id,
            "timestamp": self.timestamp,
            "confidence": self.confidence,
            "reasoning": self.reasoning,
            "quality": self.quality.name,
            "metadata": self.metadata,
            "is_attention_check": self.is_attention_check
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "PreferencePair":
        return cls(
            id=data["id"],
            prompt=data["prompt"],
            response_a=Response.from_dict(data["response_a"]),
            response_b=Response.from_dict(data["response_b"]),
            preference=PreferenceStrength[data["preference"]],
            annotator_id=data["annotator_id"],
            timestamp=data.get("timestamp", time.time()),
            confidence=data.get("confidence", 1.0),
            reasoning=data.get("reasoning"),
            quality=FeedbackQuality[data.get("quality", "MEDIUM")],
            metadata=data.get("metadata", {}),
            is_attention_check=data.get("is_attention_check", False)
        )


@dataclass
class AnnotatorProfile:
    """Profile tracking annotator quality and statistics."""
    id: AnnotatorId
    total_annotations: int = 0
    agreement_rate: float = 0.5
    attention_check_pass_rate: float = 1.0
    average_confidence: float = 0.5
    average_time_seconds: float = 30.0
    quality_score: float = 0.5
    last_active: float = field(default_factory=time.time)
    is_trusted: bool = False
    preferences_by_type: Dict[str, int] = field(default_factory=dict)

    def update_from_annotation(
        self,
        passed_attention: Optional[bool] = None,
        confidence: float = 1.0,
        time_spent: float = 30.0
    ) -> None:
        """Update profile from a new annotation."""
        self.total_annotations += 1
        self.last_active = time.time()

        # Update running averages
        alpha = 0.1
        self.average_confidence = (1 - alpha) * self.average_confidence + alpha * confidence
        self.average_time_seconds = (1 - alpha) * self.average_time_seconds + alpha * time_spent

        if passed_attention is not None:
            n = self.total_annotations
            self.attention_check_pass_rate = (
                (n - 1) * self.attention_check_pass_rate + float(passed_attention)
            ) / n

        # Update quality score
        self.quality_score = (
            0.3 * self.agreement_rate +
            0.3 * self.attention_check_pass_rate +
            0.2 * min(1.0, self.average_confidence) +
            0.2 * (1.0 if 5 < self.average_time_seconds < 300 else 0.5)
        )

        # Trust threshold
        self.is_trusted = (
            self.total_annotations >= 10 and
            self.quality_score >= 0.7 and
            self.attention_check_pass_rate >= 0.8
        )


# =============================================================================
# PREFERENCE COLLECTOR
# =============================================================================

class PreferenceCollector:
    """
    Collects and validates human preference annotations.

    Features:
    - Multi-source preference collection (direct, comparison, implicit)
    - Quality control with attention checks
    - Annotator reliability tracking
    - Spam and noise filtering
    - Async batch collection
    """

    def __init__(
        self,
        attention_check_frequency: float = 0.1,
        min_time_seconds: float = 3.0,
        max_time_seconds: float = 600.0,
        spam_threshold: float = 0.3,
        require_reasoning: bool = False
    ):
        """
        Initialize preference collector.

        Args:
            attention_check_frequency: Frequency of attention check insertions
            min_time_seconds: Minimum acceptable annotation time
            max_time_seconds: Maximum acceptable annotation time
            spam_threshold: Threshold for flagging spam annotators
            require_reasoning: Whether to require reasoning for preferences
        """
        self.attention_check_frequency = attention_check_frequency
        self.min_time_seconds = min_time_seconds
        self.max_time_seconds = max_time_seconds
        self.spam_threshold = spam_threshold
        self.require_reasoning = require_reasoning

        self._annotator_profiles: Dict[AnnotatorId, AnnotatorProfile] = {}
        self._pending_comparisons: Dict[ComparisonId, Dict[str, Any]] = {}
        self._attention_checks: List[Dict[str, Any]] = self._generate_attention_checks()
        self._collected_pairs: List[PreferencePair] = []
        self._lock = threading.Lock()

        logger.info("PreferenceCollector initialized")

    def _generate_attention_checks(self) -> List[Dict[str, Any]]:
        """Generate attention check comparisons with known correct answers."""
        return [
            {
                "prompt": "What is 2 + 2?",
                "response_a": "The answer is 4.",
                "response_b": "The answer is banana.",
                "correct": PreferenceStrength.STRONGLY_PREFER_A
            },
            {
                "prompt": "Write a greeting.",
                "response_a": "Hello! Nice to meet you.",
                "response_b": "asdfghjkl qwerty",
                "correct": PreferenceStrength.STRONGLY_PREFER_A
            },
            {
                "prompt": "Explain what water is.",
                "response_a": "I don't know anything.",
                "response_b": "Water is H2O, a molecule consisting of two hydrogen atoms and one oxygen atom.",
                "correct": PreferenceStrength.STRONGLY_PREFER_B
            }
        ]

    def get_annotator_profile(self, annotator_id: AnnotatorId) -> AnnotatorProfile:
        """Get or create annotator profile."""
        if annotator_id not in self._annotator_profiles:
            self._annotator_profiles[annotator_id] = AnnotatorProfile(id=annotator_id)
        return self._annotator_profiles[annotator_id]

    def create_comparison(
        self,
        prompt: str,
        response_a: Response,
        response_b: Response,
        metadata: Optional[Dict[str, Any]] = None
    ) -> ComparisonId:
        """
        Create a new comparison task.

        Args:
            prompt: The prompt
            response_a: First response
            response_b: Second response
            metadata: Additional metadata

        Returns:
            Comparison ID
        """
        comparison_id = str(uuid.uuid4())

        # Randomly decide to insert attention check
        is_attention_check = random.random() < self.attention_check_frequency
        if is_attention_check:
            check = random.choice(self._attention_checks)
            response_a = Response(
                id=str(uuid.uuid4()),
                prompt=check["prompt"],
                text=check["response_a"]
            )
            response_b = Response(
                id=str(uuid.uuid4()),
                prompt=check["prompt"],
                text=check["response_b"]
            )
            prompt = check["prompt"]

        with self._lock:
            self._pending_comparisons[comparison_id] = {
                "prompt": prompt,
                "response_a": response_a,
                "response_b": response_b,
                "metadata": metadata or {},
                "created_at": time.time(),
                "is_attention_check": is_attention_check,
                "correct_answer": check["correct"] if is_attention_check else None,
                "annotations": []
            }

        return comparison_id

    async def submit_preference(
        self,
        comparison_id: ComparisonId,
        annotator_id: AnnotatorId,
        preference: PreferenceStrength,
        confidence: float = 1.0,
        reasoning: Optional[str] = None,
        time_spent_seconds: float = 30.0
    ) -> Tuple[bool, Optional[PreferencePair]]:
        """
        Submit a preference annotation.

        Args:
            comparison_id: ID of the comparison
            annotator_id: ID of the annotator
            preference: The preference choice
            confidence: Annotator's confidence
            reasoning: Optional reasoning
            time_spent_seconds: Time spent on annotation

        Returns:
            Tuple of (accepted, resulting_pair)
        """
        with self._lock:
            if comparison_id not in self._pending_comparisons:
                logger.warning(f"Unknown comparison: {comparison_id}")
                return False, None

            comparison = self._pending_comparisons[comparison_id]

            # Check for duplicate
            existing_annotators = [a["annotator_id"] for a in comparison["annotations"]]
            if annotator_id in existing_annotators:
                logger.warning(f"Duplicate annotation from {annotator_id}")
                return False, None

        # Validate annotation
        quality = self._assess_quality(
            annotator_id, preference, confidence, reasoning, time_spent_seconds
        )

        # Handle attention check
        passed_attention = None
        if comparison["is_attention_check"]:
            correct = comparison["correct_answer"]
            # Allow some tolerance in strength
            passed_attention = (
                (preference.value > 0 and correct.value > 0) or
                (preference.value < 0 and correct.value < 0)
            )
            if not passed_attention:
                quality = FeedbackQuality.ATTENTION_CHECK_FAILED
                logger.warning(f"Annotator {annotator_id} failed attention check")

        # Update annotator profile
        profile = self.get_annotator_profile(annotator_id)
        profile.update_from_annotation(
            passed_attention=passed_attention,
            confidence=confidence,
            time_spent=time_spent_seconds
        )

        # Store annotation
        annotation = {
            "annotator_id": annotator_id,
            "preference": preference,
            "confidence": confidence,
            "reasoning": reasoning,
            "time_spent": time_spent_seconds,
            "quality": quality,
            "timestamp": time.time()
        }

        with self._lock:
            comparison["annotations"].append(annotation)

        # Create preference pair if quality is acceptable
        if quality in [FeedbackQuality.HIGH, FeedbackQuality.MEDIUM]:
            pair = PreferencePair(
                id=str(uuid.uuid4()),
                prompt=comparison["prompt"],
                response_a=comparison["response_a"],
                response_b=comparison["response_b"],
                preference=preference,
                annotator_id=annotator_id,
                confidence=confidence,
                reasoning=reasoning,
                quality=quality,
                metadata=comparison["metadata"],
                is_attention_check=comparison["is_attention_check"]
            )

            with self._lock:
                self._collected_pairs.append(pair)

            return True, pair

        return False, None

    def _assess_quality(
        self,
        annotator_id: AnnotatorId,
        preference: PreferenceStrength,
        confidence: float,
        reasoning: Optional[str],
        time_spent: float
    ) -> FeedbackQuality:
        """Assess the quality of a preference annotation."""
        profile = self.get_annotator_profile(annotator_id)

        # Check for spam signals
        spam_signals = 0

        # Too fast
        if time_spent < self.min_time_seconds:
            spam_signals += 2

        # Too slow (might indicate distraction)
        if time_spent > self.max_time_seconds:
            spam_signals += 1

        # Low confidence always ties
        if confidence < 0.3 and preference == PreferenceStrength.TIE:
            spam_signals += 1

        # Missing reasoning when required
        if self.require_reasoning and not reasoning:
            spam_signals += 1

        # Annotator has low quality score
        if profile.quality_score < self.spam_threshold:
            spam_signals += 2

        # Determine quality
        if spam_signals >= 3:
            return FeedbackQuality.SPAM
        elif spam_signals >= 2:
            return FeedbackQuality.LOW
        elif spam_signals >= 1:
            return FeedbackQuality.MEDIUM
        else:
            return FeedbackQuality.HIGH

    def get_collected_pairs(
        self,
        min_quality: FeedbackQuality = FeedbackQuality.MEDIUM,
        exclude_attention_checks: bool = True
    ) -> List[PreferencePair]:
        """Get collected preference pairs filtered by quality."""
        quality_rank = {
            FeedbackQuality.HIGH: 3,
            FeedbackQuality.MEDIUM: 2,
            FeedbackQuality.LOW: 1,
            FeedbackQuality.SPAM: 0,
            FeedbackQuality.ATTENTION_CHECK_FAILED: 0
        }
        min_rank = quality_rank[min_quality]

        with self._lock:
            pairs = [
                p for p in self._collected_pairs
                if quality_rank[p.quality] >= min_rank
            ]

            if exclude_attention_checks:
                pairs = [p for p in pairs if not p.is_attention_check]

            return pairs

    def get_statistics(self) -> Dict[str, Any]:
        """Get collection statistics."""
        with self._lock:
            pairs = self._collected_pairs
            profiles = list(self._annotator_profiles.values())

        return {
            "total_pairs": len(pairs),
            "total_annotators": len(profiles),
            "quality_distribution": {
                q.name: sum(1 for p in pairs if p.quality == q)
                for q in FeedbackQuality
            },
            "preference_distribution": {
                s.name: sum(1 for p in pairs if p.preference == s)
                for s in PreferenceStrength
            },
            "average_confidence": statistics.mean([p.confidence for p in pairs]) if pairs else 0,
            "trusted_annotators": sum(1 for p in profiles if p.is_trusted)
        }


# =============================================================================
# PREFERENCE DATASET
# =============================================================================

class PreferenceDataset:
    """
    Persistent storage and retrieval of preference pairs.

    Features:
    - PostgreSQL-backed persistence (Elestio)
    - Efficient batch operations
    - Filtering and sampling
    - Cross-annotator agreement computation
    - Export/import capabilities
    """

    def __init__(
        self,
        db_path: Optional[str] = None,
        min_confidence: float = 0.5
    ):
        """
        Initialize preference dataset.

        Args:
            db_path: Deprecated (ignored). Uses Elestio PostgreSQL.
            min_confidence: Minimum confidence threshold for retrieval
        """
        self.min_confidence = min_confidence
        self._pairs: Dict[ComparisonId, PreferencePair] = {}
        self._by_prompt: Dict[str, List[ComparisonId]] = defaultdict(list)
        self._by_response: Dict[ResponseId, List[ComparisonId]] = defaultdict(list)
        self._lock = threading.Lock()

        self._init_database()
        self._load_from_database()

        logger.info(f"PreferenceDataset initialized with {len(self._pairs)} pairs")

    def _get_conn(self):
        """Get a PostgreSQL connection from Elestio."""
        return psycopg2.connect(**PostgresConfig.get_connection_params())

    def _init_database(self) -> None:
        """Ensure PostgreSQL tables exist (tables should already exist in PG)."""
        conn = self._get_conn()
        try:
            with conn.cursor() as cur:
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS pl_preference_pairs (
                        id TEXT PRIMARY KEY,
                        prompt TEXT NOT NULL,
                        response_a_id TEXT NOT NULL,
                        response_a_text TEXT NOT NULL,
                        response_b_id TEXT NOT NULL,
                        response_b_text TEXT NOT NULL,
                        preference TEXT NOT NULL,
                        annotator_id TEXT NOT NULL,
                        timestamp DOUBLE PRECISION,
                        confidence REAL DEFAULT 1.0,
                        reasoning TEXT,
                        quality TEXT,
                        metadata TEXT,
                        is_attention_check INTEGER DEFAULT 0
                    )
                """)
            conn.commit()
        finally:
            conn.close()

    def _load_from_database(self) -> None:
        """Load existing pairs from database."""
        conn = self._get_conn()
        try:
            with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
                cur.execute("SELECT * FROM pl_preference_pairs")

                for row in cur:
                    pair = PreferencePair(
                        id=row["id"],
                        prompt=row["prompt"],
                        response_a=Response(
                            id=row["response_a_id"],
                            prompt=row["prompt"],
                            text=row["response_a_text"]
                        ),
                        response_b=Response(
                            id=row["response_b_id"],
                            prompt=row["prompt"],
                            text=row["response_b_text"]
                        ),
                        preference=PreferenceStrength[row["preference"]],
                        annotator_id=row["annotator_id"],
                        timestamp=row["timestamp"],
                        confidence=row["confidence"],
                        reasoning=row["reasoning"],
                        quality=FeedbackQuality[row["quality"]] if row["quality"] else FeedbackQuality.MEDIUM,
                        metadata=json.loads(row["metadata"]) if row["metadata"] else {},
                        is_attention_check=bool(row["is_attention_check"])
                    )
                    self._pairs[pair.id] = pair
                    self._by_prompt[pair.prompt].append(pair.id)
                    self._by_response[pair.response_a.id].append(pair.id)
                    self._by_response[pair.response_b.id].append(pair.id)
        finally:
            conn.close()

    def add(self, pair: PreferencePair) -> str:
        """Add a preference pair to the dataset."""
        with self._lock:
            self._pairs[pair.id] = pair
            self._by_prompt[pair.prompt].append(pair.id)
            self._by_response[pair.response_a.id].append(pair.id)
            self._by_response[pair.response_b.id].append(pair.id)

        self._save_pair(pair)
        return pair.id

    def add_batch(self, pairs: List[PreferencePair]) -> List[str]:
        """Add multiple preference pairs."""
        ids = []
        for pair in pairs:
            ids.append(self.add(pair))
        return ids

    def _save_pair(self, pair: PreferencePair) -> None:
        """Save pair to database."""
        conn = self._get_conn()
        try:
            with conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO pl_preference_pairs
                    (id, prompt, response_a_id, response_a_text, response_b_id,
                     response_b_text, preference, annotator_id, timestamp,
                     confidence, reasoning, quality, metadata, is_attention_check)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                    ON CONFLICT (id) DO UPDATE SET
                        prompt = EXCLUDED.prompt,
                        response_a_id = EXCLUDED.response_a_id,
                        response_a_text = EXCLUDED.response_a_text,
                        response_b_id = EXCLUDED.response_b_id,
                        response_b_text = EXCLUDED.response_b_text,
                        preference = EXCLUDED.preference,
                        annotator_id = EXCLUDED.annotator_id,
                        timestamp = EXCLUDED.timestamp,
                        confidence = EXCLUDED.confidence,
                        reasoning = EXCLUDED.reasoning,
                        quality = EXCLUDED.quality,
                        metadata = EXCLUDED.metadata,
                        is_attention_check = EXCLUDED.is_attention_check
                """, (
                    pair.id,
                    pair.prompt,
                    pair.response_a.id,
                    pair.response_a.text,
                    pair.response_b.id,
                    pair.response_b.text,
                    pair.preference.name,
                    pair.annotator_id,
                    pair.timestamp,
                    pair.confidence,
                    pair.reasoning,
                    pair.quality.name,
                    json.dumps(pair.metadata),
                    int(pair.is_attention_check)
                ))
            conn.commit()
        finally:
            conn.close()

    def get(self, pair_id: ComparisonId) -> Optional[PreferencePair]:
        """Get a preference pair by ID."""
        return self._pairs.get(pair_id)

    def get_by_prompt(self, prompt: str) -> List[PreferencePair]:
        """Get all preference pairs for a prompt."""
        with self._lock:
            pair_ids = self._by_prompt.get(prompt, [])
            return [self._pairs[pid] for pid in pair_ids if pid in self._pairs]

    def get_by_response(self, response_id: ResponseId) -> List[PreferencePair]:
        """Get all preference pairs involving a response."""
        with self._lock:
            pair_ids = self._by_response.get(response_id, [])
            return [self._pairs[pid] for pid in pair_ids if pid in self._pairs]

    def sample(
        self,
        n: int,
        stratified: bool = True,
        min_confidence: Optional[float] = None
    ) -> List[PreferencePair]:
        """
        Sample preference pairs from the dataset.

        Args:
            n: Number of pairs to sample
            stratified: Whether to stratify by preference type
            min_confidence: Minimum confidence threshold
        """
        min_conf = min_confidence or self.min_confidence

        with self._lock:
            eligible = [
                p for p in self._pairs.values()
                if p.confidence >= min_conf and not p.is_attention_check
            ]

        if not eligible:
            return []

        if stratified:
            # Group by preference strength
            by_pref: Dict[PreferenceStrength, List[PreferencePair]] = defaultdict(list)
            for p in eligible:
                by_pref[p.preference].append(p)

            # Sample evenly from each group
            sampled = []
            per_group = max(1, n // len(by_pref))
            for pref, pairs in by_pref.items():
                sampled.extend(random.sample(pairs, min(per_group, len(pairs))))

            # Fill remaining with random samples
            while len(sampled) < n and len(sampled) < len(eligible):
                remaining = [p for p in eligible if p not in sampled]
                sampled.append(random.choice(remaining))

            return sampled[:n]
        else:
            return random.sample(eligible, min(n, len(eligible)))

    def __len__(self) -> int:
        return len(self._pairs)

    def __iter__(self) -> Iterator[PreferencePair]:
        return iter(self._pairs.values())

    def export_json(self, filepath: str) -> None:
        """Export dataset to JSON file."""
        with self._lock:
            pairs = [p.to_dict() for p in self._pairs.values()]

        with open(filepath, "w") as f:
            json.dump({"pairs": pairs, "exported_at": datetime.now().isoformat()}, f, indent=2)

        logger.info(f"Exported {len(pairs)} pairs to {filepath}")

    @classmethod
    def import_json(cls, filepath: str, db_path: Optional[str] = None) -> "PreferenceDataset":
        """Import dataset from JSON file."""
        with open(filepath, "r") as f:
            data = json.load(f)

        dataset = cls(db_path=db_path)
        for pair_data in data.get("pairs", []):
            dataset.add(PreferencePair.from_dict(pair_data))

        return dataset


# =============================================================================
# BRADLEY-TERRY MODEL
# =============================================================================

class BradleyTerryModel:
    """
    Bradley-Terry model for preference prediction.

    The Bradley-Terry model estimates the probability that response A is
    preferred over response B as:
        P(A > B) = sigma(r_A - r_B)

    Where r_A and r_B are learned "strength" parameters for each response,
    and sigma is the sigmoid function.

    Features:
    - Online and batch parameter estimation
    - Regularization (L2 penalty)
    - Uncertainty estimation via Laplace approximation
    - Efficient incremental updates
    """

    def __init__(
        self,
        learning_rate: float = 0.1,
        regularization: float = 0.01,
        prior_strength: float = 0.0,
        use_confidence_weighting: bool = True
    ):
        """
        Initialize Bradley-Terry model.

        Args:
            learning_rate: Learning rate for parameter updates
            regularization: L2 regularization strength
            prior_strength: Prior strength for all responses
            use_confidence_weighting: Weight updates by annotator confidence
        """
        self.learning_rate = learning_rate
        self.regularization = regularization
        self.prior_strength = prior_strength
        self.use_confidence_weighting = use_confidence_weighting

        self._strengths: Dict[ResponseId, float] = defaultdict(lambda: prior_strength)
        self._counts: Dict[ResponseId, int] = defaultdict(int)
        self._pair_count = 0

        # For uncertainty estimation
        self._hessian_diag: Dict[ResponseId, float] = defaultdict(float)

    def _sigmoid(self, x: float) -> float:
        """Numerically stable sigmoid."""
        if x >= 0:
            z = math.exp(-x)
            return 1 / (1 + z)
        else:
            z = math.exp(x)
            return z / (1 + z)

    def predict_preference(
        self,
        response_a: ResponseId,
        response_b: ResponseId
    ) -> float:
        """
        Predict probability that A is preferred over B.

        Args:
            response_a: ID of response A
            response_b: ID of response B

        Returns:
            P(A > B) in [0, 1]
        """
        strength_a = self._strengths[response_a]
        strength_b = self._strengths[response_b]
        return self._sigmoid(strength_a - strength_b)

    def get_uncertainty(
        self,
        response_a: ResponseId,
        response_b: ResponseId
    ) -> float:
        """
        Get uncertainty in preference prediction.

        Uses Laplace approximation to estimate prediction variance.

        Args:
            response_a: ID of response A
            response_b: ID of response B

        Returns:
            Uncertainty score (higher = more uncertain)
        """
        # Variance via Laplace approximation
        var_a = 1.0 / max(self._hessian_diag[response_a] + self.regularization, 1e-6)
        var_b = 1.0 / max(self._hessian_diag[response_b] + self.regularization, 1e-6)

        # Combined variance of the difference
        var_diff = var_a + var_b

        # Map to uncertainty: higher variance near decision boundary
        prob = self.predict_preference(response_a, response_b)
        entropy = -prob * math.log(prob + 1e-10) - (1 - prob) * math.log(1 - prob + 1e-10)

        return entropy * math.sqrt(var_diff)

    def update(self, pair: PreferencePair) -> Dict[str, float]:
        """
        Update model parameters from a single preference pair.

        Args:
            pair: Preference pair to learn from

        Returns:
            Dictionary with update metrics
        """
        response_a = pair.response_a.id
        response_b = pair.response_b.id
        label = pair.label  # P(A preferred)
        weight = pair.confidence if self.use_confidence_weighting else 1.0

        # Current prediction
        pred = self.predict_preference(response_a, response_b)

        # Gradient (cross-entropy loss derivative)
        error = pred - label
        grad = weight * error

        # L2 regularization gradient
        reg_grad_a = self.regularization * self._strengths[response_a]
        reg_grad_b = self.regularization * self._strengths[response_b]

        # Update strengths
        self._strengths[response_a] -= self.learning_rate * (grad + reg_grad_a)
        self._strengths[response_b] += self.learning_rate * (grad - reg_grad_b)

        # Update counts
        self._counts[response_a] += 1
        self._counts[response_b] += 1
        self._pair_count += 1

        # Update Hessian diagonal for uncertainty estimation
        hess = weight * pred * (1 - pred)
        self._hessian_diag[response_a] += hess
        self._hessian_diag[response_b] += hess

        # Compute loss
        loss = -label * math.log(pred + 1e-10) - (1 - label) * math.log(1 - pred + 1e-10)

        return {
            "loss": loss,
            "prediction": pred,
            "error": abs(error),
            "strength_a": self._strengths[response_a],
            "strength_b": self._strengths[response_b]
        }

    def fit(
        self,
        pairs: List[PreferencePair],
        epochs: int = 10,
        shuffle: bool = True
    ) -> Dict[str, List[float]]:
        """
        Fit model on a batch of preference pairs.

        Args:
            pairs: List of preference pairs
            epochs: Number of training epochs
            shuffle: Whether to shuffle pairs each epoch

        Returns:
            Dictionary with training metrics over epochs
        """
        metrics = {"loss": [], "accuracy": []}

        for epoch in range(epochs):
            epoch_pairs = pairs.copy()
            if shuffle:
                random.shuffle(epoch_pairs)

            epoch_loss = 0.0
            correct = 0

            for pair in epoch_pairs:
                result = self.update(pair)
                epoch_loss += result["loss"]

                # Accuracy: did we predict the right preference?
                pred_a_wins = result["prediction"] > 0.5
                actual_a_wins = pair.label > 0.5
                if pred_a_wins == actual_a_wins:
                    correct += 1

            avg_loss = epoch_loss / len(pairs) if pairs else 0
            accuracy = correct / len(pairs) if pairs else 0

            metrics["loss"].append(avg_loss)
            metrics["accuracy"].append(accuracy)

            logger.debug(f"Epoch {epoch + 1}: loss={avg_loss:.4f}, accuracy={accuracy:.4f}")

        return metrics

    def get_ranking(self, response_ids: List[ResponseId]) -> List[Tuple[ResponseId, float]]:
        """
        Get ranked list of responses by strength.

        Args:
            response_ids: List of response IDs to rank

        Returns:
            List of (response_id, strength) tuples sorted by strength
        """
        ranked = [(rid, self._strengths[rid]) for rid in response_ids]
        return sorted(ranked, key=lambda x: x[1], reverse=True)

    def get_strength(self, response_id: ResponseId) -> float:
        """Get the strength parameter for a response."""
        return self._strengths[response_id]

    def save(self, filepath: str) -> None:
        """Save model to file."""
        state = {
            "strengths": dict(self._strengths),
            "counts": dict(self._counts),
            "hessian_diag": dict(self._hessian_diag),
            "pair_count": self._pair_count,
            "config": {
                "learning_rate": self.learning_rate,
                "regularization": self.regularization,
                "prior_strength": self.prior_strength
            }
        }
        with open(filepath, "wb") as f:
            pickle.dump(state, f)
        logger.info(f"Model saved to {filepath}")

    @classmethod
    def load(cls, filepath: str) -> "BradleyTerryModel":
        """Load model from file."""
        with open(filepath, "rb") as f:
            state = pickle.load(f)

        model = cls(**state["config"])
        model._strengths = defaultdict(lambda: model.prior_strength, state["strengths"])
        model._counts = defaultdict(int, state["counts"])
        model._hessian_diag = defaultdict(float, state["hessian_diag"])
        model._pair_count = state["pair_count"]

        logger.info(f"Model loaded from {filepath}")
        return model


# =============================================================================
# PREFERENCE LEARNER
# =============================================================================

class PreferenceLearner:
    """
    Learns from human preferences using various strategies.

    Features:
    - Online, mini-batch, and full batch learning
    - Multiple model backends (Bradley-Terry, neural)
    - Curriculum learning support
    - Learning rate scheduling
    - Validation and early stopping
    """

    def __init__(
        self,
        model: Optional[BradleyTerryModel] = None,
        strategy: LearningStrategy = LearningStrategy.MINI_BATCH,
        batch_size: int = 32,
        validation_split: float = 0.1,
        patience: int = 5
    ):
        """
        Initialize preference learner.

        Args:
            model: Preference model (creates default if None)
            strategy: Learning strategy
            batch_size: Batch size for mini-batch learning
            validation_split: Fraction for validation
            patience: Early stopping patience
        """
        self.model = model or BradleyTerryModel()
        self.strategy = strategy
        self.batch_size = batch_size
        self.validation_split = validation_split
        self.patience = patience

        self._buffer: List[PreferencePair] = []
        self._training_history: List[Dict[str, float]] = []
        self._best_val_loss = float("inf")
        self._patience_counter = 0

    def learn(self, pair: PreferencePair) -> Optional[Dict[str, float]]:
        """
        Learn from a single preference pair.

        Args:
            pair: Preference pair to learn from

        Returns:
            Metrics if update occurred, None otherwise
        """
        if self.strategy == LearningStrategy.ONLINE:
            return self.model.update(pair)

        self._buffer.append(pair)

        if self.strategy == LearningStrategy.MINI_BATCH:
            if len(self._buffer) >= self.batch_size:
                return self._train_batch(self._buffer)

        return None

    def learn_batch(
        self,
        pairs: List[PreferencePair],
        epochs: int = 10
    ) -> Dict[str, Any]:
        """
        Learn from a batch of preference pairs.

        Args:
            pairs: List of preference pairs
            epochs: Number of training epochs

        Returns:
            Training results
        """
        # Split into train/val
        random.shuffle(pairs)
        split_idx = int(len(pairs) * (1 - self.validation_split))
        train_pairs = pairs[:split_idx]
        val_pairs = pairs[split_idx:]

        results = {
            "train_metrics": [],
            "val_metrics": [],
            "best_epoch": 0
        }

        for epoch in range(epochs):
            # Train
            train_result = self._train_batch(train_pairs)
            results["train_metrics"].append(train_result)

            # Validate
            if val_pairs:
                val_result = self._evaluate(val_pairs)
                results["val_metrics"].append(val_result)

                # Early stopping
                if val_result["loss"] < self._best_val_loss:
                    self._best_val_loss = val_result["loss"]
                    self._patience_counter = 0
                    results["best_epoch"] = epoch
                else:
                    self._patience_counter += 1
                    if self._patience_counter >= self.patience:
                        logger.info(f"Early stopping at epoch {epoch}")
                        break

            self._training_history.append({
                "epoch": epoch,
                "train_loss": train_result["loss"],
                "train_accuracy": train_result["accuracy"],
                "val_loss": val_result["loss"] if val_pairs else None
            })

        return results

    def _train_batch(self, pairs: List[PreferencePair]) -> Dict[str, float]:
        """Train on a batch of pairs."""
        total_loss = 0.0
        correct = 0

        random.shuffle(pairs)
        for pair in pairs:
            result = self.model.update(pair)
            total_loss += result["loss"]
            if (result["prediction"] > 0.5) == (pair.label > 0.5):
                correct += 1

        self._buffer = []

        return {
            "loss": total_loss / len(pairs) if pairs else 0,
            "accuracy": correct / len(pairs) if pairs else 0,
            "n_pairs": len(pairs)
        }

    def _evaluate(self, pairs: List[PreferencePair]) -> Dict[str, float]:
        """Evaluate on a set of pairs without updating."""
        total_loss = 0.0
        correct = 0

        for pair in pairs:
            pred = self.model.predict_preference(pair.response_a.id, pair.response_b.id)
            label = pair.label

            loss = -label * math.log(pred + 1e-10) - (1 - label) * math.log(1 - pred + 1e-10)
            total_loss += loss

            if (pred > 0.5) == (label > 0.5):
                correct += 1

        return {
            "loss": total_loss / len(pairs) if pairs else 0,
            "accuracy": correct / len(pairs) if pairs else 0
        }

    def get_training_history(self) -> List[Dict[str, float]]:
        """Get training history."""
        return self._training_history


# =============================================================================
# ACTIVE PREFERENCE QUERY
# =============================================================================

class ActivePreferenceQuery:
    """
    Selects the most informative comparisons for human annotation.

    Uses information-theoretic criteria to select comparisons that
    will most reduce model uncertainty.

    Features:
    - Multiple query strategies (uncertainty, information gain, diversity)
    - Batch query generation
    - Diversity constraints
    - Budget-aware selection
    """

    def __init__(
        self,
        model: BradleyTerryModel,
        strategy: QueryStrategy = QueryStrategy.UNCERTAINTY,
        diversity_weight: float = 0.3
    ):
        """
        Initialize active query selector.

        Args:
            model: Preference model for uncertainty estimation
            strategy: Query selection strategy
            diversity_weight: Weight for diversity in hybrid strategy
        """
        self.model = model
        self.strategy = strategy
        self.diversity_weight = diversity_weight

        self._queried_pairs: Set[Tuple[ResponseId, ResponseId]] = set()

    def select_query(
        self,
        candidates: List[Tuple[Response, Response]],
        n: int = 1
    ) -> List[Tuple[Response, Response]]:
        """
        Select the most informative comparisons.

        Args:
            candidates: List of candidate (response_a, response_b) pairs
            n: Number of queries to select

        Returns:
            Selected comparison pairs
        """
        if not candidates:
            return []

        # Filter already queried pairs
        candidates = [
            (a, b) for a, b in candidates
            if (a.id, b.id) not in self._queried_pairs
            and (b.id, a.id) not in self._queried_pairs
        ]

        if not candidates:
            return []

        # Score candidates
        scored = []
        for a, b in candidates:
            score = self._compute_query_score(a, b)
            scored.append((a, b, score))

        # Select top-n with diversity
        selected = []
        while len(selected) < n and scored:
            # Sort by score
            scored.sort(key=lambda x: x[2], reverse=True)

            # Take best
            best = scored.pop(0)
            selected.append((best[0], best[1]))
            self._queried_pairs.add((best[0].id, best[1].id))

            # Reduce scores for similar candidates (diversity)
            if self.diversity_weight > 0:
                for i, (a, b, score) in enumerate(scored):
                    similarity = self._compute_similarity(best[0].id, best[1].id, a.id, b.id)
                    scored[i] = (a, b, score * (1 - self.diversity_weight * similarity))

        return selected

    def _compute_query_score(self, a: Response, b: Response) -> float:
        """Compute query informativeness score."""
        if self.strategy == QueryStrategy.UNCERTAINTY:
            return self.model.get_uncertainty(a.id, b.id)

        elif self.strategy == QueryStrategy.INFORMATION_GAIN:
            # Expected information gain approximation
            uncertainty = self.model.get_uncertainty(a.id, b.id)
            prob = self.model.predict_preference(a.id, b.id)

            # Entropy of prediction
            entropy = -prob * math.log(prob + 1e-10) - (1 - prob) * math.log(1 - prob + 1e-10)
            return uncertainty * entropy

        elif self.strategy == QueryStrategy.DIVERSITY:
            # Score based on how different this is from previous queries
            if not self._queried_pairs:
                return 1.0

            min_similarity = 1.0
            for qa, qb in self._queried_pairs:
                sim = self._compute_similarity(qa, qb, a.id, b.id)
                min_similarity = min(min_similarity, sim)

            return 1 - min_similarity

        elif self.strategy == QueryStrategy.HYBRID:
            uncertainty_score = self.model.get_uncertainty(a.id, b.id)

            diversity_score = 1.0
            if self._queried_pairs:
                min_similarity = min(
                    self._compute_similarity(qa, qb, a.id, b.id)
                    for qa, qb in self._queried_pairs
                )
                diversity_score = 1 - min_similarity

            return (1 - self.diversity_weight) * uncertainty_score + self.diversity_weight * diversity_score

        return random.random()

    def _compute_similarity(
        self,
        a1: ResponseId, b1: ResponseId,
        a2: ResponseId, b2: ResponseId
    ) -> float:
        """Compute similarity between two comparison pairs."""
        # Simple overlap-based similarity
        set1 = {a1, b1}
        set2 = {a2, b2}
        overlap = len(set1 & set2)
        return overlap / 2.0

    def generate_candidate_pairs(
        self,
        responses: List[Response],
        max_pairs: int = 1000
    ) -> List[Tuple[Response, Response]]:
        """
        Generate candidate comparison pairs from responses.

        Args:
            responses: List of responses
            max_pairs: Maximum pairs to generate

        Returns:
            List of candidate pairs
        """
        candidates = []
        for i, a in enumerate(responses):
            for b in responses[i + 1:]:
                if len(candidates) >= max_pairs:
                    return candidates
                candidates.append((a, b))

        return candidates


# =============================================================================
# PREFERENCE INTEGRATOR
# =============================================================================

class PreferenceIntegrator:
    """
    Integrates learned preferences into the training pipeline.

    Bridges the preference learning system with the RLHF training loop,
    providing preference-based rewards and training signals.

    Features:
    - Preference-to-reward conversion
    - Batch reward computation
    - Training signal generation for DPO/PPO
    - Model comparison and ranking
    """

    def __init__(
        self,
        model: BradleyTerryModel,
        dataset: PreferenceDataset,
        beta: float = 0.1
    ):
        """
        Initialize preference integrator.

        Args:
            model: Trained preference model
            dataset: Preference dataset
            beta: Temperature parameter for reward scaling
        """
        self.model = model
        self.dataset = dataset
        self.beta = beta

    def compute_reward(self, response: Response) -> float:
        """
        Compute reward for a response based on learned preferences.

        Args:
            response: Response to score

        Returns:
            Reward value
        """
        return self.beta * self.model.get_strength(response.id)

    def compute_pairwise_reward(
        self,
        response_a: Response,
        response_b: Response
    ) -> Tuple[float, float]:
        """
        Compute pairwise rewards for DPO training.

        Args:
            response_a: First response
            response_b: Second response

        Returns:
            Tuple of (reward_a, reward_b)
        """
        strength_a = self.model.get_strength(response_a.id)
        strength_b = self.model.get_strength(response_b.id)

        return self.beta * strength_a, self.beta * strength_b

    def get_training_pairs(
        self,
        n: int,
        min_margin: float = 0.1
    ) -> List[Dict[str, Any]]:
        """
        Get training pairs with clear preference margins.

        Args:
            n: Number of pairs to retrieve
            min_margin: Minimum preference probability margin

        Returns:
            List of training dictionaries
        """
        pairs = self.dataset.sample(n * 2)  # Sample more to filter
        training_pairs = []

        for pair in pairs:
            prob_a = self.model.predict_preference(
                pair.response_a.id,
                pair.response_b.id
            )

            # Check for sufficient margin
            if abs(prob_a - 0.5) >= min_margin / 2:
                if prob_a > 0.5:
                    chosen, rejected = pair.response_a, pair.response_b
                else:
                    chosen, rejected = pair.response_b, pair.response_a

                training_pairs.append({
                    "prompt": pair.prompt,
                    "chosen": chosen.text,
                    "rejected": rejected.text,
                    "chosen_id": chosen.id,
                    "rejected_id": rejected.id,
                    "margin": abs(prob_a - 0.5) * 2
                })

                if len(training_pairs) >= n:
                    break

        return training_pairs

    def rank_responses(
        self,
        prompt: str,
        responses: List[Response]
    ) -> List[Tuple[Response, float]]:
        """
        Rank responses for a prompt using preference model.

        Args:
            prompt: The prompt
            responses: Responses to rank

        Returns:
            Ranked list of (response, score) tuples
        """
        response_ids = [r.id for r in responses]
        ranked_ids = self.model.get_ranking(response_ids)

        id_to_response = {r.id: r for r in responses}
        return [(id_to_response[rid], score) for rid, score in ranked_ids]

    def get_statistics(self) -> Dict[str, Any]:
        """Get integration statistics."""
        pairs = list(self.dataset)

        # Compute agreement rate
        correct = 0
        for pair in pairs:
            pred = self.model.predict_preference(pair.response_a.id, pair.response_b.id)
            if (pred > 0.5) == (pair.label > 0.5):
                correct += 1

        return {
            "total_pairs": len(pairs),
            "model_agreement": correct / len(pairs) if pairs else 0,
            "unique_responses": len(set(
                rid for pair in pairs
                for rid in [pair.response_a.id, pair.response_b.id]
            )),
            "beta": self.beta
        }


# =============================================================================
# MAIN EXECUTION
# =============================================================================

async def main():
    """Demonstrate the preference learning system."""
    print("=" * 70)
    print("AIVA Queen - Preference Learning System for RLHF")
    print("=" * 70)

    # 1. Initialize components
    print("\n1. Initializing components...")

    collector = PreferenceCollector(
        attention_check_frequency=0.1,
        min_time_seconds=3.0
    )
    print("   - PreferenceCollector initialized")

    dataset = PreferenceDataset(db_path=":memory:")
    print("   - PreferenceDataset initialized")

    model = BradleyTerryModel(learning_rate=0.1, regularization=0.01)
    print("   - BradleyTerryModel initialized")

    learner = PreferenceLearner(model=model, strategy=LearningStrategy.MINI_BATCH)
    print("   - PreferenceLearner initialized")

    query_selector = ActivePreferenceQuery(model=model, strategy=QueryStrategy.UNCERTAINTY)
    print("   - ActivePreferenceQuery initialized")

    # 2. Generate sample responses
    print("\n2. Generating sample responses...")
    prompts_and_responses = [
        ("What is machine learning?", [
            "Machine learning is a subset of AI that enables systems to learn from data.",
            "ML is when computers learn stuff.",
            "Machine learning involves algorithms that improve through experience."
        ]),
        ("Explain quantum computing.", [
            "Quantum computing uses qubits that can exist in superposition.",
            "Quantum computers are really fast.",
            "Quantum computing leverages quantum mechanics for computation."
        ])
    ]

    all_responses = []
    for prompt, texts in prompts_and_responses:
        for text in texts:
            resp = Response(
                id=str(uuid.uuid4()),
                prompt=prompt,
                text=text
            )
            all_responses.append(resp)
    print(f"   Generated {len(all_responses)} responses")

    # 3. Collect preferences
    print("\n3. Collecting preferences...")
    annotators = ["annotator_1", "annotator_2", "annotator_3"]

    for i in range(0, len(all_responses) - 1, 2):
        for annotator in annotators[:2]:  # Use 2 annotators per comparison
            comp_id = collector.create_comparison(
                prompt=all_responses[i].prompt,
                response_a=all_responses[i],
                response_b=all_responses[i + 1]
            )

            # Simulate preference (better text gets higher preference)
            pref = PreferenceStrength.PREFER_A if len(all_responses[i].text) > 30 else PreferenceStrength.PREFER_B

            accepted, pair = await collector.submit_preference(
                comparison_id=comp_id,
                annotator_id=annotator,
                preference=pref,
                confidence=random.uniform(0.7, 1.0),
                time_spent_seconds=random.uniform(10, 60)
            )

            if accepted and pair:
                dataset.add(pair)

    print(f"   Collected {len(dataset)} preference pairs")
    print(f"   Statistics: {collector.get_statistics()}")

    # 4. Train preference model
    print("\n4. Training preference model...")
    pairs = list(dataset)
    results = learner.learn_batch(pairs, epochs=5)
    print(f"   Final train accuracy: {results['train_metrics'][-1]['accuracy']:.2%}")

    # 5. Active query selection
    print("\n5. Selecting informative queries...")
    candidates = query_selector.generate_candidate_pairs(all_responses)
    selected = query_selector.select_query(candidates, n=3)
    print(f"   Selected {len(selected)} queries for annotation")
    for a, b in selected:
        uncertainty = model.get_uncertainty(a.id, b.id)
        print(f"   - '{a.text[:30]}...' vs '{b.text[:30]}...' (uncertainty: {uncertainty:.3f})")

    # 6. Integration
    print("\n6. Integrating with training pipeline...")
    integrator = PreferenceIntegrator(model=model, dataset=dataset, beta=0.1)

    training_pairs = integrator.get_training_pairs(n=5)
    print(f"   Generated {len(training_pairs)} training pairs")

    for tp in training_pairs[:2]:
        print(f"   - Chosen: '{tp['chosen'][:40]}...' (margin: {tp['margin']:.2f})")

    stats = integrator.get_statistics()
    print(f"   Model agreement with dataset: {stats['model_agreement']:.2%}")

    # 7. Model predictions
    print("\n7. Testing preference predictions...")
    for i in range(min(3, len(all_responses) - 1)):
        a, b = all_responses[i], all_responses[i + 1]
        prob = model.predict_preference(a.id, b.id)
        print(f"   P('{a.text[:25]}...' > '{b.text[:25]}...'): {prob:.2%}")

    print("\n" + "=" * 70)
    print("Preference Learning System Demo Complete")
    print("=" * 70)

    return {
        "collector": collector,
        "dataset": dataset,
        "model": model,
        "learner": learner,
        "integrator": integrator
    }


if __name__ == "__main__":
    asyncio.run(main())
