"""
AIVA Queen Active Learning System
=================================

Production-grade active learning system for efficient model improvement.

Components:
1. UncertaintyEstimator - Estimate prediction uncertainty using multiple strategies
2. QueryStrategist - Select most informative queries using acquisition functions
3. OracleInterface - Interface with human oracle for labeling
4. LabelBudgetManager - Manage labeling budget and allocation
5. SampleSelector - Select optimal samples to label
6. ModelUpdater - Update model with new labels incrementally

This system implements pool-based active learning with support for:
- Multiple uncertainty estimation methods (entropy, margin, committee)
- Various query strategies (uncertainty, diversity, expected model change)
- Budget-aware sample selection
- Incremental model updates
- Oracle interaction management

Author: AIVA Queen System
Version: 1.0.0
"""

import abc
import logging
import json
import hashlib
import numpy as np
from datetime import datetime
from dataclasses import dataclass, field, asdict
from typing import (
    Dict, List, Optional, Any, Tuple, Callable,
    Protocol, Union, Set, Iterator
)
from enum import Enum, auto
from pathlib import Path
from collections import defaultdict
import heapq
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed


# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


# =============================================================================
# Data Classes and Enums
# =============================================================================

class UncertaintyMethod(Enum):
    """Supported uncertainty estimation methods."""
    ENTROPY = auto()
    MARGIN = auto()
    LEAST_CONFIDENCE = auto()
    COMMITTEE_DISAGREEMENT = auto()
    BAYESIAN_DROPOUT = auto()
    ENSEMBLE_VARIANCE = auto()


class QueryStrategy(Enum):
    """Supported query strategies for sample selection."""
    UNCERTAINTY_SAMPLING = auto()
    DIVERSITY_SAMPLING = auto()
    EXPECTED_MODEL_CHANGE = auto()
    QUERY_BY_COMMITTEE = auto()
    INFORMATION_DENSITY = auto()
    BATCH_MODE_SAMPLING = auto()
    HYBRID = auto()


class LabelStatus(Enum):
    """Status of a sample in the labeling pipeline."""
    UNLABELED = auto()
    QUEUED = auto()
    IN_PROGRESS = auto()
    LABELED = auto()
    REJECTED = auto()
    UNCERTAIN = auto()


@dataclass
class Sample:
    """Represents a data sample for active learning."""
    sample_id: str
    features: np.ndarray
    label: Optional[Any] = None
    status: LabelStatus = LabelStatus.UNLABELED
    uncertainty_score: float = 0.0
    diversity_score: float = 0.0
    informativeness_score: float = 0.0
    predicted_label: Optional[Any] = None
    prediction_probabilities: Optional[np.ndarray] = None
    metadata: Dict[str, Any] = field(default_factory=dict)
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())

    def to_dict(self) -> Dict[str, Any]:
        """Convert sample to dictionary."""
        return {
            'sample_id': self.sample_id,
            'features': self.features.tolist() if isinstance(self.features, np.ndarray) else self.features,
            'label': self.label,
            'status': self.status.name,
            'uncertainty_score': self.uncertainty_score,
            'diversity_score': self.diversity_score,
            'informativeness_score': self.informativeness_score,
            'predicted_label': self.predicted_label,
            'prediction_probabilities': (
                self.prediction_probabilities.tolist()
                if isinstance(self.prediction_probabilities, np.ndarray) else None
            ),
            'metadata': self.metadata,
            'timestamp': self.timestamp
        }


@dataclass
class LabelingTask:
    """Represents a labeling task for the oracle."""
    task_id: str
    sample: Sample
    priority: float
    created_at: str = field(default_factory=lambda: datetime.now().isoformat())
    assigned_at: Optional[str] = None
    completed_at: Optional[str] = None
    oracle_id: Optional[str] = None
    label_confidence: float = 0.0
    labeling_time_seconds: float = 0.0

    def __lt__(self, other: 'LabelingTask') -> bool:
        """Enable priority queue comparison."""
        return self.priority > other.priority  # Higher priority first


@dataclass
class BudgetAllocation:
    """Budget allocation configuration."""
    total_budget: int
    remaining_budget: int
    budget_per_cycle: int
    high_uncertainty_allocation: float = 0.6
    diversity_allocation: float = 0.3
    random_allocation: float = 0.1

    def can_label(self) -> bool:
        """Check if budget allows more labeling."""
        return self.remaining_budget > 0

    def consume(self, amount: int = 1) -> bool:
        """Consume budget if available."""
        if self.remaining_budget >= amount:
            self.remaining_budget -= amount
            return True
        return False


@dataclass
class ModelUpdateResult:
    """Result of a model update operation."""
    success: bool
    samples_used: int
    previous_metrics: Dict[str, float]
    new_metrics: Dict[str, float]
    improvement: Dict[str, float]
    update_duration_seconds: float
    error_message: Optional[str] = None


# =============================================================================
# Protocol Definitions
# =============================================================================

class PredictionModel(Protocol):
    """Protocol for prediction models compatible with active learning."""

    def predict(self, features: np.ndarray) -> np.ndarray:
        """Predict labels for features."""
        ...

    def predict_proba(self, features: np.ndarray) -> np.ndarray:
        """Predict class probabilities."""
        ...

    def fit(self, features: np.ndarray, labels: np.ndarray) -> None:
        """Fit model to data."""
        ...

    def partial_fit(self, features: np.ndarray, labels: np.ndarray) -> None:
        """Incrementally fit model to new data."""
        ...


# =============================================================================
# UncertaintyEstimator
# =============================================================================

