#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_glm_image.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional

import torch.nn as nn
import torch.nn.functional as F

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from .configuration_glm_image import GlmImageConfig, GlmImageTextConfig, GlmImageVisionConfig, GlmImageVQVAEConfig


if is_torch_available():
    import torch


class GlmImageVisionMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act]
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor | None,
    scaling: float,
    dropout: float = 0.0,
    **kwargs: Unpack[TransformersKwargs],
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


class GlmImageVisionAttention(nn.Module):
    def __init__(self, config: GlmImageVisionConfig) -> None:
        super().__init__()
        self.dim = config.hidden_size
        self.num_heads = config.num_heads
        self.head_dim = self.dim // self.num_heads
        self.num_key_value_groups = 1  # needed for eager attention
        self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
        self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
        self.scaling = self.head_dim**-0.5
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.is_causal = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        seq_length = hidden_states.shape[0]
        query_states, key_states, value_states = (
            self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
        )
        query_states = query_states.transpose(0, 1).unsqueeze(0)
        key_states = key_states.transpose(0, 1).unsqueeze(0)
        value_states = value_states.transpose(0, 1).unsqueeze(0)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        if "flash" in self.config._attn_implementation:
            # Flash Attention: Use cu_seqlens for variable length attention
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
            attn_output, _ = attention_interface(
                self,
                query_states,
                key_states,
                value_states,
                attention_mask=None,
                scaling=self.scaling,
                dropout=0.0 if not self.training else self.attention_dropout,
                cu_seq_lens_q=cu_seqlens,
                cu_seq_lens_k=cu_seqlens,
                max_length_q=max_seqlen,
                max_length_k=max_seqlen,
                is_causal=False,
                **kwargs,
            )
        else:
            # Other implementations: Process each chunk separately
            lengths = cu_seqlens[1:] - cu_seqlens[:-1]
            splits = [
                torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
            ]

            attn_outputs = [
                attention_interface(
                    self,
                    q,
                    k,
                    v,
                    attention_mask=None,
                    scaling=self.scaling,
                    dropout=0.0 if not self.training else self.attention_dropout,
                    is_causal=False,
                    **kwargs,
                )[0]
                for q, k, v in zip(*splits)
            ]
            attn_output = torch.cat(attn_outputs, dim=1)

        attn_output = attn_output.reshape(seq_length, -1).contiguous()
        attn_output = self.proj(attn_output)
        return attn_output


class GlmImageVisionPatchEmbed(nn.Module):
    def __init__(self, config: GlmImageVisionConfig) -> None:
        super().__init__()
        self.patch_size = config.patch_size
        self.in_channels = config.in_channels
        self.embed_dim = config.hidden_size
        kernel_size = [self.patch_size, self.patch_size]
        self.proj = nn.Conv2d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)

    def forward(self, hidden_states) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        hidden_states = hidden_states.view(-1, self.in_channels, self.patch_size, self.patch_size)
        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
        return hidden_states


class GlmImageVisionEmbeddings(nn.Module):
    def __init__(self, config: GlmImageVisionConfig) -> None:
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.num_positions = self.num_patches
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
        self.interpolated_method = "bilinear"

    def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
        """
        Forward pass with integrated position encoding adaptation using 2D interpolation.

        Args:
            embeddings: Input embeddings tensor
            lengths (torch.Tensor): Sequence lengths for each image in the batch.
            image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
            h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
            w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.

        Returns:
            torch.Tensor: Embeddings with adapted position encoding added.
        """
        # Get position embedding parameters
        pos_embed_weight = self.position_embedding.weight
        hidden_size = pos_embed_weight.shape[1]
        device = pos_embed_weight.device

        # Convert inputs to tensors if needed
        if isinstance(lengths, list):
            lengths = torch.tensor(lengths, device=device, dtype=torch.long)

        # Prepare 2D position embedding
        orig_size_sq = pos_embed_weight.shape[0]
        orig_size = int(orig_size_sq**0.5)
        pos_embed_2d = (
            pos_embed_weight.view(orig_size, orig_size, hidden_size)
            .permute(2, 0, 1)
            .unsqueeze(0)
            .to(device=device, dtype=torch.float32)
        )

        # Calculate target dimensions for each patch
        target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
            device=device, dtype=torch.float32
        )
        target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
            device=device, dtype=torch.float32
        )

        # Normalize coordinates to [-1, 1] range for grid_sample
        norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
        norm_h = ((h_coords + 0.5) / target_h) * 2 - 1

        # Create sampling grid
        grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)

        # Perform bicubic interpolation
        interpolated_embed_fp32 = F.grid_sample(
            pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border"
        )

        # Reshape and convert back to original dtype
        adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
        adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)

        # Add adapted position encoding to embeddings
        embeddings = embeddings + adapted_pos_embed
        return embeddings


