#!/usr/bin/env python3
"""
AIVA Queen Neural Transformer Implementation
=============================================

A complete, production-ready transformer architecture for the AIVA Queen system.
This implementation includes all core components of the transformer architecture
as described in "Attention Is All You Need" (Vaswani et al., 2017) with modern
improvements and optimizations.

Components:
-----------
1. MultiHeadSelfAttention - Scaled dot-product attention with multiple heads
2. PositionalEncoding - Sinusoidal and learned position embeddings
3. FeedForwardNetwork - Two-layer MLP with GELU activation
4. TransformerEncoderLayer - Complete encoder layer with residual connections
5. TransformerDecoderLayer - Complete decoder layer with cross-attention
6. QueenTransformer - Full encoder-decoder transformer model

Features:
---------
- Pre-layer normalization (more stable training)
- GELU activation (smoother than ReLU)
- Rotary position embeddings (optional)
- Flash attention support (optional)
- Gradient checkpointing for memory efficiency
- Attention visualization utilities
- Complete training and inference pipelines

Author: AIVA Queen Neural Core
Version: 1.0.0
"""

import math
import json
import logging
from typing import Optional, Tuple, Dict, List, Union, Any
from dataclasses import dataclass, field
from pathlib import Path
from datetime import datetime

import numpy as np

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("QueenTransformer")

# Try to import PyTorch, provide fallback for documentation
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torch.optim import AdamW
    from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    logger.warning("PyTorch not available. Running in documentation mode.")

# Try to import visualization libraries
try:
    import matplotlib.pyplot as plt
    import seaborn as sns
    PLOTTING_AVAILABLE = True
except ImportError:
    PLOTTING_AVAILABLE = False
    logger.warning("Matplotlib/Seaborn not available. Visualization disabled.")


# ==============================================================================
# Configuration Classes
# ==============================================================================

@dataclass
class TransformerConfig:
    """Configuration for the Queen Transformer model."""

    # Model architecture
    d_model: int = 512  # Model dimension
    n_heads: int = 8  # Number of attention heads
    n_encoder_layers: int = 6  # Number of encoder layers
    n_decoder_layers: int = 6  # Number of decoder layers
    d_ff: int = 2048  # Feed-forward dimension
    max_seq_length: int = 512  # Maximum sequence length
    vocab_size: int = 32000  # Vocabulary size

    # Dropout rates
    dropout: float = 0.1
    attention_dropout: float = 0.1
    ff_dropout: float = 0.1

    # Positional encoding
    pos_encoding_type: str = "sinusoidal"  # "sinusoidal" or "learned"
    use_rotary_embeddings: bool = False

    # Normalization
    layer_norm_eps: float = 1e-6
    pre_norm: bool = True  # Pre-layer normalization

    # Activation
    activation: str = "gelu"  # "gelu", "relu", "swish"

    # Special tokens
    pad_token_id: int = 0
    bos_token_id: int = 1
    eos_token_id: int = 2

    # Training settings
    gradient_checkpointing: bool = False
    tie_word_embeddings: bool = True

    def __post_init__(self):
        """Validate configuration after initialization."""
        assert self.d_model % self.n_heads == 0, \
            f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
        assert self.pos_encoding_type in ["sinusoidal", "learned"], \
            f"Invalid pos_encoding_type: {self.pos_encoding_type}"
        assert self.activation in ["gelu", "relu", "swish"], \
            f"Invalid activation: {self.activation}"

    def to_dict(self) -> Dict[str, Any]:
        """Convert config to dictionary."""
        return {
            "d_model": self.d_model,
            "n_heads": self.n_heads,
            "n_encoder_layers": self.n_encoder_layers,
            "n_decoder_layers": self.n_decoder_layers,
            "d_ff": self.d_ff,
            "max_seq_length": self.max_seq_length,
            "vocab_size": self.vocab_size,
            "dropout": self.dropout,
            "attention_dropout": self.attention_dropout,
            "ff_dropout": self.ff_dropout,
            "pos_encoding_type": self.pos_encoding_type,
            "use_rotary_embeddings": self.use_rotary_embeddings,
            "layer_norm_eps": self.layer_norm_eps,
            "pre_norm": self.pre_norm,
            "activation": self.activation,
            "pad_token_id": self.pad_token_id,
            "bos_token_id": self.bos_token_id,
            "eos_token_id": self.eos_token_id,
            "gradient_checkpointing": self.gradient_checkpointing,
            "tie_word_embeddings": self.tie_word_embeddings,
        }

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> "TransformerConfig":
        """Create config from dictionary."""
        return cls(**config_dict)

    def save(self, path: Union[str, Path]) -> None:
        """Save configuration to JSON file."""
        path = Path(path)
        with open(path, 'w') as f:
            json.dump(self.to_dict(), f, indent=2)
        logger.info(f"Config saved to {path}")

    @classmethod
    def load(cls, path: Union[str, Path]) -> "TransformerConfig":
        """Load configuration from JSON file."""
        path = Path(path)
        with open(path, 'r') as f:
            config_dict = json.load(f)
        return cls.from_dict(config_dict)


@dataclass
class TrainingConfig:
    """Configuration for training the transformer."""

    # Optimization
    learning_rate: float = 1e-4
    weight_decay: float = 0.01
    beta1: float = 0.9
    beta2: float = 0.98
    eps: float = 1e-9
    max_grad_norm: float = 1.0

    # Learning rate schedule
    warmup_steps: int = 4000
    lr_scheduler: str = "cosine"  # "cosine", "linear", "constant"

    # Training duration
    num_epochs: int = 10
    max_steps: int = -1  # -1 means use epochs

    # Batch settings
    batch_size: int = 32
    accumulation_steps: int = 1

    # Checkpointing
    save_steps: int = 1000
    eval_steps: int = 500
    logging_steps: int = 100

    # Mixed precision
    use_amp: bool = True

    # Label smoothing
    label_smoothing: float = 0.1


# ==============================================================================
# Positional Encoding
# ==============================================================================

