"""
AIVA Queen Imitation Learning System
=====================================

A comprehensive imitation learning framework implementing:
1. DemonstrationCollector - Collect expert demonstrations
2. BehaviorCloner - Clone expert behavior via supervised learning
3. InverseRL - Infer reward functions from demonstrations
4. GAIL - Generative Adversarial Imitation Learning
5. DAgger - Dataset Aggregation for iterative improvement
6. Discriminator - Expert vs learned policy discrimination

This module enables AIVA to learn from expert demonstrations across
various domains including voice AI, patent validation, and business automation.

Author: Genesis System / AIVA Queen
Version: 1.0.0
"""

import json
import time
import math
import random
import logging
import hashlib
import threading
from abc import ABC, abstractmethod
from enum import Enum
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import (
    Any, Callable, Dict, List, Optional, Tuple,
    TypeVar, Generic, Union, Iterator
)
from collections import defaultdict
import uuid

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("ImitationLearning")


# ============================================================================
# Type Definitions and Enums
# ============================================================================

State = TypeVar('State')
Action = TypeVar('Action')
Observation = TypeVar('Observation')


class DemonstrationQuality(Enum):
    """Quality levels for expert demonstrations."""
    EXPERT = "expert"           # Highest quality - verified expert
    PROFICIENT = "proficient"   # Good quality - experienced user
    NOVICE = "novice"           # Lower quality - learning user
    SYNTHETIC = "synthetic"     # Generated demonstrations


class LearningMode(Enum):
    """Modes for imitation learning algorithms."""
    OFFLINE = "offline"         # Learn from fixed dataset
    ONLINE = "online"           # Learn interactively
    HYBRID = "hybrid"           # Combination of both


@dataclass
class Demonstration:
    """A single expert demonstration consisting of state-action pairs."""
    id: str
    trajectory: List[Tuple[Any, Any]]  # (state, action) pairs
    quality: DemonstrationQuality
    expert_id: str
    domain: str
    timestamp: datetime
    metadata: Dict[str, Any] = field(default_factory=dict)
    reward_sum: float = 0.0
    success: bool = True

    def __len__(self) -> int:
        return len(self.trajectory)

    def get_states(self) -> List[Any]:
        return [s for s, a in self.trajectory]

    def get_actions(self) -> List[Any]:
        return [a for s, a in self.trajectory]

    def to_dict(self) -> Dict[str, Any]:
        return {
            "id": self.id,
            "trajectory": [(str(s), str(a)) for s, a in self.trajectory],
            "quality": self.quality.value,
            "expert_id": self.expert_id,
            "domain": self.domain,
            "timestamp": self.timestamp.isoformat(),
            "metadata": self.metadata,
            "reward_sum": self.reward_sum,
            "success": self.success
        }


@dataclass
class ExpertProfile:
    """Profile of an expert demonstrator."""
    id: str
    name: str
    domains: List[str]
    skill_level: float  # 0.0 to 1.0
    demonstrations_count: int = 0
    success_rate: float = 1.0
    created_at: datetime = field(default_factory=datetime.now)

    def update_success_rate(self, success: bool):
        n = self.demonstrations_count
        self.success_rate = (self.success_rate * n + (1.0 if success else 0.0)) / (n + 1)
        self.demonstrations_count += 1


@dataclass
class PolicyUpdate:
    """Represents an update to the learned policy."""
    iteration: int
    loss: float
    accuracy: float
    timestamp: datetime
    parameters_changed: int
    learning_rate: float


# ============================================================================
# Demonstration Collector
# ============================================================================