class UncertaintyEstimator:
    """
    Estimate prediction uncertainty using multiple strategies.

    Supports:
    - Entropy-based uncertainty
    - Margin sampling
    - Least confidence
    - Committee disagreement (Query-by-Committee)
    - Bayesian dropout approximation
    - Ensemble variance
    """

    def __init__(
        self,
        method: UncertaintyMethod = UncertaintyMethod.ENTROPY,
        committee_models: Optional[List[PredictionModel]] = None,
        dropout_iterations: int = 10,
        epsilon: float = 1e-10
    ):
        """
        Initialize uncertainty estimator.

        Args:
            method: Primary uncertainty estimation method
            committee_models: List of models for committee-based methods
            dropout_iterations: Number of forward passes for dropout-based uncertainty
            epsilon: Small value to prevent log(0)
        """
        self.method = method
        self.committee_models = committee_models or []
        self.dropout_iterations = dropout_iterations
        self.epsilon = epsilon

        self._method_handlers: Dict[UncertaintyMethod, Callable] = {
            UncertaintyMethod.ENTROPY: self._entropy_uncertainty,
            UncertaintyMethod.MARGIN: self._margin_uncertainty,
            UncertaintyMethod.LEAST_CONFIDENCE: self._least_confidence_uncertainty,
            UncertaintyMethod.COMMITTEE_DISAGREEMENT: self._committee_disagreement,
            UncertaintyMethod.BAYESIAN_DROPOUT: self._bayesian_dropout,
            UncertaintyMethod.ENSEMBLE_VARIANCE: self._ensemble_variance
        }

        logger.info(f"UncertaintyEstimator initialized with method: {method.name}")

    def estimate(
        self,
        model: PredictionModel,
        samples: List[Sample],
        method: Optional[UncertaintyMethod] = None
    ) -> List[float]:
        """
        Estimate uncertainty for a list of samples.

        Args:
            model: Prediction model to use
            samples: List of samples to estimate uncertainty for
            method: Override default method

        Returns:
            List of uncertainty scores (higher = more uncertain)
        """
        method = method or self.method

        if not samples:
            return []

        features = np.array([s.features for s in samples])

        try:
            probabilities = model.predict_proba(features)
        except Exception as e:
            logger.warning(f"Could not get probabilities: {e}. Using uniform.")
            n_classes = 2  # Default assumption
            probabilities = np.ones((len(samples), n_classes)) / n_classes

        handler = self._method_handlers.get(method, self._entropy_uncertainty)
        uncertainties = handler(model, features, probabilities)

        # Update samples with uncertainty scores
        for sample, uncertainty in zip(samples, uncertainties):
            sample.uncertainty_score = float(uncertainty)
            sample.prediction_probabilities = probabilities[samples.index(sample)]

        return list(uncertainties)

    def _entropy_uncertainty(
        self,
        model: PredictionModel,
        features: np.ndarray,
        probabilities: np.ndarray
    ) -> np.ndarray:
        """Calculate entropy-based uncertainty."""
        # H(p) = -sum(p * log(p))
        probs_clipped = np.clip(probabilities, self.epsilon, 1 - self.epsilon)
        entropy = -np.sum(probs_clipped * np.log(probs_clipped), axis=1)

        # Normalize by max possible entropy
        max_entropy = np.log(probabilities.shape[1])
        return entropy / max_entropy if max_entropy > 0 else entropy

    def _margin_uncertainty(
        self,
        model: PredictionModel,
        features: np.ndarray,
        probabilities: np.ndarray
    ) -> np.ndarray:
        """Calculate margin sampling uncertainty (1 - margin between top 2 predictions)."""
        sorted_probs = np.sort(probabilities, axis=1)[:, ::-1]
        margin = sorted_probs[:, 0] - sorted_probs[:, 1]
        return 1 - margin  # Higher uncertainty = smaller margin

    def _least_confidence_uncertainty(
        self,
        model: PredictionModel,
        features: np.ndarray,
        probabilities: np.ndarray
    ) -> np.ndarray:
        """Calculate least confidence uncertainty (1 - max probability)."""
        max_probs = np.max(probabilities, axis=1)
        return 1 - max_probs

    def _committee_disagreement(
        self,
        model: PredictionModel,
        features: np.ndarray,
        probabilities: np.ndarray
    ) -> np.ndarray:
        """Calculate disagreement among committee models."""
        if not self.committee_models:
            logger.warning("No committee models provided, falling back to entropy")
            return self._entropy_uncertainty(model, features, probabilities)

        # Get predictions from all committee members
        all_predictions = []
        for committee_model in self.committee_models:
            try:
                preds = committee_model.predict(features)
                all_predictions.append(preds)
            except Exception as e:
                logger.warning(f"Committee model prediction failed: {e}")

        if not all_predictions:
            return self._entropy_uncertainty(model, features, probabilities)

        predictions_array = np.array(all_predictions)

        # Calculate vote entropy (disagreement)
        n_samples = features.shape[0]
        disagreements = np.zeros(n_samples)

        for i in range(n_samples):
            votes = predictions_array[:, i]
            unique, counts = np.unique(votes, return_counts=True)
            vote_probs = counts / len(self.committee_models)
            vote_probs_clipped = np.clip(vote_probs, self.epsilon, 1 - self.epsilon)
            disagreements[i] = -np.sum(vote_probs_clipped * np.log(vote_probs_clipped))

        return disagreements

    def _bayesian_dropout(
        self,
        model: PredictionModel,
        features: np.ndarray,
        probabilities: np.ndarray
    ) -> np.ndarray:
        """Estimate uncertainty using MC Dropout approximation."""
        # Collect multiple stochastic forward passes
        all_probs = [probabilities]  # Use provided as first sample

        for _ in range(self.dropout_iterations - 1):
            try:
                probs = model.predict_proba(features)
                all_probs.append(probs)
            except Exception:
                all_probs.append(probabilities)

        probs_array = np.array(all_probs)

        # Calculate variance across samples
        variance = np.mean(np.var(probs_array, axis=0), axis=1)

        return variance

    def _ensemble_variance(
        self,
        model: PredictionModel,
        features: np.ndarray,
        probabilities: np.ndarray
    ) -> np.ndarray:
        """Calculate variance across ensemble predictions."""
        if not self.committee_models:
            return self._bayesian_dropout(model, features, probabilities)

        all_probs = [probabilities]

        for ensemble_model in self.committee_models:
            try:
                probs = ensemble_model.predict_proba(features)
                all_probs.append(probs)
            except Exception as e:
                logger.warning(f"Ensemble model failed: {e}")

        probs_array = np.array(all_probs)
        variance = np.mean(np.var(probs_array, axis=0), axis=1)

        return variance

    def get_combined_uncertainty(
        self,
        model: PredictionModel,
        samples: List[Sample],
        weights: Optional[Dict[UncertaintyMethod, float]] = None
    ) -> List[float]:
        """
        Get combined uncertainty using multiple methods.

        Args:
            model: Prediction model
            samples: Samples to evaluate
            weights: Method weights (default: equal weighting)

        Returns:
            Combined uncertainty scores
        """
        if weights is None:
            weights = {
                UncertaintyMethod.ENTROPY: 0.4,
                UncertaintyMethod.MARGIN: 0.3,
                UncertaintyMethod.LEAST_CONFIDENCE: 0.3
            }

        combined_scores = np.zeros(len(samples))
        total_weight = sum(weights.values())

        for method, weight in weights.items():
            scores = np.array(self.estimate(model, samples, method))
            combined_scores += (weight / total_weight) * scores

        return list(combined_scores)


# =============================================================================
# QueryStrategist
# =============================================================================