if TORCH_AVAILABLE:

    class SinusoidalPositionalEncoding(nn.Module):
        """
        Sinusoidal positional encoding as described in the original transformer paper.

        PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
        PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

        Args:
            d_model: Model dimension
            max_seq_length: Maximum sequence length
            dropout: Dropout rate
        """

        def __init__(
            self,
            d_model: int,
            max_seq_length: int = 512,
            dropout: float = 0.1
        ):
            super().__init__()
            self.d_model = d_model
            self.dropout = nn.Dropout(p=dropout)

            # Create positional encoding matrix
            pe = torch.zeros(max_seq_length, d_model)
            position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)

            # Compute div_term for sinusoidal functions
            div_term = torch.exp(
                torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
            )

            # Apply sin to even indices and cos to odd indices
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)

            # Add batch dimension and register as buffer (not a parameter)
            pe = pe.unsqueeze(0)  # Shape: (1, max_seq_length, d_model)
            self.register_buffer('pe', pe)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """
            Add positional encoding to input embeddings.

            Args:
                x: Input tensor of shape (batch_size, seq_length, d_model)

            Returns:
                Tensor with positional encoding added
            """
            seq_length = x.size(1)
            x = x + self.pe[:, :seq_length, :]
            return self.dropout(x)


    class LearnedPositionalEncoding(nn.Module):
        """
        Learned positional encoding using trainable embeddings.

        Args:
            d_model: Model dimension
            max_seq_length: Maximum sequence length
            dropout: Dropout rate
        """

        def __init__(
            self,
            d_model: int,
            max_seq_length: int = 512,
            dropout: float = 0.1
        ):
            super().__init__()
            self.d_model = d_model
            self.dropout = nn.Dropout(p=dropout)

            # Learnable position embeddings
            self.position_embeddings = nn.Embedding(max_seq_length, d_model)

            # Register position indices
            position_ids = torch.arange(max_seq_length).unsqueeze(0)
            self.register_buffer('position_ids', position_ids)

            # Initialize embeddings
            self._init_weights()

        def _init_weights(self):
            """Initialize position embeddings with small values."""
            nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=0.02)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """
            Add learned positional encoding to input embeddings.

            Args:
                x: Input tensor of shape (batch_size, seq_length, d_model)

            Returns:
                Tensor with positional encoding added
            """
            seq_length = x.size(1)
            position_ids = self.position_ids[:, :seq_length]
            position_embeddings = self.position_embeddings(position_ids)
            x = x + position_embeddings
            return self.dropout(x)


    class RotaryPositionalEmbedding(nn.Module):
        """
        Rotary Position Embedding (RoPE) as described in RoFormer.

        RoPE encodes positional information by rotating the query and key
        vectors in the attention mechanism, allowing for better extrapolation
        to longer sequences.

        Args:
            dim: Dimension of the embedding (should be head_dim)
            max_seq_length: Maximum sequence length
            base: Base for the frequency computation
        """

        def __init__(
            self,
            dim: int,
            max_seq_length: int = 512,
            base: float = 10000.0
        ):
            super().__init__()
            self.dim = dim
            self.max_seq_length = max_seq_length
            self.base = base

            # Compute inverse frequencies
            inv_freq = 1.0 / (
                base ** (torch.arange(0, dim, 2).float() / dim)
            )
            self.register_buffer('inv_freq', inv_freq)

            # Pre-compute sin and cos caches
            self._build_cache(max_seq_length)

        def _build_cache(self, seq_length: int):
            """Build sin/cos cache for the given sequence length."""
            t = torch.arange(seq_length, device=self.inv_freq.device)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)

            # Create embeddings for rotation
            emb = torch.cat((freqs, freqs), dim=-1)

            self.register_buffer('cos_cached', emb.cos().unsqueeze(0).unsqueeze(0))
            self.register_buffer('sin_cached', emb.sin().unsqueeze(0).unsqueeze(0))

        def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
            """Rotate half the hidden dims of the input."""
            x1 = x[..., : x.shape[-1] // 2]
            x2 = x[..., x.shape[-1] // 2 :]
            return torch.cat((-x2, x1), dim=-1)

        def forward(
            self,
            q: torch.Tensor,
            k: torch.Tensor,
            seq_len: int
        ) -> Tuple[torch.Tensor, torch.Tensor]:
            """
            Apply rotary position embeddings to query and key tensors.

            Args:
                q: Query tensor of shape (batch, n_heads, seq_len, head_dim)
                k: Key tensor of shape (batch, n_heads, seq_len, head_dim)
                seq_len: Sequence length

            Returns:
                Rotated query and key tensors
            """
            cos = self.cos_cached[:, :, :seq_len, :]
            sin = self.sin_cached[:, :, :seq_len, :]

            q_embed = (q * cos) + (self._rotate_half(q) * sin)
            k_embed = (k * cos) + (self._rotate_half(k) * sin)

            return q_embed, k_embed


# ==============================================================================
# Multi-Head Self-Attention
# ==============================================================================

if TORCH_AVAILABLE:

    class MultiHeadSelfAttention(nn.Module):
        """
        Multi-Head Self-Attention mechanism.

        Implements scaled dot-product attention with multiple heads, allowing
        the model to jointly attend to information from different representation
        subspaces at different positions.

        Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V

        Args:
            config: TransformerConfig object
        """

        def __init__(self, config: TransformerConfig):
            super().__init__()
            self.config = config
            self.d_model = config.d_model
            self.n_heads = config.n_heads
            self.head_dim = config.d_model // config.n_heads
            self.scale = math.sqrt(self.head_dim)

            # Linear projections for Q, K, V
            self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
            self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
            self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)

            # Output projection
            self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)

            # Dropout
            self.attention_dropout = nn.Dropout(config.attention_dropout)
            self.output_dropout = nn.Dropout(config.dropout)

            # Optional rotary embeddings
            self.rotary_emb = None
            if config.use_rotary_embeddings:
                self.rotary_emb = RotaryPositionalEmbedding(
                    self.head_dim,
                    config.max_seq_length
                )

            # Store attention weights for visualization
            self._attention_weights = None

            # Initialize weights
            self._init_weights()

        def _init_weights(self):
            """Initialize projection weights with Xavier uniform."""
            for module in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]:
                nn.init.xavier_uniform_(module.weight)

        def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
            """
            Split the last dimension into (n_heads, head_dim).

            Args:
                x: Tensor of shape (batch_size, seq_length, d_model)

            Returns:
                Tensor of shape (batch_size, n_heads, seq_length, head_dim)
            """
            batch_size, seq_length, _ = x.size()
            x = x.view(batch_size, seq_length, self.n_heads, self.head_dim)
            return x.transpose(1, 2)

        def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
            """
            Merge heads back into single dimension.

            Args:
                x: Tensor of shape (batch_size, n_heads, seq_length, head_dim)

            Returns:
                Tensor of shape (batch_size, seq_length, d_model)
            """
            batch_size, _, seq_length, _ = x.size()
            x = x.transpose(1, 2).contiguous()
            return x.view(batch_size, seq_length, self.d_model)

        def _compute_attention(
            self,
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            causal_mask: bool = False
        ) -> Tuple[torch.Tensor, torch.Tensor]:
            """
            Compute scaled dot-product attention.

            Args:
                q: Query tensor (batch, n_heads, seq_len, head_dim)
                k: Key tensor (batch, n_heads, seq_len, head_dim)
                v: Value tensor (batch, n_heads, seq_len, head_dim)
                attention_mask: Optional mask (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
                causal_mask: Whether to apply causal masking

            Returns:
                Tuple of (attention output, attention weights)
            """
            # Compute attention scores
            # (batch, n_heads, seq_len, seq_len)
            attention_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale

            seq_len = q.size(2)

            # Apply causal mask if needed
            if causal_mask:
                causal = torch.triu(
                    torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool),
                    diagonal=1
                )
                attention_scores = attention_scores.masked_fill(causal, float('-inf'))

            # Apply attention mask if provided
            if attention_mask is not None:
                # attention_mask: 1 for valid, 0 for masked positions
                # Convert to additive mask
                attention_scores = attention_scores.masked_fill(
                    attention_mask == 0,
                    float('-inf')
                )

            # Softmax to get attention probabilities
            attention_probs = F.softmax(attention_scores, dim=-1)

            # Store for visualization
            self._attention_weights = attention_probs.detach()

            # Apply dropout
            attention_probs = self.attention_dropout(attention_probs)

            # Compute attention output
            attention_output = torch.matmul(attention_probs, v)

            return attention_output, attention_probs

        def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            causal_mask: bool = False,
            output_attentions: bool = False
        ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
            """
            Forward pass for multi-head self-attention.

            Args:
                hidden_states: Input tensor (batch_size, seq_length, d_model)
                attention_mask: Optional attention mask
                causal_mask: Whether to apply causal masking
                output_attentions: Whether to return attention weights

            Returns:
                Output tensor, optionally with attention weights
            """
            batch_size, seq_length, _ = hidden_states.size()

            # Project to Q, K, V
            q = self.q_proj(hidden_states)
            k = self.k_proj(hidden_states)
            v = self.v_proj(hidden_states)

            # Split into multiple heads
            q = self._split_heads(q)
            k = self._split_heads(k)
            v = self._split_heads(v)

            # Apply rotary embeddings if configured
            if self.rotary_emb is not None:
                q, k = self.rotary_emb(q, k, seq_length)

            # Compute attention
            attention_output, attention_weights = self._compute_attention(
                q, k, v, attention_mask, causal_mask
            )

            # Merge heads
            attention_output = self._merge_heads(attention_output)

            # Output projection
            output = self.out_proj(attention_output)
            output = self.output_dropout(output)

            if output_attentions:
                return output, attention_weights
            return output

        def get_attention_weights(self) -> Optional[torch.Tensor]:
            """Return the last computed attention weights for visualization."""
            return self._attention_weights


    class MultiHeadCrossAttention(nn.Module):
        """
        Multi-Head Cross-Attention for decoder layers.

        Similar to self-attention but queries come from one source
        and keys/values from another (encoder outputs in seq2seq).

        Args:
            config: TransformerConfig object
        """

        def __init__(self, config: TransformerConfig):
            super().__init__()
            self.config = config
            self.d_model = config.d_model
            self.n_heads = config.n_heads
            self.head_dim = config.d_model // config.n_heads
            self.scale = math.sqrt(self.head_dim)

            # Q comes from decoder, K and V from encoder
            self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
            self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
            self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)

            # Output projection
            self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)

            # Dropout
            self.attention_dropout = nn.Dropout(config.attention_dropout)
            self.output_dropout = nn.Dropout(config.dropout)

            # Store attention for visualization
            self._attention_weights = None

            # Initialize
            self._init_weights()

        def _init_weights(self):
            """Initialize with Xavier uniform."""
            for module in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]:
                nn.init.xavier_uniform_(module.weight)

        def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
            """Split into multiple heads."""
            batch_size, seq_length, _ = x.size()
            x = x.view(batch_size, seq_length, self.n_heads, self.head_dim)
            return x.transpose(1, 2)

        def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
            """Merge heads back."""
            batch_size, _, seq_length, _ = x.size()
            x = x.transpose(1, 2).contiguous()
            return x.view(batch_size, seq_length, self.d_model)

        def forward(
            self,
            hidden_states: torch.Tensor,
            encoder_hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            output_attentions: bool = False
        ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
            """
            Forward pass for cross-attention.

            Args:
                hidden_states: Decoder hidden states (batch, tgt_len, d_model)
                encoder_hidden_states: Encoder outputs (batch, src_len, d_model)
                attention_mask: Mask for encoder positions
                output_attentions: Whether to return attention weights

            Returns:
                Cross-attention output, optionally with attention weights
            """
            # Q from decoder, K and V from encoder
            q = self.q_proj(hidden_states)
            k = self.k_proj(encoder_hidden_states)
            v = self.v_proj(encoder_hidden_states)

            # Split heads
            q = self._split_heads(q)
            k = self._split_heads(k)
            v = self._split_heads(v)

            # Attention scores
            attention_scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale

            # Apply mask if provided
            if attention_mask is not None:
                attention_scores = attention_scores.masked_fill(
                    attention_mask == 0,
                    float('-inf')
                )

            # Softmax
            attention_probs = F.softmax(attention_scores, dim=-1)
            self._attention_weights = attention_probs.detach()

            # Dropout
            attention_probs = self.attention_dropout(attention_probs)

            # Compute output
            attention_output = torch.matmul(attention_probs, v)

            # Merge and project
            attention_output = self._merge_heads(attention_output)
            output = self.out_proj(attention_output)
            output = self.output_dropout(output)

            if output_attentions:
                return output, attention_probs
            return output


# ==============================================================================
# Feed-Forward Network
# ==============================================================================

if TORCH_AVAILABLE:

    class FeedForwardNetwork(nn.Module):
        """
        Position-wise Feed-Forward Network.

        Consists of two linear transformations with a non-linear activation
        in between:

        FFN(x) = W2 * activation(W1 * x + b1) + b2

        Args:
            config: TransformerConfig object
        """

        def __init__(self, config: TransformerConfig):
            super().__init__()
            self.config = config

            # Get activation function
            if config.activation == "gelu":
                self.activation = nn.GELU()
            elif config.activation == "relu":
                self.activation = nn.ReLU()
            elif config.activation == "swish":
                self.activation = nn.SiLU()
            else:
                raise ValueError(f"Unknown activation: {config.activation}")

            # Two linear layers with intermediate dimension
            self.fc1 = nn.Linear(config.d_model, config.d_ff)
            self.fc2 = nn.Linear(config.d_ff, config.d_model)

            # Dropout
            self.dropout = nn.Dropout(config.ff_dropout)

            # Initialize
            self._init_weights()

        def _init_weights(self):
            """Initialize with Xavier uniform and zero bias."""
            nn.init.xavier_uniform_(self.fc1.weight)
            nn.init.xavier_uniform_(self.fc2.weight)
            if self.fc1.bias is not None:
                nn.init.zeros_(self.fc1.bias)
            if self.fc2.bias is not None:
                nn.init.zeros_(self.fc2.bias)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """
            Forward pass through FFN.

            Args:
                x: Input tensor (batch_size, seq_length, d_model)

            Returns:
                Output tensor of same shape
            """
            x = self.fc1(x)
            x = self.activation(x)
            x = self.dropout(x)
            x = self.fc2(x)
            x = self.dropout(x)
            return x


    class GatedFeedForward(nn.Module):
        """
        Gated Linear Unit (GLU) variant of Feed-Forward Network.

        Uses a gating mechanism for potentially better performance:
        GLU(x) = (W1 * x) * sigmoid(W2 * x)

        Args:
            config: TransformerConfig object
        """

        def __init__(self, config: TransformerConfig):
            super().__init__()
            self.config = config

            # Gate and value projections
            self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
            self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
            self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)

            # Dropout
            self.dropout = nn.Dropout(config.ff_dropout)

            # Initialize
            self._init_weights()

        def _init_weights(self):
            """Initialize projections."""
            for module in [self.gate_proj, self.up_proj, self.down_proj]:
                nn.init.xavier_uniform_(module.weight)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """
            Forward with gating.

            Args:
                x: Input tensor

            Returns:
                Gated output
            """
            gate = F.silu(self.gate_proj(x))
            up = self.up_proj(x)
            hidden = gate * up
            output = self.down_proj(hidden)
            return self.dropout(output)


