"""
PPO (Proximal Policy Optimization) Training Engine for LLM Fine-Tuning

This module implements a complete PPO training pipeline optimized for language model
reinforcement learning from human feedback (RLHF). It includes:

- PPOTrainer: Full PPO algorithm with clipped objectives
- ValueHead: Value function estimator with configurable architectures
- PolicyGradient: Policy gradient computation with entropy regularization
- GAE: Generalized Advantage Estimation for variance reduction
- KLDivergence: KL penalty controller for policy updates
- RolloutBuffer: Experience replay buffer with priority sampling

Author: Genesis-OS RLM Module
Version: 1.0.0
"""

import math
import time
import logging
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Tuple,
    Union,
    Iterator,
    TypeVar,
    Generic,
)
from collections import deque
from enum import Enum
import random
import copy

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("PPOEngine")


# Type aliases
Tensor = List[float]  # Simplified tensor representation
BatchTensor = List[Tensor]
T = TypeVar("T")


class ActivationFunction(Enum):
    """Supported activation functions for neural network layers."""
    RELU = "relu"
    TANH = "tanh"
    GELU = "gelu"
    SILU = "silu"
    SIGMOID = "sigmoid"


@dataclass
class PPOConfig:
    """Configuration for PPO training hyperparameters."""

    # Core PPO parameters
    clip_epsilon: float = 0.2
    value_clip_epsilon: float = 0.2
    entropy_coefficient: float = 0.01
    value_coefficient: float = 0.5
    max_grad_norm: float = 0.5

    # Learning rates
    policy_lr: float = 3e-4
    value_lr: float = 1e-3

    # GAE parameters
    gamma: float = 0.99
    gae_lambda: float = 0.95

    # Training parameters
    batch_size: int = 64
    mini_batch_size: int = 8
    ppo_epochs: int = 4
    rollout_buffer_size: int = 2048

    # KL divergence control
    target_kl: float = 0.01
    kl_coefficient: float = 0.2
    adaptive_kl: bool = True
    kl_horizon: int = 10000

    # Value function
    value_hidden_dims: List[int] = field(default_factory=lambda: [256, 256])
    value_activation: ActivationFunction = ActivationFunction.TANH

    # Regularization
    weight_decay: float = 0.0
    dropout_rate: float = 0.0

    # Device and precision
    use_mixed_precision: bool = False
    gradient_accumulation_steps: int = 1

    # Logging
    log_interval: int = 100
    save_interval: int = 1000


@dataclass
class Experience:
    """Single experience tuple for PPO training."""

    state: Any  # Token IDs or embeddings
    action: Any  # Generated token or action
    action_log_prob: float  # Log probability of action
    reward: float  # Scalar reward
    value: float  # Value estimate
    done: bool  # Episode termination flag
    info: Dict[str, Any] = field(default_factory=dict)


@dataclass
class RolloutBatch:
    """Batch of experiences for PPO update."""

    states: List[Any]
    actions: List[Any]
    action_log_probs: Tensor
    rewards: Tensor
    values: Tensor
    advantages: Tensor
    returns: Tensor
    dones: List[bool]


class MathOps:
    """Mathematical operations for PPO computations (numpy-free implementation)."""

    @staticmethod
    def exp(x: float) -> float:
        """Compute exponential."""
        return math.exp(min(x, 700))  # Prevent overflow

    @staticmethod
    def log(x: float, eps: float = 1e-8) -> float:
        """Compute natural logarithm with numerical stability."""
        return math.log(max(x, eps))

    @staticmethod
    def softmax(logits: Tensor) -> Tensor:
        """Compute softmax of logits."""
        max_logit = max(logits)
        exp_logits = [MathOps.exp(x - max_logit) for x in logits]
        sum_exp = sum(exp_logits)
        return [x / sum_exp for x in exp_logits]

    @staticmethod
    def log_softmax(logits: Tensor) -> Tensor:
        """Compute log softmax for numerical stability."""
        max_logit = max(logits)
        shifted = [x - max_logit for x in logits]
        log_sum_exp = MathOps.log(sum(MathOps.exp(x) for x in shifted))
        return [x - log_sum_exp for x in shifted]

    @staticmethod
    def mean(values: Tensor) -> float:
        """Compute mean of values."""
        if not values:
            return 0.0
        return sum(values) / len(values)

    @staticmethod
    def std(values: Tensor, eps: float = 1e-8) -> float:
        """Compute standard deviation."""
        if len(values) < 2:
            return eps
        mean_val = MathOps.mean(values)
        variance = sum((x - mean_val) ** 2 for x in values) / (len(values) - 1)
        return math.sqrt(variance + eps)

    @staticmethod
    def normalize(values: Tensor, eps: float = 1e-8) -> Tensor:
        """Normalize values to zero mean and unit variance."""
        mean_val = MathOps.mean(values)
        std_val = MathOps.std(values, eps)
        return [(x - mean_val) / std_val for x in values]

    @staticmethod
    def clip(value: float, min_val: float, max_val: float) -> float:
        """Clip value to range."""
        return max(min_val, min(max_val, value))

    @staticmethod
    def dot_product(a: Tensor, b: Tensor) -> float:
        """Compute dot product of two vectors."""
        return sum(x * y for x, y in zip(a, b))