class QueryStrategist:
    """
    Select most informative queries using various acquisition functions.

    Implements multiple query strategies:
    - Uncertainty sampling: Select most uncertain samples
    - Diversity sampling: Select diverse samples to cover feature space
    - Expected model change: Estimate potential model improvement
    - Information density: Balance uncertainty with representativeness
    - Batch mode: Select batches considering redundancy
    - Hybrid: Combine multiple strategies
    """

    def __init__(
        self,
        strategy: QueryStrategy = QueryStrategy.UNCERTAINTY_SAMPLING,
        uncertainty_estimator: Optional[UncertaintyEstimator] = None,
        diversity_weight: float = 0.3,
        batch_diversity_factor: float = 0.5
    ):
        """
        Initialize query strategist.

        Args:
            strategy: Primary query strategy
            uncertainty_estimator: Uncertainty estimation component
            diversity_weight: Weight for diversity in hybrid strategies
            batch_diversity_factor: Factor for enforcing batch diversity
        """
        self.strategy = strategy
        self.uncertainty_estimator = uncertainty_estimator or UncertaintyEstimator()
        self.diversity_weight = diversity_weight
        self.batch_diversity_factor = batch_diversity_factor

        self._strategy_handlers: Dict[QueryStrategy, Callable] = {
            QueryStrategy.UNCERTAINTY_SAMPLING: self._uncertainty_sampling,
            QueryStrategy.DIVERSITY_SAMPLING: self._diversity_sampling,
            QueryStrategy.EXPECTED_MODEL_CHANGE: self._expected_model_change,
            QueryStrategy.QUERY_BY_COMMITTEE: self._query_by_committee,
            QueryStrategy.INFORMATION_DENSITY: self._information_density,
            QueryStrategy.BATCH_MODE_SAMPLING: self._batch_mode_sampling,
            QueryStrategy.HYBRID: self._hybrid_sampling
        }

        logger.info(f"QueryStrategist initialized with strategy: {strategy.name}")

    def select_queries(
        self,
        model: PredictionModel,
        unlabeled_pool: List[Sample],
        labeled_pool: List[Sample],
        n_queries: int,
        strategy: Optional[QueryStrategy] = None
    ) -> List[Sample]:
        """
        Select most informative samples to query.

        Args:
            model: Current prediction model
            unlabeled_pool: Pool of unlabeled samples
            labeled_pool: Already labeled samples (for diversity)
            n_queries: Number of samples to select
            strategy: Override default strategy

        Returns:
            List of selected samples ordered by informativeness
        """
        if not unlabeled_pool:
            return []

        n_queries = min(n_queries, len(unlabeled_pool))
        strategy = strategy or self.strategy

        handler = self._strategy_handlers.get(strategy, self._uncertainty_sampling)

        try:
            selected = handler(model, unlabeled_pool, labeled_pool, n_queries)
        except Exception as e:
            logger.error(f"Query selection failed: {e}. Falling back to uncertainty.")
            selected = self._uncertainty_sampling(
                model, unlabeled_pool, labeled_pool, n_queries
            )

        # Update informativeness scores
        for i, sample in enumerate(selected):
            sample.informativeness_score = 1.0 - (i / n_queries) if n_queries > 0 else 1.0

        return selected

    def _uncertainty_sampling(
        self,
        model: PredictionModel,
        unlabeled_pool: List[Sample],
        labeled_pool: List[Sample],
        n_queries: int
    ) -> List[Sample]:
        """Select samples with highest uncertainty."""
        uncertainties = self.uncertainty_estimator.estimate(model, unlabeled_pool)

        # Sort by uncertainty (descending)
        indexed_uncertainties = list(enumerate(uncertainties))
        indexed_uncertainties.sort(key=lambda x: x[1], reverse=True)

        selected_indices = [idx for idx, _ in indexed_uncertainties[:n_queries]]
        return [unlabeled_pool[i] for i in selected_indices]

    def _diversity_sampling(
        self,
        model: PredictionModel,
        unlabeled_pool: List[Sample],
        labeled_pool: List[Sample],
        n_queries: int
    ) -> List[Sample]:
        """Select diverse samples to maximize coverage."""
        features = np.array([s.features for s in unlabeled_pool])

        # Use k-means++ style initialization for diversity
        selected_indices = []
        selected_features = []

        if labeled_pool:
            selected_features = [s.features for s in labeled_pool[-10:]]  # Use recent labeled

        # First sample: furthest from existing labeled data
        if selected_features:
            ref_features = np.array(selected_features)
            distances = np.min(
                np.sum((features[:, np.newaxis, :] - ref_features[np.newaxis, :, :]) ** 2, axis=2),
                axis=1
            )
        else:
            # Random first selection
            distances = np.random.random(len(unlabeled_pool))

        first_idx = np.argmax(distances)
        selected_indices.append(first_idx)
        selected_features.append(features[first_idx])

        # Subsequent samples: maximize minimum distance to already selected
        while len(selected_indices) < n_queries:
            selected_array = np.array(selected_features)

            min_distances = np.min(
                np.sum((features[:, np.newaxis, :] - selected_array[np.newaxis, :, :]) ** 2, axis=2),
                axis=1
            )

            # Mask already selected
            min_distances[selected_indices] = -np.inf

            next_idx = np.argmax(min_distances)
            selected_indices.append(next_idx)
            selected_features.append(features[next_idx])

        # Calculate diversity scores
        for idx in selected_indices:
            sample = unlabeled_pool[idx]
            sample.diversity_score = float(min_distances[idx]) if min_distances[idx] > -np.inf else 0.0

        return [unlabeled_pool[i] for i in selected_indices]

    def _expected_model_change(
        self,
        model: PredictionModel,
        unlabeled_pool: List[Sample],
        labeled_pool: List[Sample],
        n_queries: int
    ) -> List[Sample]:
        """Select samples expected to cause largest model change."""
        features = np.array([s.features for s in unlabeled_pool])

        try:
            probabilities = model.predict_proba(features)
        except Exception:
            return self._uncertainty_sampling(model, unlabeled_pool, labeled_pool, n_queries)

        # Estimate gradient magnitude as proxy for expected model change
        # Using probability distribution entropy as approximation
        n_classes = probabilities.shape[1]
        expected_changes = np.zeros(len(unlabeled_pool))

        for i, probs in enumerate(probabilities):
            # Expected gradient is higher when predictions are uncertain
            # and when sample is representative
            entropy = -np.sum(probs * np.log(probs + 1e-10))
            gradient_magnitude = entropy * np.linalg.norm(features[i])
            expected_changes[i] = gradient_magnitude

        # Sort by expected change
        indexed_changes = list(enumerate(expected_changes))
        indexed_changes.sort(key=lambda x: x[1], reverse=True)

        selected_indices = [idx for idx, _ in indexed_changes[:n_queries]]
        return [unlabeled_pool[i] for i in selected_indices]

    def _query_by_committee(
        self,
        model: PredictionModel,
        unlabeled_pool: List[Sample],
        labeled_pool: List[Sample],
        n_queries: int
    ) -> List[Sample]:
        """Select samples with highest committee disagreement."""
        self.uncertainty_estimator.method = UncertaintyMethod.COMMITTEE_DISAGREEMENT
        disagreements = self.uncertainty_estimator.estimate(model, unlabeled_pool)

        indexed_disagreements = list(enumerate(disagreements))
        indexed_disagreements.sort(key=lambda x: x[1], reverse=True)

        selected_indices = [idx for idx, _ in indexed_disagreements[:n_queries]]
        return [unlabeled_pool[i] for i in selected_indices]

    def _information_density(
        self,
        model: PredictionModel,
        unlabeled_pool: List[Sample],
        labeled_pool: List[Sample],
        n_queries: int
    ) -> List[Sample]:
        """Balance uncertainty with representativeness (information density)."""
        # Get uncertainty scores
        uncertainties = np.array(self.uncertainty_estimator.estimate(model, unlabeled_pool))

        # Calculate representativeness (average similarity to other unlabeled samples)
        features = np.array([s.features for s in unlabeled_pool])

        # Pairwise distances
        n_samples = len(unlabeled_pool)
        representativeness = np.zeros(n_samples)

        for i in range(n_samples):
            distances = np.linalg.norm(features - features[i], axis=1)
            # Similarity is inverse of average distance
            avg_distance = np.mean(distances[distances > 0])  # Exclude self
            representativeness[i] = 1.0 / (avg_distance + 1e-10)

        # Normalize
        representativeness = representativeness / (np.max(representativeness) + 1e-10)

        # Information density = uncertainty * representativeness^beta
        beta = 1.0  # Weight for representativeness
        info_density = uncertainties * (representativeness ** beta)

        indexed_density = list(enumerate(info_density))
        indexed_density.sort(key=lambda x: x[1], reverse=True)

        selected_indices = [idx for idx, _ in indexed_density[:n_queries]]
        return [unlabeled_pool[i] for i in selected_indices]

    def _batch_mode_sampling(
        self,
        model: PredictionModel,
        unlabeled_pool: List[Sample],
        labeled_pool: List[Sample],
        n_queries: int
    ) -> List[Sample]:
        """Select diverse batch while maintaining informativeness."""
        # First, get uncertainty ranking
        uncertainties = np.array(self.uncertainty_estimator.estimate(model, unlabeled_pool))

        # Pre-filter to top candidates (3x requested)
        n_candidates = min(3 * n_queries, len(unlabeled_pool))
        candidate_indices = np.argsort(uncertainties)[::-1][:n_candidates]
        candidate_pool = [unlabeled_pool[i] for i in candidate_indices]

        # Apply diversity selection on candidates
        selected = self._diversity_sampling(model, candidate_pool, labeled_pool, n_queries)

        return selected

    def _hybrid_sampling(
        self,
        model: PredictionModel,
        unlabeled_pool: List[Sample],
        labeled_pool: List[Sample],
        n_queries: int
    ) -> List[Sample]:
        """Combine uncertainty and diversity sampling."""
        # Get uncertainty scores
        uncertainties = np.array(self.uncertainty_estimator.estimate(model, unlabeled_pool))

        # Get diversity scores
        features = np.array([s.features for s in unlabeled_pool])

        # Calculate diversity as distance to nearest labeled sample
        if labeled_pool:
            labeled_features = np.array([s.features for s in labeled_pool])
            diversity_scores = np.min(
                np.linalg.norm(
                    features[:, np.newaxis, :] - labeled_features[np.newaxis, :, :],
                    axis=2
                ),
                axis=1
            )
        else:
            diversity_scores = np.ones(len(unlabeled_pool))

        # Normalize scores
        uncertainties = uncertainties / (np.max(uncertainties) + 1e-10)
        diversity_scores = diversity_scores / (np.max(diversity_scores) + 1e-10)

        # Combine scores
        combined_scores = (
            (1 - self.diversity_weight) * uncertainties +
            self.diversity_weight * diversity_scores
        )

        indexed_scores = list(enumerate(combined_scores))
        indexed_scores.sort(key=lambda x: x[1], reverse=True)

        selected_indices = [idx for idx, _ in indexed_scores[:n_queries]]
        return [unlabeled_pool[i] for i in selected_indices]


# =============================================================================
# OracleInterface
# =============================================================================