class DemonstrationCollector:
    """
    Collects expert demonstrations for imitation learning.

    Supports multiple collection modes:
    - Direct recording from expert interactions
    - Replay buffer for experience replay
    - Synthetic generation via expert policies
    - Import from external sources
    """

    def __init__(
        self,
        storage_path: Optional[Path] = None,
        max_buffer_size: int = 10000,
        quality_threshold: float = 0.7
    ):
        """
        Initialize the demonstration collector.

        Args:
            storage_path: Path for persistent storage
            max_buffer_size: Maximum demonstrations to hold in memory
            quality_threshold: Minimum quality score for inclusion
        """
        self.storage_path = storage_path or Path("./demonstrations")
        self.storage_path.mkdir(parents=True, exist_ok=True)
        self.max_buffer_size = max_buffer_size
        self.quality_threshold = quality_threshold

        self.demonstrations: List[Demonstration] = []
        self.experts: Dict[str, ExpertProfile] = {}
        self.domain_indices: Dict[str, List[int]] = defaultdict(list)
        self._lock = threading.Lock()

        self.collection_stats = {
            "total_collected": 0,
            "total_rejected": 0,
            "by_quality": defaultdict(int),
            "by_domain": defaultdict(int),
            "average_length": 0.0
        }

        logger.info(f"DemonstrationCollector initialized with buffer size {max_buffer_size}")

    def register_expert(
        self,
        name: str,
        domains: List[str],
        skill_level: float = 0.9
    ) -> str:
        """
        Register a new expert demonstrator.

        Args:
            name: Expert's identifier name
            domains: List of domains the expert can demonstrate
            skill_level: Expert's skill level (0.0 to 1.0)

        Returns:
            Unique expert ID
        """
        expert_id = str(uuid.uuid4())[:8]
        self.experts[expert_id] = ExpertProfile(
            id=expert_id,
            name=name,
            domains=domains,
            skill_level=min(1.0, max(0.0, skill_level))
        )
        logger.info(f"Registered expert {name} (ID: {expert_id}) for domains: {domains}")
        return expert_id

    def start_recording(self, expert_id: str, domain: str) -> "RecordingSession":
        """
        Start a recording session for collecting demonstrations.

        Args:
            expert_id: ID of the demonstrating expert
            domain: Domain of the demonstration

        Returns:
            RecordingSession object for collecting state-action pairs
        """
        if expert_id not in self.experts:
            raise ValueError(f"Unknown expert ID: {expert_id}")

        expert = self.experts[expert_id]
        if domain not in expert.domains:
            logger.warning(f"Expert {expert.name} demonstrating outside known domains")

        return RecordingSession(
            collector=self,
            expert_id=expert_id,
            domain=domain
        )

    def add_demonstration(
        self,
        trajectory: List[Tuple[Any, Any]],
        expert_id: str,
        domain: str,
        quality: DemonstrationQuality = DemonstrationQuality.EXPERT,
        metadata: Optional[Dict[str, Any]] = None,
        success: bool = True
    ) -> Optional[str]:
        """
        Add a demonstration to the collection.

        Args:
            trajectory: List of (state, action) pairs
            expert_id: ID of the demonstrating expert
            domain: Domain of the demonstration
            quality: Quality level of the demonstration
            metadata: Additional metadata
            success: Whether the demonstration achieved its goal

        Returns:
            Demonstration ID if accepted, None if rejected
        """
        if not trajectory:
            logger.warning("Rejected empty trajectory")
            self.collection_stats["total_rejected"] += 1
            return None

        # Calculate quality score
        quality_score = self._assess_quality(trajectory, quality, success)
        if quality_score < self.quality_threshold:
            logger.info(f"Rejected demonstration with quality {quality_score:.3f}")
            self.collection_stats["total_rejected"] += 1
            return None

        demo_id = str(uuid.uuid4())[:12]
        demonstration = Demonstration(
            id=demo_id,
            trajectory=trajectory,
            quality=quality,
            expert_id=expert_id,
            domain=domain,
            timestamp=datetime.now(),
            metadata=metadata or {},
            success=success
        )

        with self._lock:
            # Evict old demonstrations if at capacity
            if len(self.demonstrations) >= self.max_buffer_size:
                self._evict_lowest_quality()

            idx = len(self.demonstrations)
            self.demonstrations.append(demonstration)
            self.domain_indices[domain].append(idx)

            # Update stats
            self.collection_stats["total_collected"] += 1
            self.collection_stats["by_quality"][quality.value] += 1
            self.collection_stats["by_domain"][domain] += 1
            self._update_average_length(len(trajectory))

            # Update expert profile
            if expert_id in self.experts:
                self.experts[expert_id].update_success_rate(success)

        logger.info(f"Added demonstration {demo_id} ({len(trajectory)} steps) in domain {domain}")
        return demo_id

    def _assess_quality(
        self,
        trajectory: List[Tuple[Any, Any]],
        quality: DemonstrationQuality,
        success: bool
    ) -> float:
        """Assess the quality score of a demonstration."""
        base_score = {
            DemonstrationQuality.EXPERT: 1.0,
            DemonstrationQuality.PROFICIENT: 0.8,
            DemonstrationQuality.NOVICE: 0.5,
            DemonstrationQuality.SYNTHETIC: 0.7
        }.get(quality, 0.5)

        # Penalize for failure
        if not success:
            base_score *= 0.6

        # Bonus for longer, more complete trajectories
        length_bonus = min(0.2, len(trajectory) / 100.0)

        return min(1.0, base_score + length_bonus)

    def _evict_lowest_quality(self):
        """Remove the lowest quality demonstration from the buffer."""
        if not self.demonstrations:
            return

        # Find lowest quality demonstration
        min_idx = 0
        min_score = float('inf')

        for i, demo in enumerate(self.demonstrations):
            score = self._assess_quality(demo.trajectory, demo.quality, demo.success)
            if score < min_score:
                min_score = score
                min_idx = i

        # Remove from indices
        removed = self.demonstrations.pop(min_idx)
        if removed.domain in self.domain_indices:
            self.domain_indices[removed.domain] = [
                i if i < min_idx else i - 1
                for i in self.domain_indices[removed.domain]
                if i != min_idx
            ]

        logger.debug(f"Evicted demonstration {removed.id}")

    def _update_average_length(self, new_length: int):
        """Update running average of trajectory lengths."""
        n = self.collection_stats["total_collected"]
        old_avg = self.collection_stats["average_length"]
        self.collection_stats["average_length"] = (old_avg * (n - 1) + new_length) / n

    def sample_batch(
        self,
        batch_size: int,
        domain: Optional[str] = None,
        min_quality: DemonstrationQuality = DemonstrationQuality.NOVICE
    ) -> List[Demonstration]:
        """
        Sample a batch of demonstrations.

        Args:
            batch_size: Number of demonstrations to sample
            domain: Optional domain filter
            min_quality: Minimum quality level

        Returns:
            List of sampled demonstrations
        """
        quality_order = [
            DemonstrationQuality.NOVICE,
            DemonstrationQuality.SYNTHETIC,
            DemonstrationQuality.PROFICIENT,
            DemonstrationQuality.EXPERT
        ]
        min_idx = quality_order.index(min_quality)
        allowed_qualities = set(quality_order[min_idx:])

        with self._lock:
            candidates = [
                d for d in self.demonstrations
                if d.quality in allowed_qualities
                and (domain is None or d.domain == domain)
            ]

        if not candidates:
            return []

        # Weighted sampling by quality
        weights = [
            self._assess_quality(d.trajectory, d.quality, d.success)
            for d in candidates
        ]
        total_weight = sum(weights)
        normalized_weights = [w / total_weight for w in weights]

        # Sample with replacement if needed
        n_samples = min(batch_size, len(candidates))
        indices = []
        for _ in range(n_samples):
            r = random.random()
            cumsum = 0.0
            for i, w in enumerate(normalized_weights):
                cumsum += w
                if r <= cumsum:
                    indices.append(i)
                    break

        return [candidates[i] for i in indices]

    def get_state_action_pairs(
        self,
        domain: Optional[str] = None
    ) -> List[Tuple[Any, Any]]:
        """
        Get all state-action pairs from demonstrations.

        Args:
            domain: Optional domain filter

        Returns:
            List of (state, action) tuples
        """
        pairs = []
        with self._lock:
            for demo in self.demonstrations:
                if domain is None or demo.domain == domain:
                    pairs.extend(demo.trajectory)
        return pairs

    def save(self, filename: Optional[str] = None):
        """Save demonstrations to disk."""
        filename = filename or f"demonstrations_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        filepath = self.storage_path / filename

        data = {
            "demonstrations": [d.to_dict() for d in self.demonstrations],
            "experts": {
                eid: {
                    "id": e.id,
                    "name": e.name,
                    "domains": e.domains,
                    "skill_level": e.skill_level,
                    "demonstrations_count": e.demonstrations_count,
                    "success_rate": e.success_rate
                }
                for eid, e in self.experts.items()
            },
            "stats": dict(self.collection_stats)
        }

        with open(filepath, 'w') as f:
            json.dump(data, f, indent=2, default=str)

        logger.info(f"Saved {len(self.demonstrations)} demonstrations to {filepath}")

    def get_statistics(self) -> Dict[str, Any]:
        """Get collection statistics."""
        return {
            **self.collection_stats,
            "buffer_size": len(self.demonstrations),
            "num_experts": len(self.experts),
            "domains": list(self.domain_indices.keys())
        }


class RecordingSession:
    """Context manager for recording expert demonstrations."""

    def __init__(
        self,
        collector: DemonstrationCollector,
        expert_id: str,
        domain: str
    ):
        self.collector = collector
        self.expert_id = expert_id
        self.domain = domain
        self.trajectory: List[Tuple[Any, Any]] = []
        self.metadata: Dict[str, Any] = {}
        self.start_time: Optional[datetime] = None
        self.success = True

    def __enter__(self):
        self.start_time = datetime.now()
        logger.info(f"Started recording session for expert {self.expert_id}")
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is not None:
            self.success = False
            logger.warning(f"Recording session failed: {exc_val}")

        if self.trajectory:
            self.metadata["duration_seconds"] = (
                datetime.now() - self.start_time
            ).total_seconds()
            self.collector.add_demonstration(
                trajectory=self.trajectory,
                expert_id=self.expert_id,
                domain=self.domain,
                metadata=self.metadata,
                success=self.success
            )

        return False

    def record(self, state: Any, action: Any):
        """Record a single state-action pair."""
        self.trajectory.append((state, action))

    def mark_failure(self):
        """Mark the current demonstration as a failure."""
        self.success = False