# ==============================================================================
# Transformer Layers
# ==============================================================================

if TORCH_AVAILABLE:

    class TransformerEncoderLayer(nn.Module):
        """
        Single Transformer Encoder Layer.

        Consists of:
        1. Multi-Head Self-Attention
        2. Feed-Forward Network

        With residual connections and layer normalization around each.

        Args:
            config: TransformerConfig object
        """

        def __init__(self, config: TransformerConfig):
            super().__init__()
            self.config = config

            # Self-attention
            self.self_attention = MultiHeadSelfAttention(config)

            # Feed-forward
            self.feed_forward = FeedForwardNetwork(config)

            # Layer normalization
            self.norm1 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
            self.norm2 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)

            # Dropout for residual connections
            self.dropout = nn.Dropout(config.dropout)

            # Pre-norm or post-norm
            self.pre_norm = config.pre_norm

        def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            output_attentions: bool = False
        ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
            """
            Forward pass through encoder layer.

            Args:
                hidden_states: Input (batch, seq_len, d_model)
                attention_mask: Optional attention mask
                output_attentions: Whether to return attention weights

            Returns:
                Layer output, optionally with attention weights
            """
            attention_weights = None

            if self.pre_norm:
                # Pre-layer normalization (more stable training)
                residual = hidden_states
                hidden_states = self.norm1(hidden_states)

                if output_attentions:
                    hidden_states, attention_weights = self.self_attention(
                        hidden_states,
                        attention_mask=attention_mask,
                        output_attentions=True
                    )
                else:
                    hidden_states = self.self_attention(
                        hidden_states,
                        attention_mask=attention_mask
                    )

                hidden_states = residual + hidden_states

                residual = hidden_states
                hidden_states = self.norm2(hidden_states)
                hidden_states = self.feed_forward(hidden_states)
                hidden_states = residual + hidden_states

            else:
                # Post-layer normalization (original transformer)
                if output_attentions:
                    attn_output, attention_weights = self.self_attention(
                        hidden_states,
                        attention_mask=attention_mask,
                        output_attentions=True
                    )
                else:
                    attn_output = self.self_attention(
                        hidden_states,
                        attention_mask=attention_mask
                    )

                hidden_states = self.norm1(hidden_states + attn_output)

                ff_output = self.feed_forward(hidden_states)
                hidden_states = self.norm2(hidden_states + ff_output)

            if output_attentions:
                return hidden_states, attention_weights
            return hidden_states


    class TransformerDecoderLayer(nn.Module):
        """
        Single Transformer Decoder Layer.

        Consists of:
        1. Masked Multi-Head Self-Attention (causal)
        2. Multi-Head Cross-Attention (to encoder outputs)
        3. Feed-Forward Network

        With residual connections and layer normalization around each.

        Args:
            config: TransformerConfig object
        """

        def __init__(self, config: TransformerConfig):
            super().__init__()
            self.config = config

            # Masked self-attention
            self.self_attention = MultiHeadSelfAttention(config)

            # Cross-attention to encoder
            self.cross_attention = MultiHeadCrossAttention(config)

            # Feed-forward
            self.feed_forward = FeedForwardNetwork(config)

            # Layer normalization
            self.norm1 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
            self.norm2 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
            self.norm3 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)

            # Dropout
            self.dropout = nn.Dropout(config.dropout)

            self.pre_norm = config.pre_norm

        def forward(
            self,
            hidden_states: torch.Tensor,
            encoder_hidden_states: torch.Tensor,
            self_attention_mask: Optional[torch.Tensor] = None,
            cross_attention_mask: Optional[torch.Tensor] = None,
            output_attentions: bool = False
        ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
            """
            Forward pass through decoder layer.

            Args:
                hidden_states: Decoder input (batch, tgt_len, d_model)
                encoder_hidden_states: Encoder output (batch, src_len, d_model)
                self_attention_mask: Mask for decoder self-attention
                cross_attention_mask: Mask for cross-attention
                output_attentions: Whether to return attention weights

            Returns:
                Layer output, optionally with self and cross attention weights
            """
            self_attn_weights = None
            cross_attn_weights = None

            if self.pre_norm:
                # Self-attention with causal mask
                residual = hidden_states
                hidden_states = self.norm1(hidden_states)

                if output_attentions:
                    hidden_states, self_attn_weights = self.self_attention(
                        hidden_states,
                        attention_mask=self_attention_mask,
                        causal_mask=True,
                        output_attentions=True
                    )
                else:
                    hidden_states = self.self_attention(
                        hidden_states,
                        attention_mask=self_attention_mask,
                        causal_mask=True
                    )

                hidden_states = residual + hidden_states

                # Cross-attention
                residual = hidden_states
                hidden_states = self.norm2(hidden_states)

                if output_attentions:
                    hidden_states, cross_attn_weights = self.cross_attention(
                        hidden_states,
                        encoder_hidden_states,
                        attention_mask=cross_attention_mask,
                        output_attentions=True
                    )
                else:
                    hidden_states = self.cross_attention(
                        hidden_states,
                        encoder_hidden_states,
                        attention_mask=cross_attention_mask
                    )

                hidden_states = residual + hidden_states

                # Feed-forward
                residual = hidden_states
                hidden_states = self.norm3(hidden_states)
                hidden_states = self.feed_forward(hidden_states)
                hidden_states = residual + hidden_states

            else:
                # Post-norm (original transformer)
                if output_attentions:
                    self_attn_output, self_attn_weights = self.self_attention(
                        hidden_states,
                        attention_mask=self_attention_mask,
                        causal_mask=True,
                        output_attentions=True
                    )
                else:
                    self_attn_output = self.self_attention(
                        hidden_states,
                        attention_mask=self_attention_mask,
                        causal_mask=True
                    )

                hidden_states = self.norm1(hidden_states + self_attn_output)

                if output_attentions:
                    cross_attn_output, cross_attn_weights = self.cross_attention(
                        hidden_states,
                        encoder_hidden_states,
                        attention_mask=cross_attention_mask,
                        output_attentions=True
                    )
                else:
                    cross_attn_output = self.cross_attention(
                        hidden_states,
                        encoder_hidden_states,
                        attention_mask=cross_attention_mask
                    )

                hidden_states = self.norm2(hidden_states + cross_attn_output)

                ff_output = self.feed_forward(hidden_states)
                hidden_states = self.norm3(hidden_states + ff_output)

            if output_attentions:
                return hidden_states, self_attn_weights, cross_attn_weights
            return hidden_states


# ==============================================================================
# Full Transformer Model
# ==============================================================================

if TORCH_AVAILABLE:

    class TransformerEncoder(nn.Module):
        """
        Stack of Transformer Encoder Layers.

        Args:
            config: TransformerConfig object
        """

        def __init__(self, config: TransformerConfig):
            super().__init__()
            self.config = config

            # Stack of encoder layers
            self.layers = nn.ModuleList([
                TransformerEncoderLayer(config)
                for _ in range(config.n_encoder_layers)
            ])

            # Final layer norm
            self.final_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)

        def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            output_attentions: bool = False,
            output_hidden_states: bool = False
        ) -> Dict[str, torch.Tensor]:
            """
            Forward through all encoder layers.

            Args:
                hidden_states: Input embeddings (batch, seq_len, d_model)
                attention_mask: Optional attention mask
                output_attentions: Whether to return attention weights
                output_hidden_states: Whether to return all hidden states

            Returns:
                Dictionary with outputs
            """
            all_hidden_states = [] if output_hidden_states else None
            all_attentions = [] if output_attentions else None

            for layer in self.layers:
                if output_hidden_states:
                    all_hidden_states.append(hidden_states)

                if output_attentions:
                    hidden_states, attn_weights = layer(
                        hidden_states,
                        attention_mask=attention_mask,
                        output_attentions=True
                    )
                    all_attentions.append(attn_weights)
                else:
                    hidden_states = layer(
                        hidden_states,
                        attention_mask=attention_mask
                    )

            hidden_states = self.final_norm(hidden_states)

            if output_hidden_states:
                all_hidden_states.append(hidden_states)

            return {
                "last_hidden_state": hidden_states,
                "hidden_states": all_hidden_states,
                "attentions": all_attentions
            }


    class TransformerDecoder(nn.Module):
        """
        Stack of Transformer Decoder Layers.

        Args:
            config: TransformerConfig object
        """

        def __init__(self, config: TransformerConfig):
            super().__init__()
            self.config = config

            # Stack of decoder layers
            self.layers = nn.ModuleList([
                TransformerDecoderLayer(config)
                for _ in range(config.n_decoder_layers)
            ])

            # Final layer norm
            self.final_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)

        def forward(
            self,
            hidden_states: torch.Tensor,
            encoder_hidden_states: torch.Tensor,
            self_attention_mask: Optional[torch.Tensor] = None,
            cross_attention_mask: Optional[torch.Tensor] = None,
            output_attentions: bool = False,
            output_hidden_states: bool = False
        ) -> Dict[str, torch.Tensor]:
            """
            Forward through all decoder layers.

            Args:
                hidden_states: Decoder input embeddings
                encoder_hidden_states: Encoder output
                self_attention_mask: Mask for decoder self-attention
                cross_attention_mask: Mask for cross-attention
                output_attentions: Whether to return attention weights
                output_hidden_states: Whether to return all hidden states

            Returns:
                Dictionary with outputs
            """
            all_hidden_states = [] if output_hidden_states else None
            all_self_attentions = [] if output_attentions else None
            all_cross_attentions = [] if output_attentions else None

            for layer in self.layers:
                if output_hidden_states:
                    all_hidden_states.append(hidden_states)

                if output_attentions:
                    hidden_states, self_attn, cross_attn = layer(
                        hidden_states,
                        encoder_hidden_states,
                        self_attention_mask=self_attention_mask,
                        cross_attention_mask=cross_attention_mask,
                        output_attentions=True
                    )
                    all_self_attentions.append(self_attn)
                    all_cross_attentions.append(cross_attn)
                else:
                    hidden_states = layer(
                        hidden_states,
                        encoder_hidden_states,
                        self_attention_mask=self_attention_mask,
                        cross_attention_mask=cross_attention_mask
                    )

            hidden_states = self.final_norm(hidden_states)

            if output_hidden_states:
                all_hidden_states.append(hidden_states)

            return {
                "last_hidden_state": hidden_states,
                "hidden_states": all_hidden_states,
                "self_attentions": all_self_attentions,
                "cross_attentions": all_cross_attentions
            }


    class QueenTransformer(nn.Module):
        """
        Complete Transformer Model for AIVA Queen.

        Full encoder-decoder transformer architecture with:
        - Token embeddings
        - Positional encodings
        - Encoder stack
        - Decoder stack
        - Output projection (language modeling head)

        Args:
            config: TransformerConfig object
        """

        def __init__(self, config: TransformerConfig):
            super().__init__()
            self.config = config

            # Token embeddings (shared between encoder and decoder)
            self.token_embedding = nn.Embedding(
                config.vocab_size,
                config.d_model,
                padding_idx=config.pad_token_id
            )

            # Positional encoding
            if config.pos_encoding_type == "sinusoidal":
                self.pos_encoding = SinusoidalPositionalEncoding(
                    config.d_model,
                    config.max_seq_length,
                    config.dropout
                )
            else:
                self.pos_encoding = LearnedPositionalEncoding(
                    config.d_model,
                    config.max_seq_length,
                    config.dropout
                )

            # Encoder and decoder
            self.encoder = TransformerEncoder(config)
            self.decoder = TransformerDecoder(config)

            # Language modeling head
            self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

            # Tie embeddings if configured
            if config.tie_word_embeddings:
                self.lm_head.weight = self.token_embedding.weight

            # Scaling factor for embeddings
            self.embed_scale = math.sqrt(config.d_model)

            # Initialize weights
            self._init_weights()

            logger.info(f"QueenTransformer initialized with {self.count_parameters():,} parameters")

        def _init_weights(self):
            """Initialize model weights."""
            # Initialize embeddings
            nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
            if self.config.pad_token_id is not None:
                self.token_embedding.weight.data[self.config.pad_token_id].zero_()

            # Initialize LM head if not tied
            if not self.config.tie_word_embeddings:
                nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.02)

        def count_parameters(self) -> int:
            """Count trainable parameters."""
            return sum(p.numel() for p in self.parameters() if p.requires_grad)

        def get_input_embeddings(self) -> nn.Embedding:
            """Get the input embedding layer."""
            return self.token_embedding

        def set_input_embeddings(self, embeddings: nn.Embedding):
            """Set the input embedding layer."""
            self.token_embedding = embeddings

        def _create_padding_mask(
            self,
            input_ids: torch.Tensor,
            pad_token_id: int
        ) -> torch.Tensor:
            """
            Create padding mask from input IDs.

            Args:
                input_ids: Token IDs (batch, seq_len)
                pad_token_id: ID of padding token

            Returns:
                Boolean mask with True for valid positions
            """
            return (input_ids != pad_token_id).unsqueeze(1).unsqueeze(2)

        def encode(
            self,
            src_input_ids: torch.Tensor,
            src_attention_mask: Optional[torch.Tensor] = None,
            output_attentions: bool = False,
            output_hidden_states: bool = False
        ) -> Dict[str, torch.Tensor]:
            """
            Encode source sequence.

            Args:
                src_input_ids: Source token IDs (batch, src_len)
                src_attention_mask: Source attention mask
                output_attentions: Whether to return attention weights
                output_hidden_states: Whether to return all hidden states

            Returns:
                Encoder outputs dictionary
            """
            # Create mask if not provided
            if src_attention_mask is None:
                src_attention_mask = self._create_padding_mask(
                    src_input_ids,
                    self.config.pad_token_id
                )

            # Embed and add positional encoding
            src_embeddings = self.token_embedding(src_input_ids) * self.embed_scale
            src_embeddings = self.pos_encoding(src_embeddings)

            # Encode
            encoder_outputs = self.encoder(
                src_embeddings,
                attention_mask=src_attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states
            )

            return encoder_outputs

        def decode(
            self,
            tgt_input_ids: torch.Tensor,
            encoder_hidden_states: torch.Tensor,
            tgt_attention_mask: Optional[torch.Tensor] = None,
            encoder_attention_mask: Optional[torch.Tensor] = None,
            output_attentions: bool = False,
            output_hidden_states: bool = False
        ) -> Dict[str, torch.Tensor]:
            """
            Decode target sequence.

            Args:
                tgt_input_ids: Target token IDs (batch, tgt_len)
                encoder_hidden_states: Encoder output
                tgt_attention_mask: Target attention mask
                encoder_attention_mask: Encoder attention mask for cross-attention
                output_attentions: Whether to return attention weights
                output_hidden_states: Whether to return all hidden states

            Returns:
                Decoder outputs dictionary
            """
            # Create mask if not provided
            if tgt_attention_mask is None:
                tgt_attention_mask = self._create_padding_mask(
                    tgt_input_ids,
                    self.config.pad_token_id
                )

            # Embed and add positional encoding
            tgt_embeddings = self.token_embedding(tgt_input_ids) * self.embed_scale
            tgt_embeddings = self.pos_encoding(tgt_embeddings)

            # Decode
            decoder_outputs = self.decoder(
                tgt_embeddings,
                encoder_hidden_states,
                self_attention_mask=tgt_attention_mask,
                cross_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states
            )

            return decoder_outputs

        def forward(
            self,
            src_input_ids: torch.Tensor,
            tgt_input_ids: torch.Tensor,
            src_attention_mask: Optional[torch.Tensor] = None,
            tgt_attention_mask: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            output_attentions: bool = False,
            output_hidden_states: bool = False,
            return_dict: bool = True
        ) -> Dict[str, torch.Tensor]:
            """
            Full forward pass through encoder-decoder.

            Args:
                src_input_ids: Source token IDs (batch, src_len)
                tgt_input_ids: Target token IDs (batch, tgt_len)
                src_attention_mask: Source attention mask
                tgt_attention_mask: Target attention mask
                labels: Target labels for loss computation
                output_attentions: Whether to return attention weights
                output_hidden_states: Whether to return all hidden states
                return_dict: Whether to return a dictionary

            Returns:
                Model outputs with logits and optional loss
            """
            # Encode source
            encoder_outputs = self.encode(
                src_input_ids,
                src_attention_mask=src_attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states
            )

            encoder_hidden_states = encoder_outputs["last_hidden_state"]

            # Create encoder mask for cross-attention
            if src_attention_mask is None:
                encoder_attention_mask = self._create_padding_mask(
                    src_input_ids,
                    self.config.pad_token_id
                )
            else:
                encoder_attention_mask = src_attention_mask

            # Decode
            decoder_outputs = self.decode(
                tgt_input_ids,
                encoder_hidden_states,
                tgt_attention_mask=tgt_attention_mask,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states
            )

            # Project to vocabulary
            logits = self.lm_head(decoder_outputs["last_hidden_state"])

            # Compute loss if labels provided
            loss = None
            if labels is not None:
                loss_fct = nn.CrossEntropyLoss(
                    ignore_index=self.config.pad_token_id,
                    label_smoothing=0.1
                )
                # Shift so that tokens < n predict n
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
                loss = loss_fct(
                    shift_logits.view(-1, self.config.vocab_size),
                    shift_labels.view(-1)
                )

            if return_dict:
                return {
                    "loss": loss,
                    "logits": logits,
                    "encoder_last_hidden_state": encoder_hidden_states,
                    "encoder_hidden_states": encoder_outputs.get("hidden_states"),
                    "encoder_attentions": encoder_outputs.get("attentions"),
                    "decoder_hidden_states": decoder_outputs.get("hidden_states"),
                    "decoder_self_attentions": decoder_outputs.get("self_attentions"),
                    "decoder_cross_attentions": decoder_outputs.get("cross_attentions"),
                }

            return logits, loss

        @torch.no_grad()
        def generate(
            self,
            src_input_ids: torch.Tensor,
            max_length: int = 50,
            min_length: int = 1,
            num_beams: int = 1,
            temperature: float = 1.0,
            top_k: int = 50,
            top_p: float = 0.9,
            repetition_penalty: float = 1.0,
            eos_token_id: Optional[int] = None,
            pad_token_id: Optional[int] = None,
            **kwargs
        ) -> torch.Tensor:
            """
            Generate output sequence autoregressively.

            Args:
                src_input_ids: Source token IDs
                max_length: Maximum generation length
                min_length: Minimum generation length
                num_beams: Number of beams for beam search
                temperature: Sampling temperature
                top_k: Top-k filtering
                top_p: Nucleus sampling probability
                repetition_penalty: Penalty for repeated tokens
                eos_token_id: End of sequence token ID
                pad_token_id: Padding token ID

            Returns:
                Generated token IDs
            """
            self.eval()

            eos_token_id = eos_token_id or self.config.eos_token_id
            pad_token_id = pad_token_id or self.config.pad_token_id
            bos_token_id = self.config.bos_token_id

            batch_size = src_input_ids.size(0)
            device = src_input_ids.device

            # Encode source
            encoder_outputs = self.encode(src_input_ids)
            encoder_hidden_states = encoder_outputs["last_hidden_state"]

            # Initialize with BOS token
            generated = torch.full(
                (batch_size, 1),
                bos_token_id,
                dtype=torch.long,
                device=device
            )

            # Track which sequences have finished
            unfinished = torch.ones(batch_size, dtype=torch.bool, device=device)

            for _ in range(max_length - 1):
                # Decode current sequence
                decoder_outputs = self.decode(
                    generated,
                    encoder_hidden_states
                )

                # Get logits for last position
                next_token_logits = self.lm_head(
                    decoder_outputs["last_hidden_state"][:, -1, :]
                )

                # Apply temperature
                if temperature != 1.0:
                    next_token_logits = next_token_logits / temperature

                # Apply repetition penalty
                if repetition_penalty != 1.0:
                    for i in range(batch_size):
                        for token_id in set(generated[i].tolist()):
                            next_token_logits[i, token_id] /= repetition_penalty

                # Apply top-k filtering
                if top_k > 0:
                    indices_to_remove = next_token_logits < torch.topk(
                        next_token_logits, top_k
                    )[0][..., -1, None]
                    next_token_logits[indices_to_remove] = float('-inf')

                # Apply top-p (nucleus) filtering
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(
                        next_token_logits, descending=True
                    )
                    cumulative_probs = torch.cumsum(
                        F.softmax(sorted_logits, dim=-1), dim=-1
                    )

                    # Remove tokens with cumulative probability above threshold
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0

                    indices_to_remove = sorted_indices_to_remove.scatter(
                        1, sorted_indices, sorted_indices_to_remove
                    )
                    next_token_logits[indices_to_remove] = float('-inf')

                # Sample next token
                probs = F.softmax(next_token_logits, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1)

                # Replace finished sequences with pad
                next_tokens = next_tokens.squeeze(-1)
                next_tokens = next_tokens * unfinished + pad_token_id * (~unfinished)

                # Append to generated
                generated = torch.cat([generated, next_tokens.unsqueeze(-1)], dim=-1)

                # Update finished status
                unfinished = unfinished & (next_tokens != eos_token_id)

                # Stop if all sequences finished
                if not unfinished.any():
                    break

            return generated

        def save_pretrained(self, save_directory: Union[str, Path]):
            """Save model and config to directory."""
            save_directory = Path(save_directory)
            save_directory.mkdir(parents=True, exist_ok=True)

            # Save config
            self.config.save(save_directory / "config.json")

            # Save model weights
            torch.save(self.state_dict(), save_directory / "model.pt")

            logger.info(f"Model saved to {save_directory}")

        @classmethod
        def from_pretrained(
            cls,
            pretrained_path: Union[str, Path],
            device: str = "cpu"
        ) -> "QueenTransformer":
            """Load model from pretrained directory."""
            pretrained_path = Path(pretrained_path)

            # Load config
            config = TransformerConfig.load(pretrained_path / "config.json")

            # Create model
            model = cls(config)

            # Load weights
            state_dict = torch.load(
                pretrained_path / "model.pt",
                map_location=device
            )
            model.load_state_dict(state_dict)

            logger.info(f"Model loaded from {pretrained_path}")
            return model