class OracleInterface:
    """
    Interface for human oracle labeling operations.

    Features:
    - Task queue management
    - Multi-oracle support
    - Label quality tracking
    - Labeling history
    - Conflict resolution
    """

    def __init__(
        self,
        oracle_ids: Optional[List[str]] = None,
        require_consensus: bool = False,
        consensus_threshold: int = 2,
        max_queue_size: int = 1000,
        timeout_seconds: float = 3600.0
    ):
        """
        Initialize oracle interface.

        Args:
            oracle_ids: List of available oracle identifiers
            require_consensus: Whether to require multiple labels
            consensus_threshold: Number of agreeing labels required
            max_queue_size: Maximum pending tasks in queue
            timeout_seconds: Task timeout duration
        """
        self.oracle_ids = oracle_ids or ["default_oracle"]
        self.require_consensus = require_consensus
        self.consensus_threshold = consensus_threshold
        self.max_queue_size = max_queue_size
        self.timeout_seconds = timeout_seconds

        self._task_queue: List[LabelingTask] = []
        self._pending_labels: Dict[str, List[Tuple[Any, str, float]]] = defaultdict(list)
        self._labeling_history: List[Dict[str, Any]] = []
        self._oracle_stats: Dict[str, Dict[str, float]] = {
            oid: {
                'tasks_completed': 0,
                'avg_time': 0.0,
                'avg_confidence': 0.0,
                'agreement_rate': 0.0
            }
            for oid in self.oracle_ids
        }

        self._lock = threading.Lock()

        logger.info(f"OracleInterface initialized with {len(self.oracle_ids)} oracles")

    def submit_for_labeling(
        self,
        samples: List[Sample],
        priorities: Optional[List[float]] = None
    ) -> List[str]:
        """
        Submit samples to the labeling queue.

        Args:
            samples: Samples to label
            priorities: Priority scores (higher = more urgent)

        Returns:
            List of task IDs
        """
        if priorities is None:
            priorities = [s.informativeness_score for s in samples]

        task_ids = []

        with self._lock:
            for sample, priority in zip(samples, priorities):
                if len(self._task_queue) >= self.max_queue_size:
                    logger.warning("Task queue full, cannot add more tasks")
                    break

                task_id = self._generate_task_id(sample)
                task = LabelingTask(
                    task_id=task_id,
                    sample=sample,
                    priority=priority
                )

                heapq.heappush(self._task_queue, task)
                sample.status = LabelStatus.QUEUED
                task_ids.append(task_id)

        logger.info(f"Submitted {len(task_ids)} samples for labeling")
        return task_ids

    def get_next_task(self, oracle_id: str) -> Optional[LabelingTask]:
        """
        Get next task for an oracle to label.

        Args:
            oracle_id: Identifier of requesting oracle

        Returns:
            Next labeling task or None if queue is empty
        """
        with self._lock:
            if not self._task_queue:
                return None

            task = heapq.heappop(self._task_queue)
            task.assigned_at = datetime.now().isoformat()
            task.oracle_id = oracle_id
            task.sample.status = LabelStatus.IN_PROGRESS

            return task

    def submit_label(
        self,
        task_id: str,
        label: Any,
        oracle_id: str,
        confidence: float = 1.0,
        labeling_time: float = 0.0,
        metadata: Optional[Dict[str, Any]] = None
    ) -> bool:
        """
        Submit a label for a task.

        Args:
            task_id: Task identifier
            label: Assigned label
            oracle_id: Oracle that provided label
            confidence: Oracle's confidence in label
            labeling_time: Time spent labeling
            metadata: Additional labeling metadata

        Returns:
            True if label was accepted
        """
        with self._lock:
            self._pending_labels[task_id].append((label, oracle_id, confidence))

            # Update oracle statistics
            stats = self._oracle_stats.get(oracle_id, {
                'tasks_completed': 0,
                'avg_time': 0.0,
                'avg_confidence': 0.0,
                'agreement_rate': 0.0
            })
            n = stats['tasks_completed']
            stats['tasks_completed'] = n + 1
            stats['avg_time'] = (stats['avg_time'] * n + labeling_time) / (n + 1)
            stats['avg_confidence'] = (stats['avg_confidence'] * n + confidence) / (n + 1)
            self._oracle_stats[oracle_id] = stats

            # Log history
            self._labeling_history.append({
                'task_id': task_id,
                'label': label,
                'oracle_id': oracle_id,
                'confidence': confidence,
                'labeling_time': labeling_time,
                'timestamp': datetime.now().isoformat(),
                'metadata': metadata
            })

            return True

    def resolve_label(self, task_id: str, sample: Sample) -> Tuple[bool, Optional[Any]]:
        """
        Resolve final label for a task.

        Args:
            task_id: Task identifier
            sample: Sample being labeled

        Returns:
            Tuple of (success, resolved_label)
        """
        with self._lock:
            pending = self._pending_labels.get(task_id, [])

            if not pending:
                return False, None

            if self.require_consensus:
                return self._resolve_with_consensus(task_id, sample, pending)
            else:
                # Take first label with highest confidence
                label, oracle_id, confidence = max(pending, key=lambda x: x[2])
                sample.label = label
                sample.status = LabelStatus.LABELED
                sample.metadata['labeling_confidence'] = confidence
                sample.metadata['labeled_by'] = oracle_id

                del self._pending_labels[task_id]
                return True, label

    def _resolve_with_consensus(
        self,
        task_id: str,
        sample: Sample,
        pending: List[Tuple[Any, str, float]]
    ) -> Tuple[bool, Optional[Any]]:
        """Resolve label requiring consensus."""
        if len(pending) < self.consensus_threshold:
            return False, None

        # Count votes
        label_votes: Dict[Any, float] = defaultdict(float)
        for label, oracle_id, confidence in pending:
            label_votes[label] += confidence

        # Find majority label
        best_label = max(label_votes.items(), key=lambda x: x[1])

        # Check if consensus is reached
        total_confidence = sum(label_votes.values())
        consensus_ratio = best_label[1] / total_confidence if total_confidence > 0 else 0

        if consensus_ratio >= 0.5:  # Majority agreement
            sample.label = best_label[0]
            sample.status = LabelStatus.LABELED
            sample.metadata['consensus_ratio'] = consensus_ratio
            sample.metadata['n_labels'] = len(pending)

            del self._pending_labels[task_id]
            return True, best_label[0]
        else:
            sample.status = LabelStatus.UNCERTAIN
            return False, None

    def get_queue_status(self) -> Dict[str, Any]:
        """Get current queue statistics."""
        with self._lock:
            return {
                'queue_size': len(self._task_queue),
                'pending_resolution': len(self._pending_labels),
                'total_labeled': len(self._labeling_history),
                'oracle_stats': dict(self._oracle_stats)
            }

    def simulate_labeling(
        self,
        label_function: Callable[[Sample], Any],
        n_tasks: Optional[int] = None
    ) -> int:
        """
        Simulate labeling for testing purposes.

        Args:
            label_function: Function to generate labels
            n_tasks: Number of tasks to process (default: all)

        Returns:
            Number of samples labeled
        """
        labeled_count = 0
        n_tasks = n_tasks or len(self._task_queue)

        for _ in range(min(n_tasks, len(self._task_queue))):
            task = self.get_next_task(self.oracle_ids[0])
            if task is None:
                break

            try:
                label = label_function(task.sample)
                self.submit_label(
                    task.task_id,
                    label,
                    self.oracle_ids[0],
                    confidence=0.95,
                    labeling_time=1.0
                )
                success, _ = self.resolve_label(task.task_id, task.sample)
                if success:
                    labeled_count += 1
            except Exception as e:
                logger.error(f"Simulated labeling failed: {e}")
                task.sample.status = LabelStatus.REJECTED

        return labeled_count

    def _generate_task_id(self, sample: Sample) -> str:
        """Generate unique task ID."""
        content = f"{sample.sample_id}_{datetime.now().isoformat()}"
        return hashlib.md5(content.encode()).hexdigest()[:16]


# =============================================================================
# LabelBudgetManager
# =============================================================================