class GlmImageVisionBlock(GradientCheckpointingLayer):
    def __init__(self, config: GlmImageVisionConfig) -> None:
        super().__init__()
        self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.attn = GlmImageVisionAttention(config)
        self.mlp = GlmImageVisionMLP(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        r"""
        cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
            The cumulative sequence lengths of each image or video feature.
        position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
            The cosine and sine position embeddings for vision attention.
        """
        residual = hidden_states

        hidden_states = self.norm1(hidden_states)
        hidden_states = self.attn(
            hidden_states,
            cu_seqlens=cu_seqlens,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    # Keep half or full tensor for later concatenation
    rotary_dim = cos.shape[-1]
    q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
    k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]

    # Apply rotary embeddings on the first half or full tensor
    q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
    k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)

    # Concatenate back to full shape
    q_embed = torch.cat([q_embed, q_pass], dim=-1)
    k_embed = torch.cat([k_embed, k_pass], dim=-1)
    return q_embed, k_embed


@use_kernelized_func(apply_rotary_pos_emb)
class GlmImageTextAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: GlmImageTextConfig, layer_idx: int | None = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
        self.rope_parameters = config.rope_parameters

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: torch.Tensor | None,
        past_key_values: Cache | None = None,
        cache_position: torch.LongTensor | None = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape)
        key_states = self.k_proj(hidden_states).view(hidden_shape)
        value_states = self.v_proj(hidden_states).view(hidden_shape)

        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; position_ids needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