# ==============================================================================
# Training Utilities
# ==============================================================================

if TORCH_AVAILABLE:

    class TransformerTrainer:
        """
        Training loop for the QueenTransformer.

        Handles:
        - Optimization with warmup
        - Gradient accumulation
        - Mixed precision training
        - Logging and checkpointing
        - Evaluation

        Args:
            model: QueenTransformer model
            train_config: TrainingConfig object
            train_dataloader: Training data loader
            eval_dataloader: Evaluation data loader (optional)
            device: Device to train on
        """

        def __init__(
            self,
            model: QueenTransformer,
            train_config: TrainingConfig,
            train_dataloader: DataLoader,
            eval_dataloader: Optional[DataLoader] = None,
            device: str = "cuda" if torch.cuda.is_available() else "cpu"
        ):
            self.model = model.to(device)
            self.config = train_config
            self.train_dataloader = train_dataloader
            self.eval_dataloader = eval_dataloader
            self.device = device

            # Optimizer
            self.optimizer = self._create_optimizer()

            # Learning rate scheduler
            self.scheduler = self._create_scheduler()

            # Mixed precision scaler
            self.scaler = torch.amp.GradScaler('cuda') if train_config.use_amp and device == "cuda" else None

            # Training state
            self.global_step = 0
            self.epoch = 0
            self.best_eval_loss = float('inf')

            # Metrics
            self.train_losses = []
            self.eval_losses = []

        def _create_optimizer(self) -> AdamW:
            """Create AdamW optimizer with weight decay."""
            # Separate parameters that should and shouldn't have weight decay
            no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]

            optimizer_grouped_parameters = [
                {
                    "params": [
                        p for n, p in self.model.named_parameters()
                        if not any(nd in n for nd in no_decay)
                    ],
                    "weight_decay": self.config.weight_decay,
                },
                {
                    "params": [
                        p for n, p in self.model.named_parameters()
                        if any(nd in n for nd in no_decay)
                    ],
                    "weight_decay": 0.0,
                },
            ]

            return AdamW(
                optimizer_grouped_parameters,
                lr=self.config.learning_rate,
                betas=(self.config.beta1, self.config.beta2),
                eps=self.config.eps
            )

        def _create_scheduler(self):
            """Create learning rate scheduler."""
            total_steps = len(self.train_dataloader) * self.config.num_epochs

            if self.config.lr_scheduler == "cosine":
                return CosineAnnealingWarmRestarts(
                    self.optimizer,
                    T_0=self.config.warmup_steps,
                    T_mult=2
                )
            else:
                return OneCycleLR(
                    self.optimizer,
                    max_lr=self.config.learning_rate,
                    total_steps=total_steps,
                    pct_start=self.config.warmup_steps / total_steps
                )

        def _training_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
            """Execute single training step."""
            # Move batch to device
            src_ids = batch["src_input_ids"].to(self.device)
            tgt_ids = batch["tgt_input_ids"].to(self.device)
            labels = batch.get("labels", tgt_ids).to(self.device)

            # Forward pass
            if self.scaler is not None:
                with torch.amp.autocast('cuda'):
                    outputs = self.model(
                        src_ids,
                        tgt_ids,
                        labels=labels
                    )
                    loss = outputs["loss"]
            else:
                outputs = self.model(
                    src_ids,
                    tgt_ids,
                    labels=labels
                )
                loss = outputs["loss"]

            # Scale loss for accumulation
            loss = loss / self.config.accumulation_steps

            return loss

        def _backward_step(self, loss: torch.Tensor):
            """Execute backward pass."""
            if self.scaler is not None:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()

        def _optimizer_step(self):
            """Execute optimizer step."""
            if self.scaler is not None:
                # Unscale gradients
                self.scaler.unscale_(self.optimizer)

            # Clip gradients
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.config.max_grad_norm
            )

            # Optimizer step
            if self.scaler is not None:
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                self.optimizer.step()

            # Scheduler step
            self.scheduler.step()

            # Zero gradients
            self.optimizer.zero_grad()

        @torch.no_grad()
        def evaluate(self) -> float:
            """Evaluate model on validation set."""
            if self.eval_dataloader is None:
                return float('inf')

            self.model.eval()
            total_loss = 0.0
            num_batches = 0

            for batch in self.eval_dataloader:
                src_ids = batch["src_input_ids"].to(self.device)
                tgt_ids = batch["tgt_input_ids"].to(self.device)
                labels = batch.get("labels", tgt_ids).to(self.device)

                outputs = self.model(
                    src_ids,
                    tgt_ids,
                    labels=labels
                )

                total_loss += outputs["loss"].item()
                num_batches += 1

            avg_loss = total_loss / num_batches
            self.model.train()

            return avg_loss

        def train(
            self,
            checkpoint_dir: Optional[Union[str, Path]] = None
        ) -> Dict[str, List[float]]:
            """
            Full training loop.

            Args:
                checkpoint_dir: Directory to save checkpoints

            Returns:
                Training history with losses
            """
            logger.info("Starting training...")
            self.model.train()

            if checkpoint_dir:
                checkpoint_dir = Path(checkpoint_dir)
                checkpoint_dir.mkdir(parents=True, exist_ok=True)

            accumulation_counter = 0

            for epoch in range(self.config.num_epochs):
                self.epoch = epoch
                epoch_loss = 0.0
                num_batches = 0

                for batch_idx, batch in enumerate(self.train_dataloader):
                    # Forward pass
                    loss = self._training_step(batch)

                    # Backward pass
                    self._backward_step(loss)

                    accumulation_counter += 1

                    # Optimizer step after accumulation
                    if accumulation_counter >= self.config.accumulation_steps:
                        self._optimizer_step()
                        accumulation_counter = 0
                        self.global_step += 1

                    epoch_loss += loss.item() * self.config.accumulation_steps
                    num_batches += 1

                    # Logging
                    if self.global_step % self.config.logging_steps == 0:
                        avg_loss = epoch_loss / num_batches
                        lr = self.scheduler.get_last_lr()[0]
                        logger.info(
                            f"Epoch {epoch+1}, Step {self.global_step}, "
                            f"Loss: {avg_loss:.4f}, LR: {lr:.2e}"
                        )

                    # Evaluation
                    if self.global_step % self.config.eval_steps == 0:
                        eval_loss = self.evaluate()
                        logger.info(f"Eval Loss: {eval_loss:.4f}")
                        self.eval_losses.append(eval_loss)

                        # Save best model
                        if checkpoint_dir and eval_loss < self.best_eval_loss:
                            self.best_eval_loss = eval_loss
                            self.model.save_pretrained(checkpoint_dir / "best")

                    # Checkpointing
                    if checkpoint_dir and self.global_step % self.config.save_steps == 0:
                        self._save_checkpoint(checkpoint_dir / f"checkpoint-{self.global_step}")

                # End of epoch
                avg_epoch_loss = epoch_loss / num_batches
                self.train_losses.append(avg_epoch_loss)
                logger.info(f"Epoch {epoch+1} completed. Avg Loss: {avg_epoch_loss:.4f}")

            logger.info("Training completed!")
            return {
                "train_losses": self.train_losses,
                "eval_losses": self.eval_losses
            }

        def _save_checkpoint(self, path: Union[str, Path]):
            """Save training checkpoint."""
            path = Path(path)
            path.mkdir(parents=True, exist_ok=True)

            checkpoint = {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "scheduler_state_dict": self.scheduler.state_dict(),
                "global_step": self.global_step,
                "epoch": self.epoch,
                "best_eval_loss": self.best_eval_loss,
                "train_losses": self.train_losses,
                "eval_losses": self.eval_losses,
            }

            if self.scaler is not None:
                checkpoint["scaler_state_dict"] = self.scaler.state_dict()

            torch.save(checkpoint, path / "checkpoint.pt")
            self.model.config.save(path / "config.json")

            logger.info(f"Checkpoint saved to {path}")

        def load_checkpoint(self, path: Union[str, Path]):
            """Load training checkpoint."""
            path = Path(path)

            checkpoint = torch.load(path / "checkpoint.pt", map_location=self.device)

            self.model.load_state_dict(checkpoint["model_state_dict"])
            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
            self.global_step = checkpoint["global_step"]
            self.epoch = checkpoint["epoch"]
            self.best_eval_loss = checkpoint["best_eval_loss"]
            self.train_losses = checkpoint["train_losses"]
            self.eval_losses = checkpoint["eval_losses"]

            if self.scaler is not None and "scaler_state_dict" in checkpoint:
                self.scaler.load_state_dict(checkpoint["scaler_state_dict"])

            logger.info(f"Checkpoint loaded from {path}")