class LabelBudgetManager:
    """
    Manage labeling budget allocation and tracking.

    Features:
    - Budget allocation by strategy
    - Cost tracking per label
    - ROI estimation
    - Budget forecasting
    - Dynamic reallocation
    """

    def __init__(
        self,
        total_budget: int,
        budget_per_cycle: int = 100,
        cost_per_label: float = 1.0,
        high_uncertainty_ratio: float = 0.6,
        diversity_ratio: float = 0.3,
        exploration_ratio: float = 0.1
    ):
        """
        Initialize budget manager.

        Args:
            total_budget: Total labeling budget (number of labels)
            budget_per_cycle: Labels to acquire per active learning cycle
            cost_per_label: Cost per label for ROI calculations
            high_uncertainty_ratio: Budget ratio for high-uncertainty samples
            diversity_ratio: Budget ratio for diversity sampling
            exploration_ratio: Budget ratio for random exploration
        """
        self.allocation = BudgetAllocation(
            total_budget=total_budget,
            remaining_budget=total_budget,
            budget_per_cycle=budget_per_cycle,
            high_uncertainty_allocation=high_uncertainty_ratio,
            diversity_allocation=diversity_ratio,
            random_allocation=exploration_ratio
        )

        self.cost_per_label = cost_per_label
        self._spending_history: List[Dict[str, Any]] = []
        self._roi_history: List[Dict[str, float]] = []
        self._cycle_count: int = 0

        logger.info(
            f"LabelBudgetManager initialized: total={total_budget}, "
            f"per_cycle={budget_per_cycle}"
        )

    def get_cycle_budget(self) -> Dict[str, int]:
        """
        Get budget allocation for current cycle.

        Returns:
            Dictionary mapping strategy to allocated budget
        """
        cycle_budget = min(
            self.allocation.budget_per_cycle,
            self.allocation.remaining_budget
        )

        return {
            'high_uncertainty': int(cycle_budget * self.allocation.high_uncertainty_allocation),
            'diversity': int(cycle_budget * self.allocation.diversity_allocation),
            'exploration': int(cycle_budget * self.allocation.random_allocation),
            'total': cycle_budget
        }

    def consume_budget(
        self,
        amount: int,
        strategy: str,
        metrics_before: Dict[str, float],
        metrics_after: Dict[str, float]
    ) -> bool:
        """
        Consume budget and track spending.

        Args:
            amount: Number of labels consumed
            strategy: Strategy used for selection
            metrics_before: Model metrics before labeling
            metrics_after: Model metrics after labeling

        Returns:
            True if budget was consumed successfully
        """
        if not self.allocation.consume(amount):
            logger.warning(f"Insufficient budget for {amount} labels")
            return False

        # Calculate improvement
        improvement = {
            key: metrics_after.get(key, 0) - metrics_before.get(key, 0)
            for key in metrics_after
        }

        # Calculate ROI
        cost = amount * self.cost_per_label
        avg_improvement = np.mean(list(improvement.values())) if improvement else 0
        roi = avg_improvement / cost if cost > 0 else 0

        # Track spending
        spending_entry = {
            'cycle': self._cycle_count,
            'amount': amount,
            'strategy': strategy,
            'cost': cost,
            'improvement': improvement,
            'roi': roi,
            'remaining_budget': self.allocation.remaining_budget,
            'timestamp': datetime.now().isoformat()
        }
        self._spending_history.append(spending_entry)
        self._roi_history.append({'cycle': self._cycle_count, 'roi': roi})

        self._cycle_count += 1

        logger.info(
            f"Budget consumed: {amount} labels, ROI: {roi:.4f}, "
            f"remaining: {self.allocation.remaining_budget}"
        )

        return True

    def estimate_remaining_cycles(self) -> int:
        """Estimate number of remaining active learning cycles."""
        if self.allocation.budget_per_cycle == 0:
            return 0
        return self.allocation.remaining_budget // self.allocation.budget_per_cycle

    def get_roi_trend(self) -> Dict[str, float]:
        """Get ROI trend analysis."""
        if len(self._roi_history) < 2:
            return {'trend': 0.0, 'average': 0.0, 'latest': 0.0}

        rois = [entry['roi'] for entry in self._roi_history]

        # Calculate trend (simple linear regression slope)
        x = np.arange(len(rois))
        slope = np.polyfit(x, rois, 1)[0]

        return {
            'trend': float(slope),
            'average': float(np.mean(rois)),
            'latest': float(rois[-1]),
            'max': float(np.max(rois)),
            'min': float(np.min(rois))
        }

    def should_continue(self, min_roi: float = 0.01) -> bool:
        """
        Determine if active learning should continue.

        Args:
            min_roi: Minimum acceptable ROI

        Returns:
            True if active learning should continue
        """
        if not self.allocation.can_label():
            return False

        roi_trend = self.get_roi_trend()

        # Stop if ROI is consistently below threshold
        if len(self._roi_history) >= 3:
            recent_rois = [e['roi'] for e in self._roi_history[-3:]]
            if all(r < min_roi for r in recent_rois):
                logger.info("Stopping: ROI below threshold for 3 consecutive cycles")
                return False

        return True

    def reallocate_budget(
        self,
        strategy_performance: Dict[str, float]
    ) -> None:
        """
        Dynamically reallocate budget based on strategy performance.

        Args:
            strategy_performance: Performance scores per strategy
        """
        total_perf = sum(strategy_performance.values())
        if total_perf == 0:
            return

        # Normalize and update allocations
        self.allocation.high_uncertainty_allocation = (
            strategy_performance.get('high_uncertainty', 0.6) / total_perf
        )
        self.allocation.diversity_allocation = (
            strategy_performance.get('diversity', 0.3) / total_perf
        )
        self.allocation.random_allocation = (
            strategy_performance.get('exploration', 0.1) / total_perf
        )

        logger.info(f"Budget reallocated: {self.allocation}")

    def get_status(self) -> Dict[str, Any]:
        """Get comprehensive budget status."""
        return {
            'total_budget': self.allocation.total_budget,
            'remaining_budget': self.allocation.remaining_budget,
            'budget_used': self.allocation.total_budget - self.allocation.remaining_budget,
            'utilization_rate': 1 - (
                self.allocation.remaining_budget / self.allocation.total_budget
            ),
            'cycles_completed': self._cycle_count,
            'estimated_remaining_cycles': self.estimate_remaining_cycles(),
            'roi_trend': self.get_roi_trend(),
            'allocation': {
                'high_uncertainty': self.allocation.high_uncertainty_allocation,
                'diversity': self.allocation.diversity_allocation,
                'exploration': self.allocation.random_allocation
            }
        }


# =============================================================================
# SampleSelector
# =============================================================================