class ValueHead:
    """
    Value function head for estimating state values.

    Implements a multi-layer perceptron (MLP) that estimates the expected
    cumulative reward from a given state. Used as the critic in actor-critic
    PPO architecture.
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dims: List[int] = None,
        activation: ActivationFunction = ActivationFunction.TANH,
        dropout_rate: float = 0.0,
        output_scale: float = 1.0,
    ):
        """
        Initialize ValueHead.

        Args:
            input_dim: Dimension of input features
            hidden_dims: List of hidden layer dimensions
            activation: Activation function to use
            dropout_rate: Dropout probability (for training)
            output_scale: Scale factor for output values
        """
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims or [256, 256]
        self.activation = activation
        self.dropout_rate = dropout_rate
        self.output_scale = output_scale

        # Initialize weights
        self.weights: List[List[Tensor]] = []
        self.biases: List[Tensor] = []
        self._initialize_weights()

        # Training statistics
        self.training = True
        self.value_history: deque = deque(maxlen=1000)

    def _initialize_weights(self):
        """Initialize network weights using Xavier initialization."""
        dims = [self.input_dim] + self.hidden_dims + [1]

        for i in range(len(dims) - 1):
            fan_in, fan_out = dims[i], dims[i + 1]
            std = math.sqrt(2.0 / (fan_in + fan_out))

            # Weight matrix: fan_out x fan_in
            weight_matrix = [
                [random.gauss(0, std) for _ in range(fan_in)]
                for _ in range(fan_out)
            ]
            bias_vector = [0.0 for _ in range(fan_out)]

            self.weights.append(weight_matrix)
            self.biases.append(bias_vector)

    def _apply_activation(self, x: float) -> float:
        """Apply activation function."""
        if self.activation == ActivationFunction.RELU:
            return max(0.0, x)
        elif self.activation == ActivationFunction.TANH:
            return math.tanh(x)
        elif self.activation == ActivationFunction.GELU:
            # Approximation of GELU
            return 0.5 * x * (1 + math.tanh(
                math.sqrt(2 / math.pi) * (x + 0.044715 * x ** 3)
            ))
        elif self.activation == ActivationFunction.SILU:
            return x / (1 + MathOps.exp(-x))
        elif self.activation == ActivationFunction.SIGMOID:
            return 1 / (1 + MathOps.exp(-x))
        return x

    def forward(self, x: Tensor) -> float:
        """
        Forward pass through the value network.

        Args:
            x: Input tensor (state representation)

        Returns:
            Estimated state value
        """
        hidden = x

        # Process through hidden layers
        for i, (weight, bias) in enumerate(zip(self.weights[:-1], self.biases[:-1])):
            # Linear transformation
            new_hidden = []
            for j in range(len(weight)):
                neuron_output = MathOps.dot_product(weight[j], hidden) + bias[j]
                neuron_output = self._apply_activation(neuron_output)
                new_hidden.append(neuron_output)
            hidden = new_hidden

        # Output layer (no activation)
        weight, bias = self.weights[-1], self.biases[-1]
        value = MathOps.dot_product(weight[0], hidden) + bias[0]

        # Scale output
        value *= self.output_scale

        # Track statistics
        if self.training:
            self.value_history.append(value)

        return value

    def get_statistics(self) -> Dict[str, float]:
        """Get value head statistics."""
        if not self.value_history:
            return {"mean": 0.0, "std": 0.0, "min": 0.0, "max": 0.0}

        values = list(self.value_history)
        return {
            "mean": MathOps.mean(values),
            "std": MathOps.std(values),
            "min": min(values),
            "max": max(values),
        }


class PolicyGradient:
    """
    Policy gradient computation for PPO.

    Computes the clipped surrogate objective and entropy bonus
    for policy optimization.
    """

    def __init__(
        self,
        clip_epsilon: float = 0.2,
        entropy_coefficient: float = 0.01,
        normalize_advantages: bool = True,
    ):
        """
        Initialize PolicyGradient.

        Args:
            clip_epsilon: PPO clipping parameter
            entropy_coefficient: Weight for entropy bonus
            normalize_advantages: Whether to normalize advantages
        """
        self.clip_epsilon = clip_epsilon
        self.entropy_coefficient = entropy_coefficient
        self.normalize_advantages = normalize_advantages

        # Tracking metrics
        self.policy_losses: deque = deque(maxlen=100)
        self.entropy_values: deque = deque(maxlen=100)
        self.clip_fractions: deque = deque(maxlen=100)
        self.approx_kl_divs: deque = deque(maxlen=100)

    def compute_entropy(self, log_probs: Tensor) -> float:
        """
        Compute entropy from log probabilities.

        Args:
            log_probs: Log probabilities of action distribution

        Returns:
            Entropy value
        """
        probs = [MathOps.exp(lp) for lp in log_probs]
        entropy = -sum(p * lp for p, lp in zip(probs, log_probs) if p > 1e-8)
        return entropy

    def compute_ratio(
        self,
        new_log_prob: float,
        old_log_prob: float,
    ) -> float:
        """
        Compute probability ratio for importance sampling.

        Args:
            new_log_prob: Log probability under new policy
            old_log_prob: Log probability under old policy

        Returns:
            Probability ratio
        """
        log_ratio = new_log_prob - old_log_prob
        ratio = MathOps.exp(log_ratio)
        return ratio

    def compute_clipped_objective(
        self,
        ratio: float,
        advantage: float,
    ) -> Tuple[float, bool]:
        """
        Compute PPO clipped surrogate objective.

        Args:
            ratio: Probability ratio (pi_new / pi_old)
            advantage: Advantage estimate

        Returns:
            Tuple of (objective value, whether clipping occurred)
        """
        # Unclipped objective
        unclipped = ratio * advantage

        # Clipped objective
        clipped_ratio = MathOps.clip(
            ratio,
            1.0 - self.clip_epsilon,
            1.0 + self.clip_epsilon
        )
        clipped = clipped_ratio * advantage

        # Take minimum (pessimistic bound)
        if advantage >= 0:
            objective = min(unclipped, clipped)
        else:
            objective = max(unclipped, clipped)

        was_clipped = abs(ratio - clipped_ratio) > 1e-8

        return objective, was_clipped

    def compute_policy_loss(
        self,
        new_log_probs: Tensor,
        old_log_probs: Tensor,
        advantages: Tensor,
        action_distributions: Optional[List[Tensor]] = None,
    ) -> Dict[str, float]:
        """
        Compute full policy loss with entropy regularization.

        Args:
            new_log_probs: Log probs under current policy
            old_log_probs: Log probs under old policy
            advantages: Advantage estimates
            action_distributions: Optional action distributions for entropy

        Returns:
            Dictionary with loss components
        """
        # Normalize advantages if requested
        if self.normalize_advantages and len(advantages) > 1:
            advantages = MathOps.normalize(advantages)

        # Compute per-sample losses
        objectives = []
        clip_count = 0
        kl_divs = []

        for new_lp, old_lp, adv in zip(new_log_probs, old_log_probs, advantages):
            ratio = self.compute_ratio(new_lp, old_lp)
            obj, was_clipped = self.compute_clipped_objective(ratio, adv)
            objectives.append(obj)

            if was_clipped:
                clip_count += 1

            # Approximate KL divergence
            kl = old_lp - new_lp  # First-order approximation
            kl_divs.append(kl)

        # Policy loss (negative because we want to maximize)
        policy_loss = -MathOps.mean(objectives)

        # Compute entropy bonus
        entropy_loss = 0.0
        if action_distributions:
            entropies = [self.compute_entropy(dist) for dist in action_distributions]
            entropy_loss = -self.entropy_coefficient * MathOps.mean(entropies)
            self.entropy_values.append(MathOps.mean(entropies))

        # Total loss
        total_loss = policy_loss + entropy_loss

        # Track metrics
        clip_fraction = clip_count / len(advantages) if advantages else 0.0
        approx_kl = MathOps.mean(kl_divs) if kl_divs else 0.0

        self.policy_losses.append(policy_loss)
        self.clip_fractions.append(clip_fraction)
        self.approx_kl_divs.append(approx_kl)

        return {
            "policy_loss": policy_loss,
            "entropy_loss": entropy_loss,
            "total_loss": total_loss,
            "clip_fraction": clip_fraction,
            "approx_kl": approx_kl,
        }

    def get_statistics(self) -> Dict[str, float]:
        """Get policy gradient statistics."""
        return {
            "mean_policy_loss": MathOps.mean(list(self.policy_losses)) if self.policy_losses else 0.0,
            "mean_entropy": MathOps.mean(list(self.entropy_values)) if self.entropy_values else 0.0,
            "mean_clip_fraction": MathOps.mean(list(self.clip_fractions)) if self.clip_fractions else 0.0,
            "mean_approx_kl": MathOps.mean(list(self.approx_kl_divs)) if self.approx_kl_divs else 0.0,
        }


class GAE:
    """
    Generalized Advantage Estimation (GAE).

    Implements the GAE algorithm from Schulman et al. (2015) for
    computing advantages with controlled bias-variance tradeoff.
    """

    def __init__(
        self,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        normalize_advantages: bool = True,
    ):
        """
        Initialize GAE.

        Args:
            gamma: Discount factor
            gae_lambda: GAE lambda for bias-variance tradeoff
            normalize_advantages: Whether to normalize computed advantages
        """
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.normalize_advantages = normalize_advantages

        # Statistics tracking
        self.advantage_history: deque = deque(maxlen=1000)
        self.return_history: deque = deque(maxlen=1000)

    def compute_advantages(
        self,
        rewards: Tensor,
        values: Tensor,
        dones: List[bool],
        next_value: float = 0.0,
    ) -> Tuple[Tensor, Tensor]:
        """
        Compute GAE advantages and returns.

        Args:
            rewards: List of rewards
            values: List of value estimates
            dones: List of episode termination flags
            next_value: Value estimate for next state after last step

        Returns:
            Tuple of (advantages, returns)
        """
        n_steps = len(rewards)
        advantages = [0.0] * n_steps
        returns = [0.0] * n_steps

        # Compute advantages backwards
        gae = 0.0
        for t in reversed(range(n_steps)):
            # Handle terminal states
            if t == n_steps - 1:
                next_non_terminal = 1.0 - float(dones[t])
                next_val = next_value
            else:
                next_non_terminal = 1.0 - float(dones[t])
                next_val = values[t + 1]

            # TD error
            delta = rewards[t] + self.gamma * next_val * next_non_terminal - values[t]

            # GAE
            gae = delta + self.gamma * self.gae_lambda * next_non_terminal * gae
            advantages[t] = gae

            # Returns (for value function training)
            returns[t] = advantages[t] + values[t]

        # Track statistics
        self.advantage_history.extend(advantages)
        self.return_history.extend(returns)

        # Normalize advantages
        if self.normalize_advantages and len(advantages) > 1:
            advantages = MathOps.normalize(advantages)

        return advantages, returns

    def compute_td_lambda_returns(
        self,
        rewards: Tensor,
        values: Tensor,
        dones: List[bool],
        next_value: float = 0.0,
    ) -> Tensor:
        """
        Compute TD(lambda) returns.

        Alternative to GAE that directly computes returns using
        exponentially-weighted sum of n-step returns.

        Args:
            rewards: List of rewards
            values: List of value estimates
            dones: Episode termination flags
            next_value: Bootstrap value

        Returns:
            TD(lambda) returns
        """
        n_steps = len(rewards)
        returns = [0.0] * n_steps

        # Backward computation
        running_return = next_value
        for t in reversed(range(n_steps)):
            if dones[t]:
                running_return = rewards[t]
            else:
                # Mix of TD(0) and full return
                td_target = rewards[t] + self.gamma * (
                    self.gae_lambda * running_return +
                    (1 - self.gae_lambda) * values[t + 1 if t < n_steps - 1 else t]
                )
                running_return = td_target
            returns[t] = running_return

        return returns

    def get_statistics(self) -> Dict[str, float]:
        """Get GAE statistics."""
        advantages = list(self.advantage_history)
        returns = list(self.return_history)

        return {
            "advantage_mean": MathOps.mean(advantages) if advantages else 0.0,
            "advantage_std": MathOps.std(advantages) if advantages else 0.0,
            "return_mean": MathOps.mean(returns) if returns else 0.0,
            "return_std": MathOps.std(returns) if returns else 0.0,
        }


class KLDivergence:
    """
    KL Divergence controller for PPO policy updates.

    Implements adaptive KL penalty coefficient adjustment based on
    the KL divergence between old and new policies.
    """

    def __init__(
        self,
        target_kl: float = 0.01,
        initial_coefficient: float = 0.2,
        adaptive: bool = True,
        kl_horizon: int = 10000,
        min_coefficient: float = 0.0,
        max_coefficient: float = 100.0,
    ):
        """
        Initialize KL divergence controller.

        Args:
            target_kl: Target KL divergence
            initial_coefficient: Initial KL penalty coefficient
            adaptive: Whether to adaptively adjust coefficient
            kl_horizon: Horizon for adaptation rate
            min_coefficient: Minimum coefficient value
            max_coefficient: Maximum coefficient value
        """
        self.target_kl = target_kl
        self.coefficient = initial_coefficient
        self.adaptive = adaptive
        self.kl_horizon = kl_horizon
        self.min_coefficient = min_coefficient
        self.max_coefficient = max_coefficient

        # History tracking
        self.kl_history: deque = deque(maxlen=1000)
        self.coefficient_history: deque = deque(maxlen=1000)
        self.update_count = 0

    def compute_kl_divergence(
        self,
        old_log_probs: Tensor,
        new_log_probs: Tensor,
        reduction: str = "mean",
    ) -> float:
        """
        Compute KL divergence between old and new policies.

        KL(old || new) = E_old[log(old/new)] = E_old[log_old - log_new]

        Args:
            old_log_probs: Log probs under old policy
            new_log_probs: Log probs under new policy
            reduction: How to reduce ("mean", "sum", "none")

        Returns:
            KL divergence value
        """
        kl_values = []
        for old_lp, new_lp in zip(old_log_probs, new_log_probs):
            # Forward KL divergence
            kl = MathOps.exp(old_lp) * (old_lp - new_lp)
            kl_values.append(kl)

        if reduction == "mean":
            return MathOps.mean(kl_values)
        elif reduction == "sum":
            return sum(kl_values)
        else:
            return kl_values

    def compute_kl_penalty(self, kl_divergence: float) -> float:
        """
        Compute KL penalty term for loss function.

        Args:
            kl_divergence: Current KL divergence

        Returns:
            KL penalty value
        """
        return self.coefficient * kl_divergence

    def update_coefficient(self, kl_divergence: float) -> float:
        """
        Update KL coefficient based on current KL divergence.

        Uses multiplicative update rule:
        - If KL > 1.5 * target: increase coefficient
        - If KL < target / 1.5: decrease coefficient

        Args:
            kl_divergence: Current KL divergence

        Returns:
            Updated coefficient
        """
        self.update_count += 1
        self.kl_history.append(kl_divergence)

        if not self.adaptive:
            return self.coefficient

        # Adaptation rate decays over time
        adaptation_rate = 1.5 / (1 + self.update_count / self.kl_horizon)

        if kl_divergence > 1.5 * self.target_kl:
            # KL too high, increase coefficient
            self.coefficient *= (1 + adaptation_rate)
        elif kl_divergence < self.target_kl / 1.5:
            # KL too low, decrease coefficient
            self.coefficient /= (1 + adaptation_rate)

        # Clamp coefficient
        self.coefficient = MathOps.clip(
            self.coefficient,
            self.min_coefficient,
            self.max_coefficient
        )

        self.coefficient_history.append(self.coefficient)

        return self.coefficient

    def should_early_stop(self, kl_divergence: float) -> bool:
        """
        Check if training should early stop due to high KL.

        Args:
            kl_divergence: Current KL divergence

        Returns:
            True if should stop, False otherwise
        """
        # Stop if KL exceeds 2x target
        return kl_divergence > 2.0 * self.target_kl

    def get_statistics(self) -> Dict[str, float]:
        """Get KL controller statistics."""
        kl_values = list(self.kl_history)
        coef_values = list(self.coefficient_history)

        return {
            "mean_kl": MathOps.mean(kl_values) if kl_values else 0.0,
            "current_coefficient": self.coefficient,
            "coefficient_mean": MathOps.mean(coef_values) if coef_values else self.coefficient,
            "update_count": self.update_count,
        }


class RolloutBuffer:
    """
    Experience replay buffer for PPO rollouts.

    Stores experiences from environment interactions and provides
    mini-batch sampling for PPO updates.
    """

    def __init__(
        self,
        buffer_size: int = 2048,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        n_envs: int = 1,
    ):
        """
        Initialize RolloutBuffer.

        Args:
            buffer_size: Maximum experiences per rollout
            gamma: Discount factor (for GAE computation)
            gae_lambda: GAE lambda parameter
            n_envs: Number of parallel environments
        """
        self.buffer_size = buffer_size
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.n_envs = n_envs

        # Storage
        self.experiences: List[Experience] = []
        self.computed_advantages: Optional[Tensor] = None
        self.computed_returns: Optional[Tensor] = None

        # GAE computer
        self.gae = GAE(gamma=gamma, gae_lambda=gae_lambda)

        # Statistics
        self.total_experiences = 0
        self.episode_rewards: List[float] = []
        self.episode_lengths: List[int] = []

    def add(self, experience: Experience) -> None:
        """
        Add experience to buffer.

        Args:
            experience: Experience tuple to add
        """
        self.experiences.append(experience)
        self.total_experiences += 1

        # Invalidate computed values
        self.computed_advantages = None
        self.computed_returns = None

    def add_batch(self, experiences: List[Experience]) -> None:
        """Add batch of experiences."""
        for exp in experiences:
            self.add(exp)

    @property
    def size(self) -> int:
        """Current buffer size."""
        return len(self.experiences)

    @property
    def is_full(self) -> bool:
        """Check if buffer is full."""
        return self.size >= self.buffer_size

    def compute_advantages_and_returns(
        self,
        last_value: float = 0.0,
    ) -> None:
        """
        Compute GAE advantages and returns for stored experiences.

        Args:
            last_value: Value estimate for state after last experience
        """
        if not self.experiences:
            return

        rewards = [exp.reward for exp in self.experiences]
        values = [exp.value for exp in self.experiences]
        dones = [exp.done for exp in self.experiences]

        self.computed_advantages, self.computed_returns = self.gae.compute_advantages(
            rewards=rewards,
            values=values,
            dones=dones,
            next_value=last_value,
        )

    def get_batch(self) -> RolloutBatch:
        """
        Get full rollout as a batch.

        Returns:
            RolloutBatch containing all experiences
        """
        if self.computed_advantages is None:
            self.compute_advantages_and_returns()

        return RolloutBatch(
            states=[exp.state for exp in self.experiences],
            actions=[exp.action for exp in self.experiences],
            action_log_probs=[exp.action_log_prob for exp in self.experiences],
            rewards=[exp.reward for exp in self.experiences],
            values=[exp.value for exp in self.experiences],
            advantages=self.computed_advantages,
            returns=self.computed_returns,
            dones=[exp.done for exp in self.experiences],
        )

    def get_minibatches(
        self,
        batch_size: int,
        shuffle: bool = True,
    ) -> Iterator[RolloutBatch]:
        """
        Generate mini-batches for PPO updates.

        Args:
            batch_size: Size of each mini-batch
            shuffle: Whether to shuffle experiences

        Yields:
            RolloutBatch for each mini-batch
        """
        if self.computed_advantages is None:
            self.compute_advantages_and_returns()

        indices = list(range(self.size))
        if shuffle:
            random.shuffle(indices)

        for start_idx in range(0, self.size, batch_size):
            end_idx = min(start_idx + batch_size, self.size)
            batch_indices = indices[start_idx:end_idx]

            yield RolloutBatch(
                states=[self.experiences[i].state for i in batch_indices],
                actions=[self.experiences[i].action for i in batch_indices],
                action_log_probs=[self.experiences[i].action_log_prob for i in batch_indices],
                rewards=[self.experiences[i].reward for i in batch_indices],
                values=[self.experiences[i].value for i in batch_indices],
                advantages=[self.computed_advantages[i] for i in batch_indices],
                returns=[self.computed_returns[i] for i in batch_indices],
                dones=[self.experiences[i].done for i in batch_indices],
            )

    def clear(self) -> None:
        """Clear buffer."""
        self.experiences = []
        self.computed_advantages = None
        self.computed_returns = None

    def log_episode(self, episode_reward: float, episode_length: int) -> None:
        """Log completed episode statistics."""
        self.episode_rewards.append(episode_reward)
        self.episode_lengths.append(episode_length)

    def get_statistics(self) -> Dict[str, float]:
        """Get buffer statistics."""
        stats = {
            "buffer_size": self.size,
            "total_experiences": self.total_experiences,
        }

        if self.episode_rewards:
            stats["mean_episode_reward"] = MathOps.mean(self.episode_rewards[-100:])
            stats["mean_episode_length"] = MathOps.mean(self.episode_lengths[-100:])

        if self.computed_advantages:
            stats["advantage_mean"] = MathOps.mean(self.computed_advantages)
            stats["advantage_std"] = MathOps.std(self.computed_advantages)

        return stats


class PPOTrainer:
    """
    Full PPO (Proximal Policy Optimization) training implementation.

    Implements the complete PPO algorithm for LLM fine-tuning including:
    - Clipped surrogate objective
    - Value function loss with optional clipping
    - Entropy regularization
    - Adaptive KL penalty
    - Generalized Advantage Estimation

    This trainer is designed for RLHF (Reinforcement Learning from Human Feedback)
    scenarios where a language model is fine-tuned using reward signals.
    """

    def __init__(
        self,
        config: PPOConfig = None,
        policy_forward_fn: Callable[[Any], Tuple[Tensor, float]] = None,
        value_forward_fn: Callable[[Any], float] = None,
        policy_update_fn: Callable[[Dict[str, float]], None] = None,
        value_update_fn: Callable[[Dict[str, float]], None] = None,
    ):
        """
        Initialize PPOTrainer.

        Args:
            config: PPO configuration
            policy_forward_fn: Function to get action distribution and log prob
            value_forward_fn: Function to get value estimate
            policy_update_fn: Function to update policy parameters
            value_update_fn: Function to update value parameters
        """
        self.config = config or PPOConfig()

        # Model functions (to be set externally for actual LLM integration)
        self.policy_forward_fn = policy_forward_fn
        self.value_forward_fn = value_forward_fn
        self.policy_update_fn = policy_update_fn
        self.value_update_fn = value_update_fn

        # Core components
        self.value_head = ValueHead(
            input_dim=768,  # Default for transformer hidden size
            hidden_dims=self.config.value_hidden_dims,
            activation=self.config.value_activation,
        )

        self.policy_gradient = PolicyGradient(
            clip_epsilon=self.config.clip_epsilon,
            entropy_coefficient=self.config.entropy_coefficient,
        )

        self.gae = GAE(
            gamma=self.config.gamma,
            gae_lambda=self.config.gae_lambda,
        )

        self.kl_controller = KLDivergence(
            target_kl=self.config.target_kl,
            initial_coefficient=self.config.kl_coefficient,
            adaptive=self.config.adaptive_kl,
            kl_horizon=self.config.kl_horizon,
        )

        self.rollout_buffer = RolloutBuffer(
            buffer_size=self.config.rollout_buffer_size,
            gamma=self.config.gamma,
            gae_lambda=self.config.gae_lambda,
        )

        # Training state
        self.global_step = 0
        self.epochs_trained = 0
        self.best_reward = float("-inf")

        # Metrics tracking
        self.training_history: List[Dict[str, float]] = []
        self.checkpoint_history: List[Dict[str, Any]] = []

        logger.info(f"PPOTrainer initialized with config: {self.config}")

    def collect_rollouts(
        self,
        env_step_fn: Callable[[Any], Tuple[Any, float, bool, Dict]],
        initial_state: Any,
        n_steps: int = None,
    ) -> float:
        """
        Collect rollout experiences from environment.

        Args:
            env_step_fn: Function to step environment (action -> next_state, reward, done, info)
            initial_state: Starting state
            n_steps: Number of steps to collect (uses buffer size if None)

        Returns:
            Mean episode reward
        """
        n_steps = n_steps or self.config.rollout_buffer_size
        self.rollout_buffer.clear()

        state = initial_state
        episode_reward = 0.0
        episode_length = 0
        episode_rewards = []

        for _ in range(n_steps):
            # Get action from policy
            if self.policy_forward_fn:
                action_dist, action_log_prob = self.policy_forward_fn(state)
                action = self._sample_action(action_dist)
            else:
                # Placeholder for testing
                action = 0
                action_log_prob = -1.0

            # Get value estimate
            if self.value_forward_fn:
                value = self.value_forward_fn(state)
            else:
                value = 0.0

            # Step environment
            next_state, reward, done, info = env_step_fn(action)

            # Store experience
            experience = Experience(
                state=state,
                action=action,
                action_log_prob=action_log_prob,
                reward=reward,
                value=value,
                done=done,
                info=info,
            )
            self.rollout_buffer.add(experience)

            # Track episode stats
            episode_reward += reward
            episode_length += 1

            if done:
                self.rollout_buffer.log_episode(episode_reward, episode_length)
                episode_rewards.append(episode_reward)
                episode_reward = 0.0
                episode_length = 0

            state = next_state

        # Compute advantages
        last_value = self.value_forward_fn(state) if self.value_forward_fn else 0.0
        self.rollout_buffer.compute_advantages_and_returns(last_value)

        return MathOps.mean(episode_rewards) if episode_rewards else 0.0

    def _sample_action(self, action_dist: Tensor) -> int:
        """Sample action from distribution."""
        probs = MathOps.softmax(action_dist)
        r = random.random()
        cumsum = 0.0
        for i, p in enumerate(probs):
            cumsum += p
            if r < cumsum:
                return i
        return len(probs) - 1

    def compute_value_loss(
        self,
        values: Tensor,
        old_values: Tensor,
        returns: Tensor,
        clip: bool = True,
    ) -> Dict[str, float]:
        """
        Compute value function loss.

        Args:
            values: New value predictions
            old_values: Old value predictions
            returns: Target returns
            clip: Whether to clip value function

        Returns:
            Dictionary with loss components
        """
        # Unclipped loss
        value_loss_unclipped = [
            (v - r) ** 2 for v, r in zip(values, returns)
        ]

        if clip:
            # Clipped value predictions
            clipped_values = [
                old_v + MathOps.clip(
                    v - old_v,
                    -self.config.value_clip_epsilon,
                    self.config.value_clip_epsilon
                )
                for v, old_v in zip(values, old_values)
            ]

            value_loss_clipped = [
                (cv - r) ** 2 for cv, r in zip(clipped_values, returns)
            ]

            # Take maximum of clipped and unclipped (pessimistic)
            value_loss_per_sample = [
                max(uc, c) for uc, c in zip(value_loss_unclipped, value_loss_clipped)
            ]
        else:
            value_loss_per_sample = value_loss_unclipped

        value_loss = 0.5 * MathOps.mean(value_loss_per_sample)

        return {
            "value_loss": value_loss,
            "explained_variance": self._compute_explained_variance(values, returns),
        }

    def _compute_explained_variance(
        self,
        predictions: Tensor,
        targets: Tensor,
    ) -> float:
        """Compute explained variance of value predictions."""
        if len(predictions) < 2:
            return 0.0

        target_var = MathOps.std(targets) ** 2
        if target_var < 1e-8:
            return 0.0

        residuals = [t - p for t, p in zip(targets, predictions)]
        residual_var = MathOps.std(residuals) ** 2

        return 1.0 - residual_var / target_var

    def train_step(self) -> Dict[str, float]:
        """
        Perform one PPO training step on collected rollouts.

        Returns:
            Dictionary of training metrics
        """
        if self.rollout_buffer.size == 0:
            logger.warning("No experiences in rollout buffer")
            return {}

        batch = self.rollout_buffer.get_batch()

        total_policy_loss = 0.0
        total_value_loss = 0.0
        total_entropy = 0.0
        n_updates = 0
        early_stopped = False

        # Multiple epochs of updates
        for epoch in range(self.config.ppo_epochs):
            if early_stopped:
                break

            # Iterate over mini-batches
            for minibatch in self.rollout_buffer.get_minibatches(
                self.config.mini_batch_size,
                shuffle=True,
            ):
                # Get new log probs and values
                if self.policy_forward_fn:
                    new_log_probs = []
                    action_dists = []
                    for state, action in zip(minibatch.states, minibatch.actions):
                        dist, log_prob = self.policy_forward_fn(state)
                        new_log_probs.append(log_prob)
                        action_dists.append(dist)
                else:
                    # Placeholder
                    new_log_probs = minibatch.action_log_probs
                    action_dists = None

                if self.value_forward_fn:
                    new_values = [self.value_forward_fn(s) for s in minibatch.states]
                else:
                    new_values = minibatch.values

                # Compute policy loss
                policy_result = self.policy_gradient.compute_policy_loss(
                    new_log_probs=new_log_probs,
                    old_log_probs=minibatch.action_log_probs,
                    advantages=list(minibatch.advantages),
                    action_distributions=action_dists,
                )

                # Compute value loss
                value_result = self.compute_value_loss(
                    values=new_values,
                    old_values=list(minibatch.values),
                    returns=list(minibatch.returns),
                )

                # Combined loss
                total_loss = (
                    policy_result["total_loss"] +
                    self.config.value_coefficient * value_result["value_loss"]
                )

                # Check KL divergence for early stopping
                kl_div = policy_result["approx_kl"]
                self.kl_controller.update_coefficient(kl_div)

                if self.kl_controller.should_early_stop(kl_div):
                    logger.info(f"Early stopping at epoch {epoch} due to high KL: {kl_div:.4f}")
                    early_stopped = True
                    break

                # Update models (gradient step would happen here)
                if self.policy_update_fn:
                    self.policy_update_fn({
                        "loss": policy_result["total_loss"],
                        "grad_norm": self.config.max_grad_norm,
                    })

                if self.value_update_fn:
                    self.value_update_fn({
                        "loss": value_result["value_loss"],
                        "grad_norm": self.config.max_grad_norm,
                    })

                # Accumulate metrics
                total_policy_loss += policy_result["policy_loss"]
                total_value_loss += value_result["value_loss"]
                n_updates += 1

                self.global_step += 1

        self.epochs_trained += 1

        # Compute average metrics
        metrics = {
            "policy_loss": total_policy_loss / max(n_updates, 1),
            "value_loss": total_value_loss / max(n_updates, 1),
            "kl_divergence": self.kl_controller.get_statistics()["mean_kl"],
            "kl_coefficient": self.kl_controller.coefficient,
            "clip_fraction": self.policy_gradient.get_statistics()["mean_clip_fraction"],
            "explained_variance": value_result.get("explained_variance", 0.0),
            "n_updates": n_updates,
            "early_stopped": early_stopped,
            "global_step": self.global_step,
            "epochs_trained": self.epochs_trained,
        }

        # Add buffer stats
        metrics.update(self.rollout_buffer.get_statistics())

        self.training_history.append(metrics)

        # Log periodically
        if self.global_step % self.config.log_interval == 0:
            logger.info(f"Step {self.global_step}: " + ", ".join(
                f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}"
                for k, v in metrics.items()
            ))

        return metrics

    def train(
        self,
        env_step_fn: Callable[[Any], Tuple[Any, float, bool, Dict]],
        initial_state: Any,
        total_timesteps: int,
        callback: Callable[[Dict[str, float]], bool] = None,
    ) -> Dict[str, Any]:
        """
        Full training loop.

        Args:
            env_step_fn: Environment step function
            initial_state: Starting state
            total_timesteps: Total timesteps to train
            callback: Optional callback after each update (return False to stop)

        Returns:
            Training summary
        """
        logger.info(f"Starting PPO training for {total_timesteps} timesteps")
        start_time = time.time()

        timesteps_collected = 0
        state = initial_state

        while timesteps_collected < total_timesteps:
            # Collect rollouts
            mean_reward = self.collect_rollouts(
                env_step_fn=env_step_fn,
                initial_state=state,
            )
            timesteps_collected += self.rollout_buffer.size

            # Train on rollouts
            metrics = self.train_step()
            metrics["mean_episode_reward"] = mean_reward
            metrics["timesteps"] = timesteps_collected

            # Track best reward
            if mean_reward > self.best_reward:
                self.best_reward = mean_reward
                logger.info(f"New best reward: {self.best_reward:.4f}")

            # Callback
            if callback and not callback(metrics):
                logger.info("Training stopped by callback")
                break

            # Checkpoint
            if self.global_step % self.config.save_interval == 0:
                self._save_checkpoint(metrics)

        elapsed_time = time.time() - start_time

        summary = {
            "total_timesteps": timesteps_collected,
            "total_epochs": self.epochs_trained,
            "global_steps": self.global_step,
            "best_reward": self.best_reward,
            "elapsed_time": elapsed_time,
            "timesteps_per_second": timesteps_collected / elapsed_time,
        }

        logger.info(f"Training complete: {summary}")
        return summary

    def _save_checkpoint(self, metrics: Dict[str, float]) -> None:
        """Save training checkpoint."""
        checkpoint = {
            "global_step": self.global_step,
            "epochs_trained": self.epochs_trained,
            "best_reward": self.best_reward,
            "kl_coefficient": self.kl_controller.coefficient,
            "metrics": metrics,
            "timestamp": time.time(),
        }
        self.checkpoint_history.append(checkpoint)
        logger.info(f"Checkpoint saved at step {self.global_step}")

    def get_training_summary(self) -> Dict[str, Any]:
        """Get comprehensive training summary."""
        if not self.training_history:
            return {"message": "No training history available"}

        # Aggregate metrics
        all_policy_losses = [h["policy_loss"] for h in self.training_history]
        all_value_losses = [h["value_loss"] for h in self.training_history]
        all_kl_divs = [h["kl_divergence"] for h in self.training_history]

        return {
            "global_step": self.global_step,
            "epochs_trained": self.epochs_trained,
            "best_reward": self.best_reward,
            "final_metrics": self.training_history[-1],
            "policy_loss": {
                "mean": MathOps.mean(all_policy_losses),
                "min": min(all_policy_losses),
                "max": max(all_policy_losses),
            },
            "value_loss": {
                "mean": MathOps.mean(all_value_losses),
                "min": min(all_value_losses),
                "max": max(all_value_losses),
            },
            "kl_divergence": {
                "mean": MathOps.mean(all_kl_divs),
                "min": min(all_kl_divs),
                "max": max(all_kl_divs),
            },
            "component_stats": {
                "value_head": self.value_head.get_statistics(),
                "policy_gradient": self.policy_gradient.get_statistics(),
                "gae": self.gae.get_statistics(),
                "kl_controller": self.kl_controller.get_statistics(),
                "rollout_buffer": self.rollout_buffer.get_statistics(),
            },
            "checkpoints": len(self.checkpoint_history),
        }


# ==============================================================================
# Utility Functions
# ==============================================================================

def create_reward_model_wrapper(
    reward_fn: Callable[[str, str], float],
) -> Callable[[Any], Tuple[Any, float, bool, Dict]]:
    """
    Create environment step function from reward model.

    For RLHF, the "environment" is the reward model that scores
    generated text.

    Args:
        reward_fn: Function that scores (prompt, response) pairs

    Returns:
        Environment step function compatible with PPOTrainer
    """
    def env_step(action: Any) -> Tuple[Any, float, bool, Dict]:
        # In RLHF context:
        # - action is the generated token/text
        # - reward comes from reward model
        # - done is True at end of generation
        # - next_state is updated context

        # Placeholder implementation
        reward = 0.0  # Would call reward_fn here
        done = False
        next_state = action  # Simplified
        info = {}

        return next_state, reward, done, info

    return env_step


def compute_reward_baseline(
    rewards: List[float],
    method: str = "mean",
) -> float:
    """
    Compute reward baseline for variance reduction.

    Args:
        rewards: List of rewards
        method: Baseline method ("mean", "median", "exponential")

    Returns:
        Baseline value
    """
    if not rewards:
        return 0.0

    if method == "mean":
        return MathOps.mean(rewards)
    elif method == "median":
        sorted_rewards = sorted(rewards)
        n = len(sorted_rewards)
        if n % 2 == 0:
            return (sorted_rewards[n // 2 - 1] + sorted_rewards[n // 2]) / 2
        return sorted_rewards[n // 2]
    elif method == "exponential":
        # Exponential moving average
        alpha = 0.1
        baseline = rewards[0]
        for r in rewards[1:]:
            baseline = alpha * r + (1 - alpha) * baseline
        return baseline

    return 0.0


# ==============================================================================
# Example Usage and Testing
# ==============================================================================

def run_ppo_example():
    """Run example PPO training loop."""
    logger.info("=" * 60)
    logger.info("PPO Engine Example")
    logger.info("=" * 60)

    # Create configuration
    config = PPOConfig(
        clip_epsilon=0.2,
        entropy_coefficient=0.01,
        gamma=0.99,
        gae_lambda=0.95,
        ppo_epochs=4,
        mini_batch_size=8,
        rollout_buffer_size=64,  # Small for demo
        target_kl=0.02,
        log_interval=10,
    )

    # Create trainer
    trainer = PPOTrainer(config=config)

    # Simple environment simulator
    class SimpleEnv:
        def __init__(self):
            self.state = 0
            self.max_steps = 50
            self.step_count = 0

        def step(self, action):
            self.step_count += 1

            # Simple reward function
            reward = 1.0 if action == (self.state % 2) else -0.5

            # State transition
            self.state = (self.state + action + 1) % 10

            # Termination
            done = self.step_count >= self.max_steps
            if done:
                self.step_count = 0

            info = {"step": self.step_count}

            return self.state, reward, done, info

    env = SimpleEnv()

    # Mock policy and value functions
    def mock_policy(state):
        # Returns (action_logits, action_log_prob)
        logits = [random.gauss(0, 1) for _ in range(2)]
        probs = MathOps.softmax(logits)
        action = 0 if random.random() < probs[0] else 1
        log_prob = MathOps.log(probs[action])
        return logits, log_prob

    def mock_value(state):
        return random.gauss(0, 1)

    trainer.policy_forward_fn = mock_policy
    trainer.value_forward_fn = mock_value

    # Collect some rollouts
    logger.info("\nCollecting rollouts...")
    mean_reward = trainer.collect_rollouts(
        env_step_fn=env.step,
        initial_state=env.state,
        n_steps=64,
    )
    logger.info(f"Mean episode reward: {mean_reward:.4f}")

    # Train
    logger.info("\nTraining...")
    metrics = trainer.train_step()

    logger.info("\nTraining metrics:")
    for key, value in metrics.items():
        if isinstance(value, float):
            logger.info(f"  {key}: {value:.4f}")
        else:
            logger.info(f"  {key}: {value}")

    # Get summary
    summary = trainer.get_training_summary()
    logger.info("\nTraining summary:")
    logger.info(json.dumps(summary, indent=2, default=str))

    return trainer


def test_components():
    """Test individual PPO components."""
    logger.info("=" * 60)
    logger.info("Component Tests")
    logger.info("=" * 60)

    # Test GAE
    logger.info("\n1. Testing GAE...")
    gae = GAE(gamma=0.99, gae_lambda=0.95)
    rewards = [1.0, 0.5, 0.8, 1.2, 0.3]
    values = [0.9, 0.6, 0.7, 1.0, 0.4]
    dones = [False, False, False, False, True]

    advantages, returns = gae.compute_advantages(rewards, values, dones)
    logger.info(f"  Advantages: {[f'{a:.3f}' for a in advantages]}")
    logger.info(f"  Returns: {[f'{r:.3f}' for r in returns]}")
    logger.info(f"  GAE Stats: {gae.get_statistics()}")

    # Test PolicyGradient
    logger.info("\n2. Testing PolicyGradient...")
    pg = PolicyGradient(clip_epsilon=0.2, entropy_coefficient=0.01)

    old_log_probs = [-1.0, -0.8, -1.2, -0.9, -1.1]
    new_log_probs = [-1.1, -0.7, -1.0, -1.0, -1.2]

    result = pg.compute_policy_loss(new_log_probs, old_log_probs, advantages)
    logger.info(f"  Policy loss result: {result}")
    logger.info(f"  PG Stats: {pg.get_statistics()}")

    # Test KLDivergence
    logger.info("\n3. Testing KLDivergence...")
    kl_ctrl = KLDivergence(target_kl=0.01, adaptive=True)

    kl_value = kl_ctrl.compute_kl_divergence(old_log_probs, new_log_probs)
    logger.info(f"  KL divergence: {kl_value:.6f}")

    new_coef = kl_ctrl.update_coefficient(kl_value)
    logger.info(f"  Updated coefficient: {new_coef:.4f}")
    logger.info(f"  Should early stop: {kl_ctrl.should_early_stop(kl_value)}")
    logger.info(f"  KL Stats: {kl_ctrl.get_statistics()}")

    # Test ValueHead
    logger.info("\n4. Testing ValueHead...")
    vh = ValueHead(input_dim=8, hidden_dims=[16, 8])

    test_input = [random.gauss(0, 1) for _ in range(8)]
    value = vh.forward(test_input)
    logger.info(f"  Value prediction: {value:.4f}")
    logger.info(f"  Value head stats: {vh.get_statistics()}")

    # Test RolloutBuffer
    logger.info("\n5. Testing RolloutBuffer...")
    buffer = RolloutBuffer(buffer_size=32)

    for i in range(20):
        exp = Experience(
            state=i,
            action=i % 2,
            action_log_prob=-random.random(),
            reward=random.random(),
            value=random.gauss(0, 1),
            done=(i == 19),
        )
        buffer.add(exp)

    buffer.compute_advantages_and_returns()
    logger.info(f"  Buffer size: {buffer.size}")
    logger.info(f"  Buffer stats: {buffer.get_statistics()}")

    # Test mini-batching
    batch_count = 0
    for batch in buffer.get_minibatches(batch_size=8):
        batch_count += 1
        logger.info(f"  Batch {batch_count}: {len(batch.states)} samples")

    logger.info("\nAll component tests passed!")


if __name__ == "__main__":
    # Run component tests
    test_components()

    # Run example training
    print("\n" + "=" * 60 + "\n")
    trainer = run_ppo_example()

    print("\n" + "=" * 60)
    print("PPO Engine implementation complete!")
    print("=" * 60)