# ==============================================================================
# Attention Visualization
# ==============================================================================

if TORCH_AVAILABLE and PLOTTING_AVAILABLE:

    class AttentionVisualizer:
        """
        Utilities for visualizing transformer attention patterns.

        Supports:
        - Single attention head visualization
        - Multi-head attention comparison
        - Cross-attention patterns
        - Attention rollout
        """

        def __init__(self, model: QueenTransformer):
            self.model = model

        @torch.no_grad()
        def get_attention_weights(
            self,
            src_input_ids: torch.Tensor,
            tgt_input_ids: torch.Tensor
        ) -> Dict[str, List[torch.Tensor]]:
            """
            Get attention weights from a forward pass.

            Args:
                src_input_ids: Source token IDs
                tgt_input_ids: Target token IDs

            Returns:
                Dictionary with encoder and decoder attention weights
            """
            self.model.eval()

            outputs = self.model(
                src_input_ids,
                tgt_input_ids,
                output_attentions=True
            )

            return {
                "encoder_attentions": outputs["encoder_attentions"],
                "decoder_self_attentions": outputs["decoder_self_attentions"],
                "decoder_cross_attentions": outputs["decoder_cross_attentions"]
            }

        def plot_attention_head(
            self,
            attention_weights: torch.Tensor,
            head_idx: int = 0,
            src_tokens: Optional[List[str]] = None,
            tgt_tokens: Optional[List[str]] = None,
            title: str = "Attention Pattern",
            save_path: Optional[str] = None
        ):
            """
            Plot attention pattern for a single head.

            Args:
                attention_weights: Attention tensor (batch, heads, tgt_len, src_len)
                head_idx: Which head to visualize
                src_tokens: Source token labels
                tgt_tokens: Target token labels
                title: Plot title
                save_path: Path to save figure
            """
            # Extract single head
            attn = attention_weights[0, head_idx].cpu().numpy()

            fig, ax = plt.subplots(figsize=(10, 8))

            sns.heatmap(
                attn,
                ax=ax,
                cmap="viridis",
                xticklabels=src_tokens if src_tokens else False,
                yticklabels=tgt_tokens if tgt_tokens else False,
                square=True
            )

            ax.set_xlabel("Source Position")
            ax.set_ylabel("Target Position")
            ax.set_title(f"{title} - Head {head_idx}")

            plt.tight_layout()

            if save_path:
                plt.savefig(save_path, dpi=150, bbox_inches='tight')

            plt.show()

        def plot_multi_head(
            self,
            attention_weights: torch.Tensor,
            num_heads: Optional[int] = None,
            src_tokens: Optional[List[str]] = None,
            tgt_tokens: Optional[List[str]] = None,
            title: str = "Multi-Head Attention",
            save_path: Optional[str] = None
        ):
            """
            Plot attention patterns for multiple heads.

            Args:
                attention_weights: Attention tensor
                num_heads: Number of heads to show (default: all)
                src_tokens: Source token labels
                tgt_tokens: Target token labels
                title: Plot title
                save_path: Path to save figure
            """
            n_heads = attention_weights.size(1)
            num_heads = num_heads or n_heads
            num_heads = min(num_heads, n_heads)

            # Determine grid layout
            cols = min(4, num_heads)
            rows = (num_heads + cols - 1) // cols

            fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
            axes = np.array(axes).flatten() if num_heads > 1 else [axes]

            for i in range(num_heads):
                attn = attention_weights[0, i].cpu().numpy()

                sns.heatmap(
                    attn,
                    ax=axes[i],
                    cmap="viridis",
                    xticklabels=False,
                    yticklabels=False,
                    cbar=False,
                    square=True
                )

                axes[i].set_title(f"Head {i}")

            # Hide unused axes
            for i in range(num_heads, len(axes)):
                axes[i].axis('off')

            fig.suptitle(title, fontsize=14)
            plt.tight_layout()

            if save_path:
                plt.savefig(save_path, dpi=150, bbox_inches='tight')

            plt.show()

        def attention_rollout(
            self,
            attention_weights_list: List[torch.Tensor],
            head_fusion: str = "mean"
        ) -> torch.Tensor:
            """
            Compute attention rollout across layers.

            Attention rollout propagates attention through layers
            to show where each position is "looking" across the
            entire network.

            Args:
                attention_weights_list: List of attention weights per layer
                head_fusion: How to combine heads ("mean", "max", "min")

            Returns:
                Rolled-out attention matrix
            """
            # Fuse heads
            if head_fusion == "mean":
                attention_fused = [attn.mean(dim=1) for attn in attention_weights_list]
            elif head_fusion == "max":
                attention_fused = [attn.max(dim=1)[0] for attn in attention_weights_list]
            else:
                attention_fused = [attn.min(dim=1)[0] for attn in attention_weights_list]

            # Start with identity (residual connections)
            rollout = torch.eye(attention_fused[0].size(-1), device=attention_fused[0].device)
            rollout = rollout.unsqueeze(0)  # Add batch dimension

            for attn in attention_fused:
                # Add residual connection
                attn = attn + torch.eye(attn.size(-1), device=attn.device)

                # Normalize rows
                attn = attn / attn.sum(dim=-1, keepdim=True)

                # Multiply with previous rollout
                rollout = torch.matmul(attn, rollout)

            return rollout

        def plot_layer_attention_comparison(
            self,
            attention_weights_list: List[torch.Tensor],
            head_idx: int = 0,
            title: str = "Attention Across Layers",
            save_path: Optional[str] = None
        ):
            """
            Compare attention patterns across layers.

            Args:
                attention_weights_list: List of attention weights per layer
                head_idx: Which head to visualize
                title: Plot title
                save_path: Path to save figure
            """
            n_layers = len(attention_weights_list)
            cols = min(4, n_layers)
            rows = (n_layers + cols - 1) // cols

            fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
            axes = np.array(axes).flatten() if n_layers > 1 else [axes]

            for i, attn in enumerate(attention_weights_list):
                attn_head = attn[0, head_idx].cpu().numpy()

                sns.heatmap(
                    attn_head,
                    ax=axes[i],
                    cmap="viridis",
                    xticklabels=False,
                    yticklabels=False,
                    cbar=False,
                    square=True
                )

                axes[i].set_title(f"Layer {i}")

            # Hide unused axes
            for i in range(n_layers, len(axes)):
                axes[i].axis('off')

            fig.suptitle(f"{title} - Head {head_idx}", fontsize=14)
            plt.tight_layout()

            if save_path:
                plt.savefig(save_path, dpi=150, bbox_inches='tight')

            plt.show()