class SampleSelector:
    """
    Select optimal samples for labeling.

    Integrates uncertainty estimation, query strategies, and budget management
    to select the most valuable samples for labeling.
    """

    def __init__(
        self,
        query_strategist: Optional[QueryStrategist] = None,
        budget_manager: Optional[LabelBudgetManager] = None,
        min_uncertainty_threshold: float = 0.1,
        max_samples_per_selection: int = 100
    ):
        """
        Initialize sample selector.

        Args:
            query_strategist: Query strategy component
            budget_manager: Budget management component
            min_uncertainty_threshold: Minimum uncertainty to consider
            max_samples_per_selection: Maximum samples per selection batch
        """
        self.query_strategist = query_strategist or QueryStrategist()
        self.budget_manager = budget_manager
        self.min_uncertainty_threshold = min_uncertainty_threshold
        self.max_samples_per_selection = max_samples_per_selection

        self._selection_history: List[Dict[str, Any]] = []

        logger.info("SampleSelector initialized")

    def select_samples(
        self,
        model: PredictionModel,
        unlabeled_pool: List[Sample],
        labeled_pool: List[Sample],
        n_samples: Optional[int] = None
    ) -> List[Sample]:
        """
        Select samples for labeling.

        Args:
            model: Current prediction model
            unlabeled_pool: Pool of unlabeled samples
            labeled_pool: Already labeled samples
            n_samples: Number of samples to select (uses budget if None)

        Returns:
            List of selected samples
        """
        if not unlabeled_pool:
            return []

        # Determine number of samples to select
        if n_samples is None and self.budget_manager:
            cycle_budget = self.budget_manager.get_cycle_budget()
            n_samples = cycle_budget['total']
        elif n_samples is None:
            n_samples = self.max_samples_per_selection

        n_samples = min(n_samples, len(unlabeled_pool), self.max_samples_per_selection)

        # Filter by minimum uncertainty threshold
        self.query_strategist.uncertainty_estimator.estimate(model, unlabeled_pool)
        filtered_pool = [
            s for s in unlabeled_pool
            if s.uncertainty_score >= self.min_uncertainty_threshold
        ]

        if not filtered_pool:
            logger.warning("No samples above uncertainty threshold")
            filtered_pool = unlabeled_pool

        # Select using query strategy
        selected = self.query_strategist.select_queries(
            model=model,
            unlabeled_pool=filtered_pool,
            labeled_pool=labeled_pool,
            n_queries=n_samples
        )

        # Record selection
        self._selection_history.append({
            'n_selected': len(selected),
            'n_pool': len(unlabeled_pool),
            'n_filtered': len(filtered_pool),
            'avg_uncertainty': np.mean([s.uncertainty_score for s in selected]) if selected else 0,
            'avg_diversity': np.mean([s.diversity_score for s in selected]) if selected else 0,
            'timestamp': datetime.now().isoformat()
        })

        return selected

    def select_by_strategy_mix(
        self,
        model: PredictionModel,
        unlabeled_pool: List[Sample],
        labeled_pool: List[Sample]
    ) -> Dict[str, List[Sample]]:
        """
        Select samples using multiple strategies based on budget allocation.

        Args:
            model: Current prediction model
            unlabeled_pool: Pool of unlabeled samples
            labeled_pool: Already labeled samples

        Returns:
            Dictionary mapping strategy name to selected samples
        """
        if not self.budget_manager:
            selected = self.select_samples(model, unlabeled_pool, labeled_pool)
            return {'default': selected}

        cycle_budget = self.budget_manager.get_cycle_budget()
        results = {}

        # Select high-uncertainty samples
        if cycle_budget['high_uncertainty'] > 0:
            self.query_strategist.strategy = QueryStrategy.UNCERTAINTY_SAMPLING
            results['high_uncertainty'] = self.query_strategist.select_queries(
                model, unlabeled_pool, labeled_pool, cycle_budget['high_uncertainty']
            )
            # Remove selected from pool for diversity selection
            selected_ids = {s.sample_id for s in results['high_uncertainty']}
            remaining_pool = [s for s in unlabeled_pool if s.sample_id not in selected_ids]
        else:
            results['high_uncertainty'] = []
            remaining_pool = unlabeled_pool

        # Select diverse samples
        if cycle_budget['diversity'] > 0 and remaining_pool:
            self.query_strategist.strategy = QueryStrategy.DIVERSITY_SAMPLING
            results['diversity'] = self.query_strategist.select_queries(
                model, remaining_pool, labeled_pool, cycle_budget['diversity']
            )
            selected_ids = {s.sample_id for s in results['diversity']}
            remaining_pool = [s for s in remaining_pool if s.sample_id not in selected_ids]
        else:
            results['diversity'] = []

        # Random exploration samples
        if cycle_budget['exploration'] > 0 and remaining_pool:
            import random
            n_explore = min(cycle_budget['exploration'], len(remaining_pool))
            results['exploration'] = random.sample(remaining_pool, n_explore)
        else:
            results['exploration'] = []

        return results

    def get_selection_stats(self) -> Dict[str, Any]:
        """Get selection statistics."""
        if not self._selection_history:
            return {'total_selections': 0}

        return {
            'total_selections': len(self._selection_history),
            'total_samples_selected': sum(
                e['n_selected'] for e in self._selection_history
            ),
            'avg_uncertainty': np.mean([
                e['avg_uncertainty'] for e in self._selection_history
            ]),
            'avg_diversity': np.mean([
                e['avg_diversity'] for e in self._selection_history
            ]),
            'selection_rate': np.mean([
                e['n_selected'] / e['n_pool'] if e['n_pool'] > 0 else 0
                for e in self._selection_history
            ])
        }


# =============================================================================
# ModelUpdater
# =============================================================================

class ModelUpdater:
    """
    Update model with new labeled data.

    Features:
    - Incremental updates
    - Full retraining support
    - Validation monitoring
    - Update scheduling
    - Rollback capability
    """

    def __init__(
        self,
        model: PredictionModel,
        incremental: bool = True,
        validation_split: float = 0.2,
        min_samples_for_update: int = 10,
        checkpoint_dir: Optional[str] = None
    ):
        """
        Initialize model updater.

        Args:
            model: Prediction model to update
            incremental: Use incremental updates if available
            validation_split: Fraction of data for validation
            min_samples_for_update: Minimum new samples to trigger update
            checkpoint_dir: Directory for model checkpoints
        """
        self.model = model
        self.incremental = incremental
        self.validation_split = validation_split
        self.min_samples_for_update = min_samples_for_update
        self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir else None

        self._update_history: List[ModelUpdateResult] = []
        self._best_metrics: Dict[str, float] = {}
        self._pending_samples: List[Sample] = []

        if self.checkpoint_dir:
            self.checkpoint_dir.mkdir(parents=True, exist_ok=True)

        logger.info(f"ModelUpdater initialized (incremental={incremental})")

    def add_labeled_samples(self, samples: List[Sample]) -> int:
        """
        Add newly labeled samples to pending update queue.

        Args:
            samples: Labeled samples to add

        Returns:
            Number of samples added
        """
        labeled = [s for s in samples if s.status == LabelStatus.LABELED]
        self._pending_samples.extend(labeled)
        return len(labeled)

    def should_update(self) -> bool:
        """Check if model should be updated."""
        return len(self._pending_samples) >= self.min_samples_for_update

    def update(
        self,
        labeled_pool: Optional[List[Sample]] = None,
        force_full_retrain: bool = False
    ) -> ModelUpdateResult:
        """
        Update model with new labeled data.

        Args:
            labeled_pool: Full labeled dataset (for full retraining)
            force_full_retrain: Force full retraining even if incremental is enabled

        Returns:
            Update result with metrics
        """
        start_time = datetime.now()

        if not self._pending_samples and not labeled_pool:
            return ModelUpdateResult(
                success=False,
                samples_used=0,
                previous_metrics={},
                new_metrics={},
                improvement={},
                update_duration_seconds=0.0,
                error_message="No samples available for update"
            )

        # Prepare data
        if labeled_pool and (force_full_retrain or not self.incremental):
            samples_to_use = labeled_pool
            update_type = "full_retrain"
        else:
            samples_to_use = self._pending_samples
            update_type = "incremental"

        features = np.array([s.features for s in samples_to_use])
        labels = np.array([s.label for s in samples_to_use])

        # Split for validation
        n_samples = len(samples_to_use)
        n_val = int(n_samples * self.validation_split)

        if n_val > 0:
            indices = np.random.permutation(n_samples)
            train_idx, val_idx = indices[n_val:], indices[:n_val]

            train_features, train_labels = features[train_idx], labels[train_idx]
            val_features, val_labels = features[val_idx], labels[val_idx]
        else:
            train_features, train_labels = features, labels
            val_features, val_labels = None, None

        # Get previous metrics
        previous_metrics = self._evaluate_model(val_features, val_labels)

        try:
            # Update model
            if update_type == "incremental" and hasattr(self.model, 'partial_fit'):
                self.model.partial_fit(train_features, train_labels)
            else:
                self.model.fit(train_features, train_labels)

            # Evaluate new metrics
            new_metrics = self._evaluate_model(val_features, val_labels)

            # Calculate improvement
            improvement = {
                key: new_metrics.get(key, 0) - previous_metrics.get(key, 0)
                for key in new_metrics
            }

            # Update best metrics
            for key, value in new_metrics.items():
                if value > self._best_metrics.get(key, float('-inf')):
                    self._best_metrics[key] = value

            # Clear pending samples
            self._pending_samples = []

            result = ModelUpdateResult(
                success=True,
                samples_used=n_samples,
                previous_metrics=previous_metrics,
                new_metrics=new_metrics,
                improvement=improvement,
                update_duration_seconds=(datetime.now() - start_time).total_seconds()
            )

            self._update_history.append(result)

            # Save checkpoint if improved
            if self.checkpoint_dir and all(v >= 0 for v in improvement.values()):
                self._save_checkpoint()

            logger.info(
                f"Model updated ({update_type}): {n_samples} samples, "
                f"improvement: {improvement}"
            )

            return result

        except Exception as e:
            logger.error(f"Model update failed: {e}")
            return ModelUpdateResult(
                success=False,
                samples_used=0,
                previous_metrics=previous_metrics,
                new_metrics={},
                improvement={},
                update_duration_seconds=(datetime.now() - start_time).total_seconds(),
                error_message=str(e)
            )

    def _evaluate_model(
        self,
        features: Optional[np.ndarray],
        labels: Optional[np.ndarray]
    ) -> Dict[str, float]:
        """Evaluate model on validation data."""
        if features is None or labels is None or len(features) == 0:
            return {}

        try:
            predictions = self.model.predict(features)

            # Calculate accuracy
            accuracy = np.mean(predictions == labels)

            # Calculate additional metrics if possible
            metrics = {'accuracy': float(accuracy)}

            try:
                probabilities = self.model.predict_proba(features)
                # Calculate log loss
                n_samples = len(labels)
                log_loss_val = 0
                for i in range(n_samples):
                    prob = probabilities[i][labels[i]] if len(probabilities[i]) > labels[i] else 0.5
                    log_loss_val -= np.log(max(prob, 1e-10))
                metrics['log_loss'] = float(log_loss_val / n_samples)
            except Exception:
                pass

            return metrics

        except Exception as e:
            logger.warning(f"Evaluation failed: {e}")
            return {}

    def _save_checkpoint(self) -> None:
        """Save model checkpoint."""
        if not self.checkpoint_dir:
            return

        checkpoint_path = self.checkpoint_dir / f"checkpoint_{len(self._update_history)}.json"

        try:
            checkpoint_data = {
                'update_count': len(self._update_history),
                'best_metrics': self._best_metrics,
                'timestamp': datetime.now().isoformat()
            }

            with open(checkpoint_path, 'w') as f:
                json.dump(checkpoint_data, f, indent=2)

            logger.info(f"Checkpoint saved: {checkpoint_path}")
        except Exception as e:
            logger.warning(f"Checkpoint save failed: {e}")

    def get_update_history(self) -> List[Dict[str, Any]]:
        """Get update history as dictionaries."""
        return [
            {
                'success': r.success,
                'samples_used': r.samples_used,
                'previous_metrics': r.previous_metrics,
                'new_metrics': r.new_metrics,
                'improvement': r.improvement,
                'duration_seconds': r.update_duration_seconds
            }
            for r in self._update_history
        ]

    def get_best_metrics(self) -> Dict[str, float]:
        """Get best achieved metrics."""
        return dict(self._best_metrics)