# ============================================================================
# Behavior Cloning
# ============================================================================

class BehaviorCloner:
    """
    Implements Behavior Cloning (BC) - supervised learning from demonstrations.

    BC directly learns a policy pi(a|s) by treating imitation as a
    supervised learning problem where states are inputs and expert
    actions are labels.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_layers: List[int] = None,
        learning_rate: float = 0.001,
        l2_regularization: float = 0.01
    ):
        """
        Initialize the behavior cloner.

        Args:
            state_dim: Dimensionality of state features
            action_dim: Dimensionality of action space
            hidden_layers: List of hidden layer sizes
            learning_rate: Learning rate for optimization
            l2_regularization: L2 regularization strength
        """
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_layers = hidden_layers or [128, 64]
        self.learning_rate = learning_rate
        self.l2_reg = l2_regularization

        # Initialize network parameters
        self.weights: List[List[List[float]]] = []
        self.biases: List[List[float]] = []
        self._init_network()

        self.training_history: List[PolicyUpdate] = []
        self.iteration = 0

        logger.info(f"BehaviorCloner initialized: {state_dim} -> {hidden_layers} -> {action_dim}")

    def _init_network(self):
        """Initialize neural network weights using Xavier initialization."""
        layer_sizes = [self.state_dim] + self.hidden_layers + [self.action_dim]

        for i in range(len(layer_sizes) - 1):
            in_dim = layer_sizes[i]
            out_dim = layer_sizes[i + 1]

            # Xavier initialization
            scale = math.sqrt(2.0 / (in_dim + out_dim))
            weight = [
                [random.gauss(0, scale) for _ in range(out_dim)]
                for _ in range(in_dim)
            ]
            bias = [0.0 for _ in range(out_dim)]

            self.weights.append(weight)
            self.biases.append(bias)

    def _relu(self, x: float) -> float:
        """ReLU activation function."""
        return max(0.0, x)

    def _softmax(self, x: List[float]) -> List[float]:
        """Softmax activation for output layer."""
        max_x = max(x)
        exp_x = [math.exp(xi - max_x) for xi in x]
        sum_exp = sum(exp_x)
        return [e / sum_exp for e in exp_x]

    def _forward(self, state: List[float]) -> Tuple[List[List[float]], List[float]]:
        """
        Forward pass through the network.

        Args:
            state: Input state features

        Returns:
            Tuple of (activations for each layer, output probabilities)
        """
        activations = [state]
        current = state

        for i, (weight, bias) in enumerate(zip(self.weights, self.biases)):
            # Linear transformation
            output = []
            for j in range(len(bias)):
                val = bias[j]
                for k in range(len(current)):
                    val += current[k] * weight[k][j]
                output.append(val)

            # Apply activation (ReLU for hidden, no activation for last layer)
            if i < len(self.weights) - 1:
                output = [self._relu(x) for x in output]

            activations.append(output)
            current = output

        # Apply softmax to get action probabilities
        probs = self._softmax(current)
        return activations, probs

    def predict(self, state: List[float]) -> int:
        """
        Predict the best action for a given state.

        Args:
            state: Input state features

        Returns:
            Index of the predicted action
        """
        _, probs = self._forward(state)
        return probs.index(max(probs))

    def predict_proba(self, state: List[float]) -> List[float]:
        """
        Get action probabilities for a given state.

        Args:
            state: Input state features

        Returns:
            List of action probabilities
        """
        _, probs = self._forward(state)
        return probs

    def train_step(
        self,
        states: List[List[float]],
        actions: List[int]
    ) -> float:
        """
        Perform a single training step.

        Args:
            states: Batch of state features
            actions: Batch of expert action indices

        Returns:
            Batch loss value
        """
        total_loss = 0.0

        # Accumulate gradients
        grad_weights = [
            [[0.0 for _ in row] for row in w]
            for w in self.weights
        ]
        grad_biases = [
            [0.0 for _ in b]
            for b in self.biases
        ]

        for state, action in zip(states, actions):
            activations, probs = self._forward(state)

            # Cross-entropy loss
            loss = -math.log(max(probs[action], 1e-10))
            total_loss += loss

            # Backpropagation
            # Output layer gradient
            delta = list(probs)
            delta[action] -= 1.0  # derivative of cross-entropy with softmax

            # Backward pass through layers
            deltas = [delta]
            for i in range(len(self.weights) - 2, -1, -1):
                new_delta = []
                for j in range(len(self.weights[i][0]) if self.weights[i] else 0):
                    grad = 0.0
                    for k in range(len(deltas[0])):
                        grad += deltas[0][k] * self.weights[i + 1][j][k] if j < len(self.weights[i + 1]) else 0
                    # ReLU derivative
                    if activations[i + 1][j] > 0:
                        new_delta.append(grad)
                    else:
                        new_delta.append(0.0)
                deltas.insert(0, new_delta)

            # Accumulate gradients
            for i in range(len(self.weights)):
                for j in range(len(self.weights[i])):
                    for k in range(len(self.weights[i][j])):
                        if i < len(deltas) and k < len(deltas[i]):
                            grad_weights[i][j][k] += activations[i][j] * deltas[i][k]

                for j in range(len(self.biases[i])):
                    if i < len(deltas) and j < len(deltas[i]):
                        grad_biases[i][j] += deltas[i][j]

        # Apply gradients with L2 regularization
        batch_size = len(states)
        for i in range(len(self.weights)):
            for j in range(len(self.weights[i])):
                for k in range(len(self.weights[i][j])):
                    grad = grad_weights[i][j][k] / batch_size
                    grad += self.l2_reg * self.weights[i][j][k]
                    self.weights[i][j][k] -= self.learning_rate * grad

            for j in range(len(self.biases[i])):
                grad = grad_biases[i][j] / batch_size
                self.biases[i][j] -= self.learning_rate * grad

        return total_loss / batch_size

    def train(
        self,
        collector: DemonstrationCollector,
        epochs: int = 100,
        batch_size: int = 32,
        state_encoder: Optional[Callable[[Any], List[float]]] = None,
        action_encoder: Optional[Callable[[Any], int]] = None,
        validation_split: float = 0.1
    ) -> Dict[str, Any]:
        """
        Train the behavior cloner on collected demonstrations.

        Args:
            collector: DemonstrationCollector with expert demonstrations
            epochs: Number of training epochs
            batch_size: Training batch size
            state_encoder: Function to encode states to feature vectors
            action_encoder: Function to encode actions to indices
            validation_split: Fraction of data for validation

        Returns:
            Training statistics dictionary
        """
        # Get all state-action pairs
        pairs = collector.get_state_action_pairs()
        if not pairs:
            logger.warning("No demonstrations available for training")
            return {"error": "No demonstrations"}

        # Encode data
        state_encoder = state_encoder or (lambda s: s if isinstance(s, list) else [float(hash(str(s)) % 1000) / 1000 for _ in range(self.state_dim)])
        action_encoder = action_encoder or (lambda a: hash(str(a)) % self.action_dim)

        states = [state_encoder(s) for s, a in pairs]
        actions = [action_encoder(a) for s, a in pairs]

        # Split into train/validation
        n_val = int(len(states) * validation_split)
        indices = list(range(len(states)))
        random.shuffle(indices)

        train_indices = indices[n_val:]
        val_indices = indices[:n_val]

        train_states = [states[i] for i in train_indices]
        train_actions = [actions[i] for i in train_indices]
        val_states = [states[i] for i in val_indices] if val_indices else []
        val_actions = [actions[i] for i in val_indices] if val_indices else []

        best_val_loss = float('inf')
        best_weights = None

        logger.info(f"Training on {len(train_states)} examples, validating on {len(val_states)}")

        for epoch in range(epochs):
            # Shuffle training data
            combined = list(zip(train_states, train_actions))
            random.shuffle(combined)
            train_states, train_actions = zip(*combined) if combined else ([], [])
            train_states, train_actions = list(train_states), list(train_actions)

            # Train in batches
            epoch_loss = 0.0
            n_batches = max(1, len(train_states) // batch_size)

            for i in range(n_batches):
                start = i * batch_size
                end = start + batch_size
                batch_states = train_states[start:end]
                batch_actions = train_actions[start:end]

                if batch_states:
                    loss = self.train_step(batch_states, batch_actions)
                    epoch_loss += loss

            epoch_loss /= n_batches

            # Validation
            val_loss = 0.0
            val_correct = 0
            if val_states:
                for state, action in zip(val_states, val_actions):
                    _, probs = self._forward(state)
                    val_loss -= math.log(max(probs[action], 1e-10))
                    if probs.index(max(probs)) == action:
                        val_correct += 1
                val_loss /= len(val_states)
                val_acc = val_correct / len(val_states)
            else:
                val_acc = 0.0

            # Track best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_weights = (
                    [[list(row) for row in w] for w in self.weights],
                    [list(b) for b in self.biases]
                )

            self.iteration += 1
            update = PolicyUpdate(
                iteration=self.iteration,
                loss=epoch_loss,
                accuracy=val_acc,
                timestamp=datetime.now(),
                parameters_changed=sum(len(w) * len(w[0]) for w in self.weights if w),
                learning_rate=self.learning_rate
            )
            self.training_history.append(update)

            if epoch % 10 == 0:
                logger.info(f"Epoch {epoch}: train_loss={epoch_loss:.4f}, val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")

        # Restore best weights
        if best_weights:
            self.weights, self.biases = best_weights

        return {
            "epochs": epochs,
            "final_train_loss": epoch_loss,
            "best_val_loss": best_val_loss,
            "training_samples": len(train_states),
            "validation_samples": len(val_states)
        }

    def get_policy_parameters(self) -> Dict[str, Any]:
        """Get the current policy parameters."""
        return {
            "weights": self.weights,
            "biases": self.biases,
            "architecture": [self.state_dim] + self.hidden_layers + [self.action_dim]
        }


# ============================================================================
# Inverse Reinforcement Learning
# ============================================================================

class InverseRL:
    """
    Implements Inverse Reinforcement Learning (IRL).

    IRL infers the reward function that the expert is optimizing,
    enabling transfer to new environments and better generalization.
    """

    def __init__(
        self,
        state_dim: int,
        feature_extractor: Optional[Callable[[Any], List[float]]] = None,
        learning_rate: float = 0.01,
        discount_factor: float = 0.99
    ):
        """
        Initialize Inverse RL.

        Args:
            state_dim: Dimensionality of state features
            feature_extractor: Function to extract features from states
            learning_rate: Learning rate for reward learning
            discount_factor: Discount factor gamma
        """
        self.state_dim = state_dim
        self.feature_extractor = feature_extractor or (lambda s: s)
        self.learning_rate = learning_rate
        self.gamma = discount_factor

        # Reward function parameters (linear in features)
        self.reward_weights = [random.gauss(0, 0.1) for _ in range(state_dim)]

        # Feature expectations
        self.expert_feature_expectations: Optional[List[float]] = None
        self.policy_feature_expectations: Optional[List[float]] = None

        self.iteration = 0
        self.convergence_history: List[float] = []

        logger.info(f"InverseRL initialized with {state_dim} features")

    def compute_reward(self, state: Any) -> float:
        """
        Compute reward for a state using learned reward function.

        Args:
            state: State to evaluate

        Returns:
            Scalar reward value
        """
        features = self.feature_extractor(state)
        return sum(w * f for w, f in zip(self.reward_weights, features))

    def compute_feature_expectations(
        self,
        demonstrations: List[Demonstration]
    ) -> List[float]:
        """
        Compute empirical feature expectations from demonstrations.

        Args:
            demonstrations: List of expert demonstrations

        Returns:
            Feature expectations vector
        """
        feature_sum = [0.0 for _ in range(self.state_dim)]
        total_weight = 0.0

        for demo in demonstrations:
            discount = 1.0
            for state, _ in demo.trajectory:
                features = self.feature_extractor(state)
                for i, f in enumerate(features):
                    feature_sum[i] += discount * f
                discount *= self.gamma
                total_weight += discount

        if total_weight > 0:
            return [f / total_weight for f in feature_sum]
        return feature_sum

    def learn_reward(
        self,
        collector: DemonstrationCollector,
        policy_sampler: Callable[[], List[Tuple[Any, Any]]],
        max_iterations: int = 100,
        tolerance: float = 0.01
    ) -> Dict[str, Any]:
        """
        Learn the reward function via maximum entropy IRL.

        Args:
            collector: Demonstration collector with expert data
            policy_sampler: Function that generates trajectories from current policy
            max_iterations: Maximum iterations for learning
            tolerance: Convergence tolerance

        Returns:
            Learning statistics
        """
        # Compute expert feature expectations
        demonstrations = collector.demonstrations
        if not demonstrations:
            logger.warning("No demonstrations for IRL")
            return {"error": "No demonstrations"}

        self.expert_feature_expectations = self.compute_feature_expectations(demonstrations)
        logger.info(f"Computed expert feature expectations from {len(demonstrations)} demos")

        for iteration in range(max_iterations):
            # Sample trajectories from current policy
            policy_trajectories = [policy_sampler() for _ in range(min(10, len(demonstrations)))]

            # Create pseudo-demonstrations for feature expectation computation
            policy_demos = [
                Demonstration(
                    id=f"policy_{i}",
                    trajectory=traj,
                    quality=DemonstrationQuality.SYNTHETIC,
                    expert_id="policy",
                    domain="policy",
                    timestamp=datetime.now()
                )
                for i, traj in enumerate(policy_trajectories) if traj
            ]

            if not policy_demos:
                logger.warning("Policy sampler returned no trajectories")
                break

            self.policy_feature_expectations = self.compute_feature_expectations(policy_demos)

            # Compute gradient and update reward weights
            gradient = [
                expert - policy
                for expert, policy in zip(
                    self.expert_feature_expectations,
                    self.policy_feature_expectations
                )
            ]

            # Update weights
            for i in range(len(self.reward_weights)):
                self.reward_weights[i] += self.learning_rate * gradient[i]

            # Compute convergence metric
            grad_norm = math.sqrt(sum(g ** 2 for g in gradient))
            self.convergence_history.append(grad_norm)

            self.iteration += 1

            if iteration % 10 == 0:
                logger.info(f"IRL iteration {iteration}: gradient_norm={grad_norm:.6f}")

            if grad_norm < tolerance:
                logger.info(f"IRL converged at iteration {iteration}")
                break

        return {
            "iterations": self.iteration,
            "final_gradient_norm": self.convergence_history[-1] if self.convergence_history else 0,
            "reward_weights": self.reward_weights,
            "converged": self.convergence_history[-1] < tolerance if self.convergence_history else False
        }

    def get_reward_function(self) -> Callable[[Any], float]:
        """
        Get the learned reward function.

        Returns:
            Callable that maps states to rewards
        """
        return self.compute_reward


# ============================================================================
# Generative Adversarial Imitation Learning (GAIL)
# ============================================================================

class Discriminator:
    """
    Discriminator network for GAIL.

    Distinguishes between expert demonstrations and learned policy rollouts.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dim: int = 64,
        learning_rate: float = 0.001
    ):
        """
        Initialize the discriminator.

        Args:
            state_dim: State feature dimensionality
            action_dim: Action space dimensionality
            hidden_dim: Hidden layer size
            learning_rate: Learning rate
        """
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate

        # Network weights (state-action -> hidden -> 1)
        input_dim = state_dim + action_dim
        scale1 = math.sqrt(2.0 / (input_dim + hidden_dim))
        scale2 = math.sqrt(2.0 / (hidden_dim + 1))

        self.w1 = [[random.gauss(0, scale1) for _ in range(hidden_dim)] for _ in range(input_dim)]
        self.b1 = [0.0 for _ in range(hidden_dim)]
        self.w2 = [[random.gauss(0, scale2)] for _ in range(hidden_dim)]
        self.b2 = [0.0]

        self.expert_accuracy = 0.0
        self.policy_accuracy = 0.0

        logger.info(f"Discriminator initialized: ({state_dim}+{action_dim}) -> {hidden_dim} -> 1")

    def _sigmoid(self, x: float) -> float:
        """Sigmoid activation."""
        return 1.0 / (1.0 + math.exp(-max(-500, min(500, x))))

    def _forward(self, state: List[float], action: List[float]) -> float:
        """
        Forward pass returning probability of being expert.

        Args:
            state: State features
            action: One-hot action vector

        Returns:
            Probability that (state, action) is from expert
        """
        x = state + action

        # Hidden layer
        h = []
        for j in range(self.hidden_dim):
            val = self.b1[j]
            for i, xi in enumerate(x):
                if i < len(self.w1):
                    val += xi * self.w1[i][j]
            h.append(max(0.0, val))  # ReLU

        # Output layer
        out = self.b2[0]
        for j, hj in enumerate(h):
            out += hj * self.w2[j][0]

        return self._sigmoid(out)

    def predict(self, state: List[float], action: List[float]) -> float:
        """
        Predict probability that state-action is from expert.

        Args:
            state: State features
            action: One-hot action vector

        Returns:
            Probability [0, 1]
        """
        return self._forward(state, action)

    def get_reward(self, state: List[float], action: List[float]) -> float:
        """
        Get GAIL reward signal.

        Args:
            state: State features
            action: One-hot action vector

        Returns:
            Reward signal based on discriminator output
        """
        d = self._forward(state, action)
        # Reward = -log(1 - D(s,a)) for GAIL
        return -math.log(max(1.0 - d, 1e-10))

    def train_step(
        self,
        expert_states: List[List[float]],
        expert_actions: List[List[float]],
        policy_states: List[List[float]],
        policy_actions: List[List[float]]
    ) -> Dict[str, float]:
        """
        Train discriminator on expert and policy data.

        Args:
            expert_states: Expert state batch
            expert_actions: Expert action batch
            policy_states: Policy state batch
            policy_actions: Policy action batch

        Returns:
            Training metrics
        """
        total_loss = 0.0
        expert_correct = 0
        policy_correct = 0

        # Initialize gradients
        grad_w1 = [[0.0 for _ in row] for row in self.w1]
        grad_b1 = [0.0 for _ in self.b1]
        grad_w2 = [[0.0] for _ in self.w2]
        grad_b2 = [0.0]

        # Expert samples (label = 1)
        for state, action in zip(expert_states, expert_actions):
            d = self._forward(state, action)
            total_loss -= math.log(max(d, 1e-10))
            if d > 0.5:
                expert_correct += 1

            # Simplified gradient update
            error = d - 1.0
            x = state + action
            for i in range(len(x)):
                for j in range(self.hidden_dim):
                    if i < len(grad_w1):
                        grad_w1[i][j] += error * x[i] * 0.1

        # Policy samples (label = 0)
        for state, action in zip(policy_states, policy_actions):
            d = self._forward(state, action)
            total_loss -= math.log(max(1.0 - d, 1e-10))
            if d < 0.5:
                policy_correct += 1

            error = d - 0.0
            x = state + action
            for i in range(len(x)):
                for j in range(self.hidden_dim):
                    if i < len(grad_w1):
                        grad_w1[i][j] += error * x[i] * 0.1

        # Apply gradients
        batch_size = len(expert_states) + len(policy_states)
        for i in range(len(self.w1)):
            for j in range(len(self.w1[i])):
                self.w1[i][j] -= self.learning_rate * grad_w1[i][j] / max(batch_size, 1)

        self.expert_accuracy = expert_correct / max(len(expert_states), 1)
        self.policy_accuracy = policy_correct / max(len(policy_states), 1)

        return {
            "loss": total_loss / max(batch_size, 1),
            "expert_accuracy": self.expert_accuracy,
            "policy_accuracy": self.policy_accuracy
        }