# ==============================================================================
# Demo Dataset
# ==============================================================================

if TORCH_AVAILABLE:

    class DemoDataset(Dataset):
        """
        Simple demo dataset for testing the transformer.

        Generates synthetic sequence-to-sequence data.
        """

        def __init__(
            self,
            num_samples: int = 1000,
            max_length: int = 32,
            vocab_size: int = 100,
            pad_token_id: int = 0,
            bos_token_id: int = 1,
            eos_token_id: int = 2
        ):
            self.num_samples = num_samples
            self.max_length = max_length
            self.vocab_size = vocab_size
            self.pad_token_id = pad_token_id
            self.bos_token_id = bos_token_id
            self.eos_token_id = eos_token_id

            # Generate synthetic data
            self.data = self._generate_data()

        def _generate_data(self) -> List[Dict[str, torch.Tensor]]:
            """Generate synthetic sequence pairs."""
            data = []

            for _ in range(self.num_samples):
                # Random source length
                src_len = np.random.randint(5, self.max_length - 2)

                # Random source tokens (avoiding special tokens)
                src_tokens = torch.randint(
                    3, self.vocab_size, (src_len,)
                )

                # Target is reverse of source (simple task)
                tgt_tokens = torch.flip(src_tokens, [0])

                # Add BOS and EOS
                src_ids = torch.cat([
                    torch.tensor([self.bos_token_id]),
                    src_tokens,
                    torch.tensor([self.eos_token_id])
                ])

                tgt_ids = torch.cat([
                    torch.tensor([self.bos_token_id]),
                    tgt_tokens,
                    torch.tensor([self.eos_token_id])
                ])

                # Pad to max length
                src_ids = F.pad(
                    src_ids,
                    (0, self.max_length - len(src_ids)),
                    value=self.pad_token_id
                )

                tgt_ids = F.pad(
                    tgt_ids,
                    (0, self.max_length - len(tgt_ids)),
                    value=self.pad_token_id
                )

                data.append({
                    "src_input_ids": src_ids,
                    "tgt_input_ids": tgt_ids,
                    "labels": tgt_ids.clone()
                })

            return data

        def __len__(self) -> int:
            return len(self.data)

        def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
            return self.data[idx]