# =============================================================================
# ActiveLearningSystem (Main Orchestrator)
# =============================================================================

class ActiveLearningSystem:
    """
    Complete active learning system orchestrating all components.

    Manages the full active learning loop:
    1. Select informative samples from unlabeled pool
    2. Submit samples for labeling via oracle
    3. Update model with new labels
    4. Track budget and performance
    """

    def __init__(
        self,
        model: PredictionModel,
        total_budget: int = 1000,
        budget_per_cycle: int = 50,
        query_strategy: QueryStrategy = QueryStrategy.HYBRID,
        uncertainty_method: UncertaintyMethod = UncertaintyMethod.ENTROPY,
        checkpoint_dir: Optional[str] = None
    ):
        """
        Initialize active learning system.

        Args:
            model: Prediction model to train
            total_budget: Total labeling budget
            budget_per_cycle: Labels per active learning cycle
            query_strategy: Strategy for selecting samples
            uncertainty_method: Method for uncertainty estimation
            checkpoint_dir: Directory for checkpoints
        """
        # Initialize components
        self.uncertainty_estimator = UncertaintyEstimator(method=uncertainty_method)
        self.query_strategist = QueryStrategist(
            strategy=query_strategy,
            uncertainty_estimator=self.uncertainty_estimator
        )
        self.oracle = OracleInterface()
        self.budget_manager = LabelBudgetManager(
            total_budget=total_budget,
            budget_per_cycle=budget_per_cycle
        )
        self.sample_selector = SampleSelector(
            query_strategist=self.query_strategist,
            budget_manager=self.budget_manager
        )
        self.model_updater = ModelUpdater(
            model=model,
            checkpoint_dir=checkpoint_dir
        )

        self.model = model
        self._unlabeled_pool: List[Sample] = []
        self._labeled_pool: List[Sample] = []
        self._cycle_count: int = 0

        logger.info("ActiveLearningSystem initialized")

    def add_unlabeled_data(
        self,
        features: np.ndarray,
        sample_ids: Optional[List[str]] = None
    ) -> int:
        """
        Add unlabeled data to the pool.

        Args:
            features: Feature matrix (n_samples, n_features)
            sample_ids: Optional sample identifiers

        Returns:
            Number of samples added
        """
        n_samples = len(features)

        if sample_ids is None:
            sample_ids = [f"sample_{len(self._unlabeled_pool) + i}" for i in range(n_samples)]

        for i in range(n_samples):
            sample = Sample(
                sample_id=sample_ids[i],
                features=features[i]
            )
            self._unlabeled_pool.append(sample)

        logger.info(f"Added {n_samples} unlabeled samples")
        return n_samples

    def add_labeled_data(
        self,
        features: np.ndarray,
        labels: np.ndarray,
        sample_ids: Optional[List[str]] = None
    ) -> int:
        """
        Add labeled data (seed data for initial training).

        Args:
            features: Feature matrix
            labels: Labels
            sample_ids: Optional sample identifiers

        Returns:
            Number of samples added
        """
        n_samples = len(features)

        if sample_ids is None:
            sample_ids = [f"labeled_{len(self._labeled_pool) + i}" for i in range(n_samples)]

        for i in range(n_samples):
            sample = Sample(
                sample_id=sample_ids[i],
                features=features[i],
                label=labels[i],
                status=LabelStatus.LABELED
            )
            self._labeled_pool.append(sample)

        logger.info(f"Added {n_samples} labeled samples")
        return n_samples

    def run_cycle(
        self,
        label_function: Optional[Callable[[Sample], Any]] = None
    ) -> Dict[str, Any]:
        """
        Run one active learning cycle.

        Args:
            label_function: Function to provide labels (for simulation)

        Returns:
            Cycle results
        """
        self._cycle_count += 1
        cycle_start = datetime.now()

        logger.info(f"Starting active learning cycle {self._cycle_count}")

        # 1. Get current metrics
        metrics_before = self._get_current_metrics()

        # 2. Select samples
        selected_samples = self.sample_selector.select_samples(
            model=self.model,
            unlabeled_pool=self._unlabeled_pool,
            labeled_pool=self._labeled_pool
        )

        if not selected_samples:
            return {
                'cycle': self._cycle_count,
                'success': False,
                'reason': 'No samples selected',
                'duration_seconds': (datetime.now() - cycle_start).total_seconds()
            }

        # 3. Submit for labeling
        task_ids = self.oracle.submit_for_labeling(selected_samples)

        # 4. Get labels (simulate or wait)
        if label_function:
            labeled_count = self.oracle.simulate_labeling(label_function, len(task_ids))
        else:
            labeled_count = 0

        # 5. Resolve labels and move to labeled pool
        newly_labeled = []
        for sample in selected_samples:
            if sample.status == LabelStatus.LABELED:
                newly_labeled.append(sample)
                self._unlabeled_pool.remove(sample)
                self._labeled_pool.append(sample)

        # 6. Update model
        self.model_updater.add_labeled_samples(newly_labeled)

        update_result = None
        if self.model_updater.should_update():
            update_result = self.model_updater.update(self._labeled_pool)

        # 7. Update budget
        metrics_after = self._get_current_metrics()
        self.budget_manager.consume_budget(
            amount=len(newly_labeled),
            strategy='hybrid',
            metrics_before=metrics_before,
            metrics_after=metrics_after
        )

        cycle_duration = (datetime.now() - cycle_start).total_seconds()

        return {
            'cycle': self._cycle_count,
            'success': True,
            'samples_selected': len(selected_samples),
            'samples_labeled': len(newly_labeled),
            'metrics_before': metrics_before,
            'metrics_after': metrics_after,
            'model_updated': update_result is not None and update_result.success,
            'remaining_budget': self.budget_manager.allocation.remaining_budget,
            'remaining_unlabeled': len(self._unlabeled_pool),
            'total_labeled': len(self._labeled_pool),
            'duration_seconds': cycle_duration
        }

    def run_until_budget_exhausted(
        self,
        label_function: Callable[[Sample], Any],
        max_cycles: int = 100
    ) -> List[Dict[str, Any]]:
        """
        Run active learning until budget is exhausted.

        Args:
            label_function: Function to provide labels
            max_cycles: Maximum number of cycles

        Returns:
            List of cycle results
        """
        results = []

        while (
            self.budget_manager.should_continue() and
            len(results) < max_cycles and
            self._unlabeled_pool
        ):
            result = self.run_cycle(label_function)
            results.append(result)

            if not result['success']:
                break

        logger.info(
            f"Active learning complete: {len(results)} cycles, "
            f"{len(self._labeled_pool)} total labeled samples"
        )

        return results

    def _get_current_metrics(self) -> Dict[str, float]:
        """Get current model metrics on labeled data."""
        if len(self._labeled_pool) < 10:
            return {}

        # Use a subset for evaluation
        eval_samples = self._labeled_pool[-100:]
        features = np.array([s.features for s in eval_samples])
        labels = np.array([s.label for s in eval_samples])

        try:
            predictions = self.model.predict(features)
            accuracy = np.mean(predictions == labels)
            return {'accuracy': float(accuracy)}
        except Exception:
            return {}

    def get_status(self) -> Dict[str, Any]:
        """Get comprehensive system status."""
        return {
            'cycles_completed': self._cycle_count,
            'unlabeled_pool_size': len(self._unlabeled_pool),
            'labeled_pool_size': len(self._labeled_pool),
            'budget_status': self.budget_manager.get_status(),
            'oracle_status': self.oracle.get_queue_status(),
            'selection_stats': self.sample_selector.get_selection_stats(),
            'update_history': self.model_updater.get_update_history(),
            'best_metrics': self.model_updater.get_best_metrics()
        }

    def save_state(self, filepath: str) -> None:
        """Save system state to file."""
        state = {
            'cycle_count': self._cycle_count,
            'unlabeled_pool': [s.to_dict() for s in self._unlabeled_pool],
            'labeled_pool': [s.to_dict() for s in self._labeled_pool],
            'budget_status': self.budget_manager.get_status(),
            'timestamp': datetime.now().isoformat()
        }

        with open(filepath, 'w') as f:
            json.dump(state, f, indent=2)

        logger.info(f"State saved to {filepath}")