class GAIL:
    """
    Generative Adversarial Imitation Learning.

    Uses a discriminator to distinguish expert from policy rollouts,
    using the discriminator signal as a reward for policy optimization.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        discriminator: Optional[Discriminator] = None,
        policy: Optional[BehaviorCloner] = None,
        discriminator_steps: int = 5,
        policy_steps: int = 3
    ):
        """
        Initialize GAIL.

        Args:
            state_dim: State feature dimensionality
            action_dim: Action space dimensionality
            discriminator: Pre-initialized discriminator (or create new)
            policy: Pre-initialized policy (or create new)
            discriminator_steps: Discriminator updates per iteration
            policy_steps: Policy updates per iteration
        """
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.discriminator = discriminator or Discriminator(state_dim, action_dim)
        self.policy = policy or BehaviorCloner(state_dim, action_dim)
        self.discriminator_steps = discriminator_steps
        self.policy_steps = policy_steps

        self.training_history: List[Dict[str, float]] = []
        self.iteration = 0

        logger.info("GAIL initialized")

    def _one_hot_action(self, action: int) -> List[float]:
        """Convert action index to one-hot vector."""
        one_hot = [0.0] * self.action_dim
        if 0 <= action < self.action_dim:
            one_hot[action] = 1.0
        return one_hot

    def train_iteration(
        self,
        expert_demonstrations: List[Demonstration],
        policy_rollouts: List[List[Tuple[Any, Any]]],
        state_encoder: Callable[[Any], List[float]],
        action_encoder: Callable[[Any], int]
    ) -> Dict[str, float]:
        """
        Perform one GAIL training iteration.

        Args:
            expert_demonstrations: Expert demonstrations
            policy_rollouts: Trajectories from current policy
            state_encoder: State encoding function
            action_encoder: Action encoding function

        Returns:
            Training metrics
        """
        # Prepare expert data
        expert_pairs = []
        for demo in expert_demonstrations:
            for state, action in demo.trajectory:
                expert_pairs.append((state_encoder(state), self._one_hot_action(action_encoder(action))))

        # Prepare policy data
        policy_pairs = []
        for rollout in policy_rollouts:
            for state, action in rollout:
                policy_pairs.append((state_encoder(state), self._one_hot_action(action_encoder(action))))

        if not expert_pairs or not policy_pairs:
            return {"error": "Insufficient data"}

        # Train discriminator
        disc_metrics = {"loss": 0.0, "expert_accuracy": 0.0, "policy_accuracy": 0.0}
        for _ in range(self.discriminator_steps):
            # Sample mini-batch
            batch_size = min(32, len(expert_pairs), len(policy_pairs))
            expert_batch = random.sample(expert_pairs, batch_size)
            policy_batch = random.sample(policy_pairs, batch_size)

            expert_states = [p[0] for p in expert_batch]
            expert_actions = [p[1] for p in expert_batch]
            policy_states = [p[0] for p in policy_batch]
            policy_actions = [p[1] for p in policy_batch]

            metrics = self.discriminator.train_step(
                expert_states, expert_actions,
                policy_states, policy_actions
            )
            for k, v in metrics.items():
                disc_metrics[k] = v

        # Update policy using discriminator reward
        # (Simplified - in practice you'd use policy gradients)
        policy_rewards = []
        for rollout in policy_rollouts:
            trajectory_reward = 0.0
            for state, action in rollout:
                s = state_encoder(state)
                a = self._one_hot_action(action_encoder(action))
                trajectory_reward += self.discriminator.get_reward(s, a)
            policy_rewards.append(trajectory_reward / max(len(rollout), 1))

        avg_reward = sum(policy_rewards) / max(len(policy_rewards), 1)

        self.iteration += 1
        result = {
            "iteration": self.iteration,
            "discriminator_loss": disc_metrics["loss"],
            "expert_accuracy": disc_metrics["expert_accuracy"],
            "policy_accuracy": disc_metrics["policy_accuracy"],
            "average_policy_reward": avg_reward
        }

        self.training_history.append(result)
        logger.info(f"GAIL iteration {self.iteration}: avg_reward={avg_reward:.4f}")

        return result

    def train(
        self,
        collector: DemonstrationCollector,
        policy_sampler: Callable[[], List[Tuple[Any, Any]]],
        state_encoder: Callable[[Any], List[float]],
        action_encoder: Callable[[Any], int],
        num_iterations: int = 100,
        rollouts_per_iteration: int = 10
    ) -> Dict[str, Any]:
        """
        Full GAIL training loop.

        Args:
            collector: Demonstration collector
            policy_sampler: Function to sample policy rollouts
            state_encoder: State encoding function
            action_encoder: Action encoding function
            num_iterations: Number of training iterations
            rollouts_per_iteration: Policy rollouts per iteration

        Returns:
            Training summary
        """
        demonstrations = collector.demonstrations
        if not demonstrations:
            return {"error": "No demonstrations"}

        for i in range(num_iterations):
            # Generate policy rollouts
            rollouts = [policy_sampler() for _ in range(rollouts_per_iteration)]
            rollouts = [r for r in rollouts if r]  # Filter empty

            if not rollouts:
                logger.warning(f"No valid rollouts at iteration {i}")
                continue

            self.train_iteration(
                demonstrations,
                rollouts,
                state_encoder,
                action_encoder
            )

        return {
            "total_iterations": self.iteration,
            "final_avg_reward": self.training_history[-1]["average_policy_reward"] if self.training_history else 0,
            "history": self.training_history
        }


# ============================================================================
# DAgger (Dataset Aggregation)
# ============================================================================

class DAgger:
    """
    Dataset Aggregation (DAgger) algorithm.

    Iteratively improves the policy by:
    1. Running the current policy
    2. Getting expert labels for visited states
    3. Aggregating new data with previous dataset
    4. Retraining the policy
    """

    def __init__(
        self,
        policy: BehaviorCloner,
        expert_oracle: Callable[[Any], Any],
        state_encoder: Callable[[Any], List[float]],
        action_encoder: Callable[[Any], int],
        beta_schedule: str = "linear"
    ):
        """
        Initialize DAgger.

        Args:
            policy: Initial policy (typically from BC)
            expert_oracle: Function that provides expert action for any state
            state_encoder: State encoding function
            action_encoder: Action encoding function
            beta_schedule: Schedule for mixing policy/expert ("linear", "exponential")
        """
        self.policy = policy
        self.expert_oracle = expert_oracle
        self.state_encoder = state_encoder
        self.action_encoder = action_encoder
        self.beta_schedule = beta_schedule

        # Aggregated dataset
        self.dataset: List[Tuple[List[float], int]] = []
        self.iteration = 0

        self.performance_history: List[Dict[str, float]] = []

        logger.info("DAgger initialized")

    def _get_beta(self, iteration: int, max_iterations: int) -> float:
        """
        Get the expert mixing coefficient.

        Args:
            iteration: Current iteration
            max_iterations: Total iterations

        Returns:
            Beta in [0, 1] where 1 = full expert, 0 = full policy
        """
        if self.beta_schedule == "linear":
            return max(0.0, 1.0 - iteration / max_iterations)
        elif self.beta_schedule == "exponential":
            return 0.99 ** iteration
        else:
            return 0.5

    def collect_trajectory(
        self,
        initial_state: Any,
        transition_fn: Callable[[Any, Any], Any],
        max_steps: int = 100,
        beta: float = 0.5
    ) -> List[Tuple[Any, Any, Any]]:
        """
        Collect a trajectory using beta-mixture of policy and expert.

        Args:
            initial_state: Starting state
            transition_fn: Function(state, action) -> next_state
            max_steps: Maximum trajectory length
            beta: Expert mixing coefficient

        Returns:
            List of (state, executed_action, expert_action) tuples
        """
        trajectory = []
        state = initial_state

        for _ in range(max_steps):
            # Get expert action
            expert_action = self.expert_oracle(state)

            # Get policy action
            state_encoded = self.state_encoder(state)
            policy_action_idx = self.policy.predict(state_encoded)

            # Mix actions based on beta
            if random.random() < beta:
                executed_action = expert_action
            else:
                executed_action = policy_action_idx

            trajectory.append((state, executed_action, expert_action))

            # Transition to next state
            try:
                state = transition_fn(state, executed_action)
            except Exception:
                break

        return trajectory

    def aggregate_data(
        self,
        trajectories: List[List[Tuple[Any, Any, Any]]]
    ):
        """
        Aggregate new data from trajectories.

        Args:
            trajectories: List of trajectories with expert labels
        """
        for trajectory in trajectories:
            for state, _, expert_action in trajectory:
                state_encoded = self.state_encoder(state)
                action_encoded = self.action_encoder(expert_action)
                self.dataset.append((state_encoded, action_encoded))

        logger.info(f"Dataset size: {len(self.dataset)}")

    def train_policy(
        self,
        epochs: int = 10,
        batch_size: int = 32
    ) -> float:
        """
        Retrain policy on aggregated dataset.

        Args:
            epochs: Training epochs
            batch_size: Batch size

        Returns:
            Final training loss
        """
        if not self.dataset:
            return 0.0

        # Shuffle dataset
        random.shuffle(self.dataset)

        total_loss = 0.0
        n_batches = max(1, len(self.dataset) // batch_size)

        for epoch in range(epochs):
            epoch_loss = 0.0
            for i in range(n_batches):
                start = i * batch_size
                end = min(start + batch_size, len(self.dataset))
                batch = self.dataset[start:end]

                if batch:
                    states = [b[0] for b in batch]
                    actions = [b[1] for b in batch]
                    loss = self.policy.train_step(states, actions)
                    epoch_loss += loss

            total_loss = epoch_loss / n_batches

        return total_loss

    def evaluate_policy(
        self,
        evaluation_states: List[Any]
    ) -> Dict[str, float]:
        """
        Evaluate current policy against expert.

        Args:
            evaluation_states: States to evaluate on

        Returns:
            Evaluation metrics
        """
        if not evaluation_states:
            return {"accuracy": 0.0}

        correct = 0
        for state in evaluation_states:
            expert_action = self.expert_oracle(state)
            state_encoded = self.state_encoder(state)
            policy_action = self.policy.predict(state_encoded)

            if policy_action == self.action_encoder(expert_action):
                correct += 1

        return {"accuracy": correct / len(evaluation_states)}

    def run(
        self,
        initial_states: List[Any],
        transition_fn: Callable[[Any, Any], Any],
        num_iterations: int = 10,
        trajectories_per_iteration: int = 5,
        max_trajectory_length: int = 100,
        training_epochs: int = 10
    ) -> Dict[str, Any]:
        """
        Run the full DAgger algorithm.

        Args:
            initial_states: List of initial states for rollouts
            transition_fn: State transition function
            num_iterations: DAgger iterations
            trajectories_per_iteration: Rollouts per iteration
            max_trajectory_length: Maximum trajectory length
            training_epochs: Policy training epochs per iteration

        Returns:
            Training summary
        """
        for iteration in range(num_iterations):
            self.iteration = iteration
            beta = self._get_beta(iteration, num_iterations)

            logger.info(f"DAgger iteration {iteration}: beta={beta:.3f}")

            # Collect trajectories
            trajectories = []
            for _ in range(trajectories_per_iteration):
                init_state = random.choice(initial_states)
                traj = self.collect_trajectory(
                    init_state,
                    transition_fn,
                    max_trajectory_length,
                    beta
                )
                trajectories.append(traj)

            # Aggregate data
            self.aggregate_data(trajectories)

            # Retrain policy
            loss = self.train_policy(epochs=training_epochs)

            # Evaluate
            eval_states = random.sample(initial_states, min(20, len(initial_states)))
            metrics = self.evaluate_policy(eval_states)
            metrics["iteration"] = iteration
            metrics["beta"] = beta
            metrics["loss"] = loss
            metrics["dataset_size"] = len(self.dataset)

            self.performance_history.append(metrics)

            logger.info(f"DAgger iteration {iteration}: accuracy={metrics['accuracy']:.4f}, loss={loss:.4f}")

        return {
            "total_iterations": num_iterations,
            "final_accuracy": self.performance_history[-1]["accuracy"] if self.performance_history else 0,
            "final_dataset_size": len(self.dataset),
            "history": self.performance_history
        }


# ============================================================================
# Unified Imitation Learning System
# ============================================================================

class ImitationLearningSystem:
    """
    Unified imitation learning system for AIVA Queen.

    Coordinates all imitation learning components:
    - DemonstrationCollector
    - BehaviorCloner
    - InverseRL
    - GAIL
    - DAgger
    """

    def __init__(
        self,
        state_dim: int = 32,
        action_dim: int = 10,
        storage_path: Optional[Path] = None
    ):
        """
        Initialize the unified imitation learning system.

        Args:
            state_dim: State feature dimensionality
            action_dim: Action space size
            storage_path: Path for persistent storage
        """
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.storage_path = storage_path or Path("./imitation_learning")
        self.storage_path.mkdir(parents=True, exist_ok=True)

        # Initialize components
        self.collector = DemonstrationCollector(
            storage_path=self.storage_path / "demonstrations"
        )
        self.behavior_cloner = BehaviorCloner(state_dim, action_dim)
        self.inverse_rl = InverseRL(state_dim)
        self.discriminator = Discriminator(state_dim, action_dim)
        self.gail = GAIL(
            state_dim, action_dim,
            discriminator=self.discriminator,
            policy=self.behavior_cloner
        )

        # Encoders (can be customized)
        self.state_encoder: Callable[[Any], List[float]] = lambda s: (
            s if isinstance(s, list) and len(s) == state_dim
            else [float(hash(str(s)) % 1000) / 1000.0 for _ in range(state_dim)]
        )
        self.action_encoder: Callable[[Any], int] = lambda a: (
            a if isinstance(a, int) and 0 <= a < action_dim
            else hash(str(a)) % action_dim
        )

        self.training_mode: Optional[str] = None
        self.training_stats: Dict[str, Any] = {}

        logger.info(f"ImitationLearningSystem initialized: state_dim={state_dim}, action_dim={action_dim}")

    def set_encoders(
        self,
        state_encoder: Callable[[Any], List[float]],
        action_encoder: Callable[[Any], int]
    ):
        """Set custom state and action encoders."""
        self.state_encoder = state_encoder
        self.action_encoder = action_encoder

    def train_behavior_cloning(
        self,
        epochs: int = 100,
        batch_size: int = 32
    ) -> Dict[str, Any]:
        """
        Train using pure behavior cloning.

        Args:
            epochs: Training epochs
            batch_size: Batch size

        Returns:
            Training statistics
        """
        self.training_mode = "behavior_cloning"
        stats = self.behavior_cloner.train(
            self.collector,
            epochs=epochs,
            batch_size=batch_size,
            state_encoder=self.state_encoder,
            action_encoder=self.action_encoder
        )
        self.training_stats["behavior_cloning"] = stats
        return stats

    def train_inverse_rl(
        self,
        policy_sampler: Callable[[], List[Tuple[Any, Any]]],
        max_iterations: int = 100
    ) -> Dict[str, Any]:
        """
        Train using inverse reinforcement learning.

        Args:
            policy_sampler: Function to sample policy trajectories
            max_iterations: Maximum iterations

        Returns:
            Training statistics
        """
        self.training_mode = "inverse_rl"
        stats = self.inverse_rl.learn_reward(
            self.collector,
            policy_sampler,
            max_iterations=max_iterations
        )
        self.training_stats["inverse_rl"] = stats
        return stats

    def train_gail(
        self,
        policy_sampler: Callable[[], List[Tuple[Any, Any]]],
        num_iterations: int = 100
    ) -> Dict[str, Any]:
        """
        Train using GAIL.

        Args:
            policy_sampler: Function to sample policy trajectories
            num_iterations: Training iterations

        Returns:
            Training statistics
        """
        self.training_mode = "gail"
        stats = self.gail.train(
            self.collector,
            policy_sampler,
            self.state_encoder,
            self.action_encoder,
            num_iterations=num_iterations
        )
        self.training_stats["gail"] = stats
        return stats

    def train_dagger(
        self,
        expert_oracle: Callable[[Any], Any],
        initial_states: List[Any],
        transition_fn: Callable[[Any, Any], Any],
        num_iterations: int = 10
    ) -> Dict[str, Any]:
        """
        Train using DAgger.

        Args:
            expert_oracle: Expert action provider
            initial_states: Initial states for rollouts
            transition_fn: State transition function
            num_iterations: DAgger iterations

        Returns:
            Training statistics
        """
        self.training_mode = "dagger"
        dagger = DAgger(
            self.behavior_cloner,
            expert_oracle,
            self.state_encoder,
            self.action_encoder
        )
        stats = dagger.run(
            initial_states,
            transition_fn,
            num_iterations=num_iterations
        )
        self.training_stats["dagger"] = stats
        return stats

    def predict(self, state: Any) -> int:
        """
        Predict action for a state using trained policy.

        Args:
            state: Input state

        Returns:
            Action index
        """
        encoded = self.state_encoder(state)
        return self.behavior_cloner.predict(encoded)

    def get_reward(self, state: Any) -> float:
        """
        Get learned reward for a state.

        Args:
            state: Input state

        Returns:
            Reward value
        """
        return self.inverse_rl.compute_reward(state)

    def is_expert_like(self, state: Any, action: Any) -> float:
        """
        Check how expert-like a state-action pair is.

        Args:
            state: Input state
            action: Input action

        Returns:
            Probability of being expert-like
        """
        state_encoded = self.state_encoder(state)
        action_encoded = self.action_encoder(action)
        action_one_hot = [0.0] * self.action_dim
        if 0 <= action_encoded < self.action_dim:
            action_one_hot[action_encoded] = 1.0
        return self.discriminator.predict(state_encoded, action_one_hot)

    def save(self, filename: str = "imitation_system.json"):
        """Save system state."""
        filepath = self.storage_path / filename

        state = {
            "state_dim": self.state_dim,
            "action_dim": self.action_dim,
            "training_mode": self.training_mode,
            "training_stats": self.training_stats,
            "behavior_cloner": self.behavior_cloner.get_policy_parameters(),
            "inverse_rl_weights": self.inverse_rl.reward_weights,
            "collector_stats": self.collector.get_statistics()
        }

        with open(filepath, 'w') as f:
            json.dump(state, f, indent=2, default=str)

        self.collector.save()
        logger.info(f"Saved imitation learning system to {filepath}")

    def get_summary(self) -> Dict[str, Any]:
        """Get system summary."""
        return {
            "state_dim": self.state_dim,
            "action_dim": self.action_dim,
            "training_mode": self.training_mode,
            "demonstrations_collected": self.collector.get_statistics()["total_collected"],
            "training_stats": self.training_stats
        }


# ============================================================================
# Example Usage and Testing
# ============================================================================

def example_environment():
    """Create a simple grid world example environment."""

    class GridWorld:
        def __init__(self, size: int = 5):
            self.size = size
            self.goal = (size - 1, size - 1)

        def reset(self) -> Tuple[int, int]:
            return (0, 0)

        def step(self, state: Tuple[int, int], action: int) -> Tuple[Tuple[int, int], float, bool]:
            x, y = state
            # Actions: 0=up, 1=right, 2=down, 3=left
            dx, dy = [(0, -1), (1, 0), (0, 1), (-1, 0)][action % 4]
            nx, ny = max(0, min(self.size-1, x+dx)), max(0, min(self.size-1, y+dy))
            new_state = (nx, ny)
            reward = 1.0 if new_state == self.goal else -0.1
            done = new_state == self.goal
            return new_state, reward, done

        def expert_action(self, state: Tuple[int, int]) -> int:
            x, y = state
            gx, gy = self.goal
            if x < gx:
                return 1  # right
            elif y < gy:
                return 2  # down
            elif x > gx:
                return 3  # left
            else:
                return 0  # up

    return GridWorld()


def main():
    """Main demonstration of the imitation learning system."""
    logger.info("=" * 60)
    logger.info("AIVA Queen Imitation Learning System - Demo")
    logger.info("=" * 60)

    # Initialize system
    system = ImitationLearningSystem(
        state_dim=2,  # (x, y) coordinates
        action_dim=4   # 4 directions
    )

    # Set up encoders for grid world
    system.set_encoders(
        state_encoder=lambda s: [float(s[0]) / 5.0, float(s[1]) / 5.0],
        action_encoder=lambda a: a % 4
    )

    # Create environment
    env = example_environment()

    # Register an expert
    expert_id = system.collector.register_expert(
        name="GridWorldExpert",
        domains=["navigation"],
        skill_level=0.95
    )

    # Collect expert demonstrations
    logger.info("\n--- Collecting Expert Demonstrations ---")
    for i in range(20):
        with system.collector.start_recording(expert_id, "navigation") as session:
            state = env.reset()
            for _ in range(20):
                action = env.expert_action(state)
                session.record(state, action)
                new_state, reward, done = env.step(state, action)
                state = new_state
                if done:
                    break

    logger.info(f"Collected demonstrations: {system.collector.get_statistics()}")

    # Train with behavior cloning
    logger.info("\n--- Training Behavior Cloning ---")
    bc_stats = system.train_behavior_cloning(epochs=50)
    logger.info(f"BC Results: {bc_stats}")

    # Test policy
    logger.info("\n--- Testing Learned Policy ---")
    state = env.reset()
    trajectory = [state]
    for _ in range(20):
        action = system.predict(state)
        new_state, reward, done = env.step(state, action)
        trajectory.append(new_state)
        state = new_state
        if done:
            logger.info("Policy reached goal!")
            break

    logger.info(f"Trajectory: {trajectory}")

    # Train GAIL (if needed)
    def policy_sampler():
        state = env.reset()
        traj = []
        for _ in range(20):
            action = system.predict(state)
            traj.append((state, action))
            new_state, _, done = env.step(state, action)
            state = new_state
            if done:
                break
        return traj

    logger.info("\n--- Training GAIL ---")
    gail_stats = system.train_gail(policy_sampler, num_iterations=20)
    logger.info(f"GAIL Results: avg_reward={gail_stats.get('final_avg_reward', 0):.4f}")

    # Summary
    logger.info("\n--- System Summary ---")
    logger.info(json.dumps(system.get_summary(), indent=2, default=str))

    # Save
    system.save()

    logger.info("\n" + "=" * 60)
    logger.info("Demo Complete!")
    logger.info("=" * 60)


if __name__ == "__main__":
    main()