# ==============================================================================
# Main Demo
# ==============================================================================

def run_demo():
    """Run a demonstration of the QueenTransformer."""

    if not TORCH_AVAILABLE:
        logger.error("PyTorch not available. Cannot run demo.")
        return

    logger.info("=" * 60)
    logger.info("AIVA Queen Neural Transformer Demo")
    logger.info("=" * 60)

    # Configuration
    config = TransformerConfig(
        d_model=256,
        n_heads=8,
        n_encoder_layers=3,
        n_decoder_layers=3,
        d_ff=512,
        max_seq_length=64,
        vocab_size=100,
        dropout=0.1,
        pos_encoding_type="sinusoidal",
        pre_norm=True,
        activation="gelu"
    )

    logger.info(f"Config: {config.to_dict()}")

    # Create model
    model = QueenTransformer(config)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    logger.info(f"Model created with {model.count_parameters():,} parameters")
    logger.info(f"Device: {device}")

    # Create demo dataset
    logger.info("Creating demo dataset...")
    train_dataset = DemoDataset(
        num_samples=500,
        max_length=32,
        vocab_size=config.vocab_size
    )

    eval_dataset = DemoDataset(
        num_samples=100,
        max_length=32,
        vocab_size=config.vocab_size
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=16,
        shuffle=True
    )

    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=16
    )

    # Training config
    train_config = TrainingConfig(
        learning_rate=1e-4,
        num_epochs=3,
        batch_size=16,
        warmup_steps=50,
        logging_steps=20,
        eval_steps=50,
        use_amp=False  # Disable for demo simplicity
    )

    # Create trainer
    trainer = TransformerTrainer(
        model=model,
        train_config=train_config,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader,
        device=device
    )

    # Train
    logger.info("Starting training...")
    history = trainer.train()

    logger.info(f"Training completed!")
    logger.info(f"Final train loss: {history['train_losses'][-1]:.4f}")
    if history['eval_losses']:
        logger.info(f"Final eval loss: {history['eval_losses'][-1]:.4f}")

    # Test generation
    logger.info("\nTesting generation...")

    # Create a test input
    test_src = torch.tensor([[1, 5, 10, 15, 20, 25, 2]]).to(device)  # BOS + tokens + EOS

    generated = model.generate(
        test_src,
        max_length=10,
        temperature=0.7
    )

    logger.info(f"Source: {test_src[0].tolist()}")
    logger.info(f"Generated: {generated[0].tolist()}")

    # Visualize attention if plotting available
    if PLOTTING_AVAILABLE:
        logger.info("\nVisualizing attention patterns...")
        visualizer = AttentionVisualizer(model)

        # Get attention weights
        test_tgt = torch.tensor([[1, 25, 20, 15, 10, 5, 2]]).to(device)
        attention_data = visualizer.get_attention_weights(test_src, test_tgt)

        # Plot encoder attention from first layer
        if attention_data["encoder_attentions"]:
            logger.info("Plotting encoder attention...")
            visualizer.plot_multi_head(
                attention_data["encoder_attentions"][0],
                num_heads=4,
                title="Encoder Layer 0 Attention"
            )

    logger.info("\n" + "=" * 60)
    logger.info("Demo completed successfully!")
    logger.info("=" * 60)


if __name__ == "__main__":
    run_demo()