# =============================================================================
# Example Usage and Testing
# =============================================================================

class SimpleClassifier:
    """Simple classifier for testing purposes."""

    def __init__(self, n_features: int = 10, n_classes: int = 2):
        self.n_features = n_features
        self.n_classes = n_classes
        self.weights = np.random.randn(n_features, n_classes)
        self.bias = np.zeros(n_classes)
        self._fitted = False

    def predict_proba(self, features: np.ndarray) -> np.ndarray:
        logits = features @ self.weights + self.bias
        exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
        return exp_logits / np.sum(exp_logits, axis=1, keepdims=True)

    def predict(self, features: np.ndarray) -> np.ndarray:
        probs = self.predict_proba(features)
        return np.argmax(probs, axis=1)

    def fit(self, features: np.ndarray, labels: np.ndarray) -> None:
        # Simple gradient descent
        learning_rate = 0.01
        for _ in range(100):
            probs = self.predict_proba(features)
            n_samples = len(labels)

            # One-hot encode labels
            targets = np.zeros((n_samples, self.n_classes))
            targets[np.arange(n_samples), labels.astype(int)] = 1

            # Gradient
            error = probs - targets
            grad_w = features.T @ error / n_samples
            grad_b = np.mean(error, axis=0)

            self.weights -= learning_rate * grad_w
            self.bias -= learning_rate * grad_b

        self._fitted = True

    def partial_fit(self, features: np.ndarray, labels: np.ndarray) -> None:
        # Single update step
        learning_rate = 0.1
        probs = self.predict_proba(features)
        n_samples = len(labels)

        targets = np.zeros((n_samples, self.n_classes))
        targets[np.arange(n_samples), labels.astype(int)] = 1

        error = probs - targets
        grad_w = features.T @ error / n_samples
        grad_b = np.mean(error, axis=0)

        self.weights -= learning_rate * grad_w
        self.bias -= learning_rate * grad_b


def generate_test_data(
    n_samples: int = 1000,
    n_features: int = 10,
    n_classes: int = 2,
    seed: int = 42
) -> Tuple[np.ndarray, np.ndarray]:
    """Generate synthetic test data."""
    np.random.seed(seed)

    features = np.random.randn(n_samples, n_features)

    # Create decision boundary
    true_weights = np.random.randn(n_features)
    scores = features @ true_weights
    labels = (scores > 0).astype(int)

    return features, labels


if __name__ == "__main__":
    print("=" * 70)
    print("AIVA Queen Active Learning System - Test Run")
    print("=" * 70)

    # Generate test data
    print("\n1. Generating test data...")
    features, labels = generate_test_data(n_samples=1000, n_features=10)

    # Split into seed labeled and unlabeled
    seed_size = 50
    seed_features, seed_labels = features[:seed_size], labels[:seed_size]
    unlabeled_features = features[seed_size:]
    unlabeled_labels = labels[seed_size:]  # Ground truth for simulation

    # Create oracle label function
    label_lookup = {
        f"sample_{i}": unlabeled_labels[i]
        for i in range(len(unlabeled_labels))
    }

    def oracle_label_function(sample: Sample) -> int:
        return int(label_lookup.get(sample.sample_id, 0))

    # Initialize model
    print("\n2. Initializing model and active learning system...")
    model = SimpleClassifier(n_features=10, n_classes=2)

    # Initialize active learning system
    al_system = ActiveLearningSystem(
        model=model,
        total_budget=200,
        budget_per_cycle=20,
        query_strategy=QueryStrategy.HYBRID,
        uncertainty_method=UncertaintyMethod.ENTROPY
    )

    # Add seed data
    print(f"\n3. Adding {seed_size} seed samples...")
    al_system.add_labeled_data(seed_features, seed_labels)

    # Train initial model
    print("\n4. Training initial model on seed data...")
    model.fit(seed_features, seed_labels)

    # Add unlabeled data
    print(f"\n5. Adding {len(unlabeled_features)} unlabeled samples...")
    al_system.add_unlabeled_data(unlabeled_features)

    # Run active learning
    print("\n6. Running active learning cycles...")
    results = al_system.run_until_budget_exhausted(
        label_function=oracle_label_function,
        max_cycles=10
    )

    # Print results
    print("\n" + "=" * 70)
    print("ACTIVE LEARNING RESULTS")
    print("=" * 70)

    for result in results:
        print(f"\nCycle {result['cycle']}:")
        print(f"  - Samples selected: {result['samples_selected']}")
        print(f"  - Samples labeled: {result['samples_labeled']}")
        print(f"  - Model updated: {result['model_updated']}")
        print(f"  - Metrics: {result['metrics_after']}")
        print(f"  - Remaining budget: {result['remaining_budget']}")

    # Final status
    print("\n" + "=" * 70)
    print("FINAL STATUS")
    print("=" * 70)

    status = al_system.get_status()
    print(f"\nCycles completed: {status['cycles_completed']}")
    print(f"Labeled samples: {status['labeled_pool_size']}")
    print(f"Remaining unlabeled: {status['unlabeled_pool_size']}")
    print(f"Best metrics: {status['best_metrics']}")
    print(f"Budget utilization: {status['budget_status']['utilization_rate']:.2%}")

    # Test individual components
    print("\n" + "=" * 70)
    print("COMPONENT TESTS")
    print("=" * 70)

    # Test uncertainty estimator
    print("\n1. Uncertainty Estimator:")
    estimator = UncertaintyEstimator(method=UncertaintyMethod.ENTROPY)
    test_samples = [Sample(sample_id=f"test_{i}", features=features[i]) for i in range(10)]
    uncertainties = estimator.estimate(model, test_samples)
    print(f"   Sample uncertainties: {[f'{u:.3f}' for u in uncertainties[:5]]}...")

    # Test query strategist
    print("\n2. Query Strategist:")
    strategist = QueryStrategist(strategy=QueryStrategy.UNCERTAINTY_SAMPLING)
    selected = strategist.select_queries(
        model=model,
        unlabeled_pool=test_samples,
        labeled_pool=[],
        n_queries=3
    )
    print(f"   Selected {len(selected)} samples for labeling")

    # Test budget manager
    print("\n3. Budget Manager:")
    budget_mgr = LabelBudgetManager(total_budget=100, budget_per_cycle=10)
    print(f"   Cycle budget: {budget_mgr.get_cycle_budget()}")
    print(f"   Remaining cycles: {budget_mgr.estimate_remaining_cycles()}")

    print("\n" + "=" * 70)
    print("ALL TESTS COMPLETED SUCCESSFULLY")
    print("=" * 70)