@use_kernel_forward_from_hub("RMSNorm")
class GlmImageRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps: float = 1e-6) -> None:
        """
        GlmImageRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class GlmImageTextMLP(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.activation_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
        up_states = self.gate_up_proj(hidden_states)

        gate, up_states = up_states.chunk(2, dim=-1)
        up_states = up_states * self.activation_fn(gate)

        return self.down_proj(up_states)


class GlmImageTextDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: GlmImageTextConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = GlmImageTextAttention(config, layer_idx)
        self.mlp = GlmImageTextMLP(config)
        self.input_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_self_attn_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_mlp_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    @auto_docstring
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        use_cache: bool | None = False,
        cache_position: torch.LongTensor | None = None,
        **kwargs,
    ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = self.post_self_attn_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_mlp_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


@auto_docstring
class GlmImagePreTrainedModel(PreTrainedModel):
    config: GlmImageConfig
    base_model_prefix = "model"
    input_modalities = ("image", "text")
    supports_gradient_checkpointing = True
    _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn = True
    _supports_sdpa = True

    _can_compile_fullgraph = True
    _supports_attention_backend = True
    _can_record_outputs = {
        "hidden_states": GlmImageTextDecoderLayer,
        "attentions": GlmImageTextAttention,
    }

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)


@dataclass
@auto_docstring(
    custom_intro="""
    Base class for Llava outputs, with hidden states and attentions.
    """
)
class GlmImageModelOutputWithPast(ModelOutput):
    r"""
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
        The rope index difference between sequence length and multimodal rope.
    """

    last_hidden_state: torch.FloatTensor | None = None
    past_key_values: Cache | None = None
    hidden_states: tuple[torch.FloatTensor] | None = None
    attentions: tuple[torch.FloatTensor] | None = None
    rope_deltas: torch.LongTensor | None = None


class GlmImageVQVAEVectorQuantizer(nn.Module):
    """
    A module for vector quantization using learned embedding vectors.

    This module implements the quantization process similar to te one described in
    the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
    input vectors into discrete codebook vectors, which are learned during training.
    Current implementation improves over previous ones by avoiding costly matrix multiplications
    and allowing for post-hoc remapping of indices.
    """

    def __init__(self, config: GlmImageVQVAEConfig):
        super().__init__()
        self.num_embeddings = config.num_embeddings
        self.embedding_dim = config.embed_dim
        self.beta = getattr(config, "beta", 0.25)

        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)

    def forward(self, hidden_state: torch.Tensor):
        hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
        hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)

        # L2 normalize
        hidden_state = F.normalize(hidden_state, p=2, dim=-1)
        hidden_state_flattened = F.normalize(hidden_state_flattened, p=2, dim=-1)
        embedding = F.normalize(self.embedding.weight, p=2, dim=-1)

        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        distances = (
            torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
            + torch.sum(embedding**2, dim=1)
            - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, embedding.transpose(0, 1))
        )

        min_encoding_indices = torch.argmin(distances, dim=1)
        hidden_state_quant = embedding[min_encoding_indices].view(hidden_state.shape)

        # compute loss for embedding
        loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
            (hidden_state_quant - hidden_state.detach()) ** 2
        )

        # preserve gradients
        hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()

        # reshape back to match original input shape
        hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()

        return hidden_state_quant, loss, min_encoding_indices


@dataclass
@auto_docstring
class GlmImageVQVAEModelOutput(BaseModelOutputWithPooling):
    r"""
    quantized_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
        Quantized last hidden state from the VQ-VAE model.
    image_tokens (`torch.FloatTensor` of shape `(batch_size, config.vocab_size`):
        Indices of the image tokens predicted by the VQ-VAE model.
    embedding_loss (`torch.FloatTensor`):
        The embedding loss computed during quantization.
    """

    quantized_last_hidden_state: torch.FloatTensor | None = None
    image_tokens: torch.FloatTensor | None = None
    embedding_loss: torch.FloatTensor | None = None


@auto_docstring(
    custom_intro="""
    The VQ-VAE model used in GlmImage for encoding/decoding images into discrete tokens.
    This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
    [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
    Taigman](https://huggingface.co/papers/2203.13131).
    """
)
class GlmImageVQVAE(GlmImagePreTrainedModel):
    config: GlmImageVQVAEConfig
    _no_split_modules = [
        "GlmImageVQVAEVectorQuantizer",
    ]
    _can_record_outputs = {}

    def __init__(self, config: GlmImageVQVAEConfig):
        super().__init__(config)
        self.quantize = GlmImageVQVAEVectorQuantizer(config)
        self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
        self.eval()  # GlmImage's VQ model is frozen
        self.post_init()

    @merge_with_config_defaults
    @capture_outputs
    def encode(self, hidden_states) -> GlmImageVQVAEModelOutput:
        conv_hidden_states = self.quant_conv(hidden_states)
        quantized_last_hidden_state, emb_loss, indices = self.quantize(conv_hidden_states)
        return GlmImageVQVAEModelOutput(
            last_hidden_state=hidden_states,
            quantized_last_hidden_state=quantized_last_hidden_state,
            image_tokens=indices,
            embedding_loss=emb_loss,
        )


class GlmImageVisionModel(GlmImagePreTrainedModel):
    config: GlmImageVisionConfig
    input_modalities = ("image",)
    _no_split_modules = ["GlmImageVisionBlock"]
    _can_record_outputs = {
        "hidden_states": GlmImageVisionBlock,
        "attentions": GlmImageVisionAttention,
    }
    main_input_name = "pixel_values"

    def __init__(self, config: GlmImageVisionConfig) -> None:
        super().__init__(config)
        self.spatial_merge_size = config.spatial_merge_size
        self.patch_size = config.patch_size

        self.embeddings = GlmImageVisionEmbeddings(config)
        self.patch_embed = GlmImageVisionPatchEmbed(config)

        head_dim = config.hidden_size // config.num_heads

        self.blocks = nn.ModuleList([GlmImageVisionBlock(config) for _ in range(config.depth)])

        self.gradient_checkpointing = False
        self.head_dim = head_dim
        self.post_init()

    def rot_pos_emb(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        return pos_ids

    @merge_with_config_defaults
    @capture_outputs
    @auto_docstring
    def forward(
        self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
    ) -> tuple | BaseModelOutputWithPooling:
        r"""
        pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`):
            Packed pixel values.
        grid_thw (`torch.Tensor` of shape `(num_images, 3)`):
            The temporal, height and width of feature shape of each image.

        Returns:
            `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states.
        """

        hidden_states = self.patch_embed(pixel_values)
        image_type_ids = self.rot_pos_emb(grid_thw)

        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            dim=0,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        hidden_states = self.embeddings(
            hidden_states,
            seqlens,
            grid_thw,
            image_type_ids[:, 0].to(hidden_states.device),
            image_type_ids[:, 1].to(hidden_states.device),
        )

        # Transformer blocks (no position_embeddings needed, already added above)
        for blk in self.blocks:
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
            )

        return BaseModelOutputWithPooling(last_hidden_state=hidden_states)


class GlmImageTextRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config: GlmImageTextConfig, device=None):
        super().__init__()
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config

        self.rope_type = self.config.rope_parameters["rope_type"]
        rope_init_fn: Callable = self.compute_default_rope_parameters
        if self.rope_type != "default":
            rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
        inv_freq, self.attention_scaling = rope_init_fn(self.config, device)

        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
        self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12])

    @staticmethod
    def compute_default_rope_parameters(
        config: GlmImageTextConfig | None = None,
        device: Optional["torch.device"] = None,
        seq_len: int | None = None,
    ) -> tuple["torch.Tensor", float]:
        """
        Computes the inverse frequencies according to the original RoPE implementation
        Args:
            config ([`~transformers.PreTrainedConfig`]):
                The model configuration.
            device (`torch.device`):
                The device to use for initialization of the inverse frequencies.
            seq_len (`int`, *optional*):
                The current sequence length. Unused for this type of RoPE.
        Returns:
            Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
            post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
        """
        base = config.rope_parameters["rope_theta"]
        partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
        head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
        dim = int(head_dim * partial_rotary_factor)

        attention_factor = 1.0  # Unused in this type of RoPE

        # Compute the inverse frequencies
        inv_freq = 1.0 / (
            base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
        )
        return inv_freq, attention_factor

    @torch.no_grad()
    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
    def forward(self, x, position_ids):
        # In contrast to other models, GLM-V has different position ids for the grids
        # So we expand the inv_freq to shape (3, ...)
        inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
        position_ids_expanded = position_ids[:, :, None, :].float()  # shape (3, bs, 1, positions)

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with maybe_autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
            freqs = self.apply_mrope(freqs, self.mrope_section)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

    def apply_mrope(self, freqs, mrope_section):
        section = mrope_section
        chunks = freqs.split(section, dim=-1)
        result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
        return result


@auto_docstring
class GlmImageTextModel(GlmImagePreTrainedModel):
    config: GlmImageTextConfig
    input_modalities = ("text",)

    def __init__(self, config: GlmImageTextConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [GlmImageTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = GlmImageTextRotaryEmbedding(config=config)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    @auto_docstring
    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        use_cache: bool | None = None,
        cache_position: torch.LongTensor | None = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple | BaseModelOutputWithPast:
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        # torch.jit.trace() doesn't support cache objects in the output
        if use_cache and past_key_values is None and not torch.jit.is_tracing():
            past_key_values = DynamicCache(config=self.config)

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        # the hard coded `3` is for temporal, height and width.
        if position_ids is None:
            position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
        elif position_ids.ndim == 2:
            position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)

        # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
        # where each dim indicates visual spatial positions for temporal/height/width grids.
        # There are two scenarios when FA2-like packed masking might be activated.
        # 1. User specifically passed packed `position_ids` and no attention mask.
        #    In this case we expect the useer to create correct position ids for all 3 grids
        #    and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
        # 2. User runs forward with no attention mask and no position ids. In this case, position ids
        #    are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
        #    prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
        #    text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
        if position_ids.ndim == 3 and position_ids.shape[0] == 4:
            text_position_ids = position_ids[0]
            position_ids = position_ids[1:]
        else:
            # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
            text_position_ids = None

        mask_kwargs = {
            "config": self.config,
            "inputs_embeds": inputs_embeds,
            "attention_mask": attention_mask,
            "cache_position": cache_position,
            "past_key_values": past_key_values,
            "position_ids": text_position_ids,
        }
        # Create the masks
        causal_mask = create_causal_mask(**mask_kwargs)

        hidden_states = inputs_embeds
        position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)

        for decoder_layer in self.layers:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=text_position_ids,
                past_key_values=past_key_values,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **kwargs,
            )
            hidden_states = layer_outputs

        hidden_states = self.norm(hidden_states)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
        )


@auto_docstring
class GlmImageModel(GlmImagePreTrainedModel):
    base_model_prefix = "model"
    _checkpoint_conversion_mapping = {}
    # Reference: fix gemma3 grad acc #37208
    accepts_loss_kwargs = False
    _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"]

    def __init__(self, config):
        super().__init__(config)
        self.visual = GlmImageVisionModel._from_config(config.vision_config)
        self.language_model = GlmImageTextModel._from_config(config.text_config)

        self.rope_deltas = None  # cache rope_deltas here
        self.vqmodel = GlmImageVQVAE._from_config(config.vq_config)

        # Per-sample caches for batch processing
        self._cached_decode_position_ids = None  # shape: [batch_size, 3, max_decode_len]
        self._prefill_len = None  # prefill sequence length (same for all samples in batch)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def set_input_embeddings(self, value):
        self.language_model.set_input_embeddings(value)

    def get_rope_index(
        self,
        input_ids: torch.LongTensor | None = None,
        image_grid_thw: torch.LongTensor | None = None,
        images_per_sample: torch.LongTensor | None = None,
        attention_mask: torch.LongTensor | None = None,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Calculate the 3D rope index for image generation task with full batch support.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary.
            image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
                The temporal, height and width of feature shape of each image.
                Images are packed across all samples in the batch.
            images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
                Number of images (including target grids) for each sample in the batch.
                Used to split image_grid_thw by sample.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices.

        Returns:
            position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`):
                Position IDs for temporal, height, and width dimensions.
            mrope_position_deltas (`torch.Tensor` of shape `(batch_size, 1)`):
                Position deltas for multi-modal rotary position embedding.
        """
        batch_size, seq_len = input_ids.shape
        device = input_ids.device
        dtype = input_ids.dtype

        image_start_token_id = self.config.image_start_token_id
        image_end_token_id = self.config.image_end_token_id

        position_ids = torch.ones(3, batch_size, seq_len, dtype=dtype, device=device)
        text_positions = torch.arange(seq_len, device=device)[None, :].repeat(3, 1)

        # Split image_grid_thw by sample if images_per_sample is provided
        if image_grid_thw is not None and images_per_sample is not None:
            grids_per_sample = torch.split(image_grid_thw, images_per_sample.tolist())
        elif image_grid_thw is not None:
            # Fallback: assume all grids belong to first sample (batch_size=1)
            grids_per_sample = [image_grid_thw] * batch_size
        else:
            grids_per_sample = [None] * batch_size

        # Per-sample caches for decode stage
        all_decode_position_ids = []

        for batch_idx in range(batch_size):
            curr_input_ids = input_ids[batch_idx]
            curr_grids = grids_per_sample[batch_idx]

            if attention_mask is not None and attention_mask.shape[1] == seq_len:
                valid_mask = attention_mask[batch_idx] == 1
                curr_input_ids_valid = curr_input_ids[valid_mask]
            else:
                # attention_mask may have different length during assisted decoding
                curr_input_ids_valid = curr_input_ids
                valid_mask = None

            # Find image boundaries in this sample
            image_end_positions = torch.where(curr_input_ids_valid == image_end_token_id)[0]
            image_start_positions = torch.where(curr_input_ids_valid == image_start_token_id)[0] + 1
            num_complete_images = len(image_end_positions)

            current_pos = 0
            prev_image_end = 0
            curr_position_ids = []

            # Process complete images (source images in image-to-image task)
            for img_idx, (start, end) in enumerate(zip(image_start_positions, image_end_positions)):
                if curr_grids is None or img_idx >= len(curr_grids):
                    break
                grid = curr_grids[img_idx]
                # grid format is [temporal, height, width]
                _, height, width = grid.tolist()

                # Text tokens before this image
                llm_pos_length = start - prev_image_end
                llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to(device=device)
                current_pos += llm_position_ids.shape[-1]

                # Image tokens with 2D spatial encoding
                # For an image with height H and width W:
                # - position_width cycles [0, 1, ..., W-1] for each row, repeated H times
                # - position_height stays constant per row, [0]*W, [1]*W, ..., [H-1]*W
                image_seq_length = height * width
                position_width = torch.arange(current_pos, current_pos + width, device=device).repeat(height)
                position_height = torch.arange(current_pos, current_pos + height, device=device).repeat_interleave(
                    width
                )
                position_temporal = torch.full((image_seq_length,), current_pos, device=device, dtype=torch.long)
                vision_position_ids = torch.stack([position_temporal, position_height, position_width], dim=0)
                current_pos += max(height, width)

                prev_image_end = end
                curr_position_ids.append(torch.cat([llm_position_ids, vision_position_ids], dim=-1))

            # Remaining text tokens (including the final image_start token for generation)
            end_position = len(curr_input_ids_valid) - prev_image_end
            llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=device)
            current_pos += llm_position_ids.shape[-1]
            curr_position_ids.append(llm_position_ids)

            # Concatenate all position ids for this sample
            curr_position_ids = torch.cat(curr_position_ids, dim=-1)

            # Store in the main position_ids tensor
            if valid_mask is not None:
                position_ids[:, batch_idx, valid_mask] = curr_position_ids
            else:
                position_ids[:, batch_idx, :] = curr_position_ids

            # Build decode position ids for this sample
            if curr_grids is not None and len(curr_grids) > 0:
                num_decode_grids = len(curr_grids) - num_complete_images
                num_decode_grids = max(num_decode_grids, 0)
                decode_pos = current_pos

                decode_temporal_list = []
                decode_height_list = []
                decode_width_list = []

                curr_grids_list = curr_grids.tolist()
                for i in range(1, num_decode_grids + 1):
                    grid_idx = -i
                    h = curr_grids_list[grid_idx][1]
                    w = curr_grids_list[grid_idx][2]
                    total_tokens = h * w

                    h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten()
                    w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten()

                    decode_temporal_list.append(
                        torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long)
                    )
                    decode_height_list.append(decode_pos + h_indices)
                    decode_width_list.append(decode_pos + w_indices)
                    decode_pos = decode_pos + max(h, w)

                # End marker
                decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
                decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
                decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))

                sample_decode_pos_ids = torch.stack(
                    [
                        torch.cat(decode_temporal_list, dim=0),
                        torch.cat(decode_height_list, dim=0),
                        torch.cat(decode_width_list, dim=0),
                    ],
                    dim=0,
                )
                all_decode_position_ids.append(sample_decode_pos_ids)

        # Store prefill length (same for all samples since input_ids is padded to same length)
        self._prefill_len = seq_len

        # Pad decode position ids to same length and stack
        if all_decode_position_ids:
            max_decode_len = max(x.shape[1] for x in all_decode_position_ids)
            padded_decode_pos_ids = [
                F.pad(pos_ids, (0, max_decode_len - pos_ids.shape[1]), mode="replicate")
                for pos_ids in all_decode_position_ids
            ]
            self._cached_decode_position_ids = torch.stack(padded_decode_pos_ids, dim=0)  # [batch, 3, max_decode_len]
        else:
            self._cached_decode_position_ids = None

        mrope_position_deltas = torch.zeros([batch_size, 1], dtype=dtype, device=device)

        return position_ids, mrope_position_deltas

    @can_return_tuple
    @auto_docstring
    def get_image_features(
        self,
        pixel_values: torch.FloatTensor,
        image_grid_thw: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | BaseModelOutputWithPooling:
        r"""
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
            The tensors corresponding to the input images.
        image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
            The temporal, height and width of feature shape of each image in LLM.
        """
        pixel_values = pixel_values.type(self.visual.dtype)
        vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs)
        split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
        image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes)
        vision_outputs.pooler_output = image_embeds

        return vision_outputs

    def get_placeholder_mask(
        self,
        input_ids: torch.LongTensor,
        image_ids: torch.LongTensor,
    ):
        """
        Replace image placeholder tokens in input_ids with actual image token ids from VQVAE.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`):
                Input token ids with image placeholders.
            image_ids (`torch.LongTensor` of shape `(num_images, num_tokens_per_image)` or flattened):
                Discrete token indices from the VQVAE codebook.

        Returns:
            special_image_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
                Mask indicating positions in input ids that will be replaced by actual image tokens.
        """

        special_image_mask = input_ids == self.config.image_token_id
        n_placeholder_tokens = special_image_mask.sum()
        n_image_tokens = image_ids.shape[0]

        if n_placeholder_tokens != n_image_tokens:
            raise ValueError(
                f"Number of image placeholder tokens ({n_placeholder_tokens.item()}) does not match "
                f"number of image tokens from VQVAE ({n_image_tokens})"
            )

        return special_image_mask

    def compute_3d_position_ids(
        self,
        input_ids: torch.Tensor | None,
        image_grid_thw: torch.Tensor | None,
        images_per_sample: torch.Tensor | None,
        inputs_embeds: torch.Tensor | None,
        attention_mask: torch.Tensor | None,
        past_key_values: torch.Tensor | None,
        cache_position: torch.Tensor | None,
    ) -> torch.Tensor | None:
        past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
        can_compute_mrope = input_ids is not None and image_grid_thw is not None

        if can_compute_mrope and (self.rope_deltas is None or past_key_values_length == 0):
            position_ids, rope_deltas = self.get_rope_index(
                input_ids,
                image_grid_thw=image_grid_thw,
                attention_mask=attention_mask,
                images_per_sample=images_per_sample,
            )
            self.rope_deltas = rope_deltas
        # Use pre-calculated rope-deltas to infer correct 3D position ids
        elif self.rope_deltas is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
            if self._cached_decode_position_ids is not None:
                step = cache_position[0].item() - self._prefill_len
                position_ids = self._cached_decode_position_ids[:, :, step : step + seq_length].permute(1, 0, 2)
            else:
                position_ids = cache_position.view(1, 1, -1).repeat(3, batch_size, 1)
        else:
            # Can't build correct 3D positions. Let the model infer it from `cache_position`
            position_ids = None
        return position_ids

    @auto_docstring
    @can_return_tuple
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        pixel_values: torch.Tensor | None = None,
        image_grid_thw: torch.LongTensor | None = None,
        images_per_sample: torch.LongTensor | None = None,
        rope_deltas: torch.LongTensor | None = None,
        cache_position: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | GlmImageModelOutputWithPast:
        r"""
        image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
            The temporal, height and width of feature shape of each image in LLM.
            Images are packed across all samples in the batch.
        images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Number of images (including target grids) for each sample in the batch.
        rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
            The rope index difference between sequence length and multimodal rope.
        """
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]

        if pixel_values is not None:
            # Process source images (image-to-image mode)
            # Source images are identified by counting image_end_token_id in input_ids
            # Note: We must exclude padding tokens since pad_token_id == image_end_token_id
            if images_per_sample is not None:
                grids_per_sample = torch.split(image_grid_thw, images_per_sample.tolist())
                # Create mask for non-padding tokens (attention_mask=1 means non-padding)
                # Handle 4D attention mask (from static cache) by extracting diagonal
                if attention_mask is not None and attention_mask.ndim == 4:
                    non_pad_mask = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
                    if non_pad_mask.dtype.is_floating_point:
                        non_pad_mask = non_pad_mask / torch.finfo(non_pad_mask.dtype).min
                        non_pad_mask = (1.0 - non_pad_mask).int()
                    # Only keep columns matching input_ids length
                    non_pad_mask = non_pad_mask[:, -input_ids.shape[1] :]
                else:
                    non_pad_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids)

                source_grids_list = []
                is_image_end = input_ids == self.config.image_end_token_id
                is_non_pad = non_pad_mask == 1
                num_source_per_sample = (is_image_end & is_non_pad).sum(dim=1).tolist()
                for sample_idx in range(batch_size):
                    num_source = num_source_per_sample[sample_idx]
                    if num_source > 0:
                        source_grids_list.append(grids_per_sample[sample_idx][:num_source])
                if len(source_grids_list) == 0:
                    raise ValueError(
                        "pixel_values provided but no source images found in input_ids. "
                        "Ensure input_ids contains image_end_token_id for each source image."
                    )
                source_grids = torch.cat(source_grids_list, dim=0)
            else:
                # Fallback for batch_size=1: all but last grid are source images
                source_grids = image_grid_thw[:-1]

            image_features = self.get_image_features(pixel_values, source_grids, return_dict=True)
            image_embeds = torch.cat(image_features.pooler_output, dim=0)
            image_ids = self.get_image_tokens(image_embeds, source_grids)
            image_ids = image_ids.view(-1).to(input_ids.device)
            special_image_mask = self.get_placeholder_mask(input_ids, image_ids)
            input_ids = input_ids.masked_scatter(special_image_mask, image_ids)

        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

        if position_ids is None:
            position_ids = self.compute_3d_position_ids(
                input_ids=input_ids,
                image_grid_thw=image_grid_thw,
                images_per_sample=images_per_sample,
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                cache_position=cache_position,
            )

        outputs = self.language_model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            cache_position=cache_position,
            **kwargs,
        )

        return GlmImageModelOutputWithPast(
            last_hidden_state=outputs.last_hidden_state,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            rope_deltas=self.rope_deltas,
        )

    def get_image_tokens(
        self,
        hidden_states: torch.FloatTensor,
        image_grid_thw: torch.LongTensor,
    ) -> torch.LongTensor:
        """
        Tokenizes image features into discrete tokens with VQVAE module.

        Args:
            hidden_states (`torch.FloatTensor` of shape `(total_patches, hidden_size)`):
                The packed image features from vision encoder.
            image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
                The temporal, height and width of feature shape of each image.

        Returns:
            image_tokens (`torch.LongTensor` of shape `(total_patches,)`):
                Discrete token indices from the VQVAE codebook.
        """
        hidden_size = hidden_states.shape[-1]
        split_sizes = (image_grid_thw.prod(dim=-1)).tolist()
        hidden_states_list = torch.split(hidden_states, split_sizes, dim=0)

        all_image_toks = []
        for i, hs in enumerate(hidden_states_list):
            grid_t, grid_h, grid_w = image_grid_thw[i].tolist()
            hs = hs.view(grid_t, grid_h, grid_w, hidden_size)
            hs = hs.permute(0, 3, 1, 2).contiguous()
            vqmodel_outputs: GlmImageVQVAEModelOutput = self.vqmodel.encode(hs)
            all_image_toks.append(vqmodel_outputs.image_tokens)
        return torch.cat(all_image_toks, dim=0)


@dataclass
@auto_docstring(
    custom_intro="""
    Base class for GlmImage causal language model (or autoregressive) outputs.
    """
)
class GlmImageCausalLMOutputWithPast(ModelOutput):
    r"""
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
        The rope index difference between sequence length and multimodal rope.
    """

    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    past_key_values: Cache | None = None
    hidden_states: tuple[torch.FloatTensor] | None = None
    attentions: tuple[torch.FloatTensor] | None = None
    rope_deltas: torch.LongTensor | None = None


class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin):
    _checkpoint_conversion_mapping = {}
    _tied_weights_keys = {}
    # Reference: fix gemma3 grad acc #37208
    accepts_loss_kwargs = False
    base_model_prefix = "model"
    config: GlmImageConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = GlmImageModel(config)
        self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vision_vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    @auto_docstring
    def get_image_features(
        self,
        pixel_values: torch.FloatTensor,
        image_grid_thw: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | BaseModelOutputWithPooling:
        r"""
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
            The tensors corresponding to the input images.
        image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
            The temporal, height and width of feature shape of each image in LLM.
        """
        return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs)

    def get_image_tokens(self, hidden_states: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
        return self.model.get_image_tokens(hidden_states, image_grid_thw)

    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        labels: torch.LongTensor | None = None,
        pixel_values: torch.Tensor | None = None,
        image_grid_thw: torch.LongTensor | None = None,
        images_per_sample: torch.LongTensor | None = None,
        cache_position: torch.LongTensor | None = None,
        logits_to_keep: int | torch.Tensor = 0,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | GlmImageCausalLMOutputWithPast:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
            The temporal, height and width of feature shape of each image in LLM.
            Images are packed across all samples in the batch.
        images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Number of images (including target grids) for each sample in the batch.

        Example:

        ```python
        >>> from PIL import Image
        >>> import httpx
        >>> from io import BytesIO
        >>> from transformers import AutoProcessor, GlmImageForConditionalGeneration

        >>> model = GlmImageForConditionalGeneration.from_pretrained("zai-org/GLM-Image")
        >>> processor = AutoProcessor.from_pretrained("zai-org/GLM-Image")

        >>> messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": "Add a truck of this photo.<sop>28 40<eop>"},
                ],
            },
        ]
        >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
        >>> with httpx.stream("GET", url) as response:
        ...     image = Image.open(BytesIO(response.read()))

        >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
        ```"""
        outputs = self.model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
            images_per_sample=images_per_sample,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs[0]

        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)

        return GlmImageCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            rope_deltas=outputs.rope_deltas,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        position_ids=None,
        use_cache=True,
        pixel_values=None,
        image_grid_thw=None,
        images_per_sample=None,
        is_first_iteration=False,
        **kwargs,
    ):
        model_inputs = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            cache_position=cache_position,
            position_ids=position_ids,
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
            is_first_iteration=is_first_iteration,
            use_cache=use_cache,
            **kwargs,
        )

        model_inputs["position_ids"] = None
        model_inputs["images_per_sample"] = images_per_sample

        if not is_first_iteration and use_cache:
            model_inputs["pixel_values"] = None

        return model_inputs

    def _get_image_nums(
        self,
        input_ids: torch.LongTensor | None,
    ) -> torch.Tensor:
        """
        Get the number of images for each sample.
        For GLM-Image, only input_ids allow us to get the number of images.

        Returns:
            image_counts (`torch.LongTensor` of shape `(batch_size,)`)
        """
        is_image = input_ids == self.config.image_start_token_id

        return is_image.sum(dim=1)

    def _expand_inputs_for_generation(
        self,
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
        input_ids: torch.LongTensor | None = None,
        **model_kwargs,
    ) -> tuple[torch.LongTensor, dict[str, Any]]:
        # Overwritten -- Support for expanding tensors without a batch size dimension
        # e.g., pixel_values, image_grid_thw
        # pixel_values.shape[0] is sum(seqlen_images for samples)
        # image_grid_thw.shape[0] is sum(num_images for samples)

        if expand_size == 1:
            return input_ids, model_kwargs

        visual_keys = ["pixel_values", "image_grid_thw", "images_per_sample"]

        def _expand_dict_for_generation_visual(dict_to_expand):
            image_grid_thw = model_kwargs.get("image_grid_thw", None)
            if image_grid_thw is None:
                return dict_to_expand

            images_per_sample = model_kwargs.get("images_per_sample", None)

            # Use images_per_sample if available
            if images_per_sample is not None:
                image_nums = images_per_sample.tolist()
            elif input_ids is not None:
                # Try to infer from image_grid_thw / batch_size
                batch_size = input_ids.shape[0]
                total_grids = image_grid_thw.shape[0]
                if total_grids % batch_size == 0:
                    grids_per_sample = total_grids // batch_size
                    image_nums = [grids_per_sample] * batch_size
                else:
                    # Cannot evenly distribute grids - fall back to simple repeat_interleave
                    # This handles test cases where image_grid_thw has (batch_size + 1) rows
                    dict_to_expand["image_grid_thw"] = image_grid_thw.repeat_interleave(expand_size, dim=0)
                    if dict_to_expand.get("pixel_values") is not None:
                        dict_to_expand["pixel_values"] = dict_to_expand["pixel_values"].repeat_interleave(
                            expand_size, dim=0
                        )
                    return dict_to_expand
            else:
                image_nums = self._get_image_nums(input_ids).tolist()

            # Get source image counts per sample from image_end_token_id count
            source_image_nums = (input_ids == self.config.image_end_token_id).sum(dim=1).tolist()

            def _repeat_interleave_samples(x, lengths, repeat_times):
                samples = torch.split(x, lengths)
                repeat_args = [repeat_times] + [1] * (x.dim() - 1)
                result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
                return result

            for key in dict_to_expand:
                if key == "pixel_values":
                    # Split images into samples based on source image counts
                    if sum(source_image_nums) > 0:
                        # Split grids by sample to compute pixel counts
                        grids_per_sample = torch.split(image_grid_thw, image_nums)
                        all_pixel_counts = image_grid_thw.prod(dim=1)
                        pixel_counts_per_sample = torch.split(all_pixel_counts, image_nums)
                        # Build source mask and compute per-sample source pixel counts in one sync
                        source_pixel_counts = torch.zeros(len(grids_per_sample), device=image_grid_thw.device)
                        for batch_idx in range(len(grids_per_sample)):
                            num_source = source_image_nums[batch_idx]
                            if num_source > 0:
                                source_pixel_counts[batch_idx] = pixel_counts_per_sample[batch_idx][:num_source].sum()
                        lengths = source_pixel_counts.to(torch.int64).tolist()

                        dict_to_expand[key] = _repeat_interleave_samples(
                            dict_to_expand[key], lengths=lengths, repeat_times=expand_size
                        )
                elif key == "image_grid_thw":
                    # Expand all grids (source + target) per sample
                    dict_to_expand[key] = _repeat_interleave_samples(
                        dict_to_expand[key], lengths=image_nums, repeat_times=expand_size
                    )
                elif key == "images_per_sample":
                    # Simply repeat the counts
                    if dict_to_expand.get(key) is not None:
                        dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
            return dict_to_expand

        def _expand_dict_for_generation(dict_to_expand):
            for key in dict_to_expand:
                if (
                    key != "cache_position"
                    and dict_to_expand[key] is not None
                    and isinstance(dict_to_expand[key], torch.Tensor)
                    and key not in visual_keys
                ):
                    dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
            return dict_to_expand

        model_kwargs = _expand_dict_for_generation_visual(model_kwargs)

        if input_ids is not None:
            input_ids = input_ids.repeat_interleave(expand_size, dim=0)

        model_kwargs = _expand_dict_for_generation(model_kwargs)

        if is_encoder_decoder:
            if model_kwargs.get("encoder_outputs") is None:
                raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
            model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])

        return input_ids, model_kwargs


__all__ = [
    "GlmImagePreTrainedModel",
    "GlmImageVQVAE",
    "GlmImageVisionModel",
    "GlmImageTextModel",
    "GlmImageModel",
    "GlmImageForConditionalGeneration",
]
