#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/vibevoice_acoustic_tokenizer/modular_vibevoice_acoustic_tokenizer.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_vibevoice_acoustic_tokenizer.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn

from ... import initialization as init
from ...activations import ACT2FN
from ...integrations import use_kernel_forward_from_hub
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple
from .configuration_vibevoice_acoustic_tokenizer import VibeVoiceAcousticTokenizerConfig


@dataclass
@auto_docstring
class VibeVoiceAcousticTokenizerOutput(ModelOutput):
    r"""
    audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`):
        Decoded audio.
    latents (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
        Projected latents (continuous representations for acoustic tokens) at the output of the encoder.
    padding_cache (`VibeVoiceAcousticTokenizerConv1dPaddingCache`, *optional*, returned when `use_cache=True` is passed):
        A [`VibeVoiceAcousticTokenizerConv1dPaddingCache`] instance containing cached convolution states for each decoder
        layer that can be passed to subsequent forward calls.
    """

    audio: torch.FloatTensor | None = None
    latents: torch.FloatTensor | None = None
    padding_cache: Optional["VibeVoiceAcousticTokenizerConv1dPaddingCache"] = None


@dataclass
@auto_docstring
class VibeVoiceAcousticTokenizerEncoderOutput(ModelOutput):
    r"""
    latents (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
        Projected latents (continuous representations for acoustic tokens) at the output of the encoder.
    padding_cache (`VibeVoiceAcousticTokenizerConv1dPaddingCache`, *optional*, returned when `use_cache=True` is passed):
        A [`VibeVoiceAcousticTokenizerConv1dPaddingCache`] instance containing cached convolution states for each encoder
        layer that can be passed to subsequent forward calls.
    """

    latents: torch.FloatTensor | None = None
    padding_cache: Optional["VibeVoiceAcousticTokenizerConv1dPaddingCache"] = None


@dataclass
@auto_docstring
class VibeVoiceAcousticTokenizerDecoderOutput(ModelOutput):
    r"""
    audio (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`):
        Decoded audio.
    padding_cache (`VibeVoiceAcousticTokenizerConv1dPaddingCache`, *optional*, returned when `use_cache=True` is passed):
        A [`VibeVoiceAcousticTokenizerConv1dPaddingCache`] instance containing cached convolution states for each decoder
        layer that can be passed to subsequent forward calls.
    """

    audio: torch.FloatTensor | None = None
    padding_cache: Optional["VibeVoiceAcousticTokenizerConv1dPaddingCache"] = None


@use_kernel_forward_from_hub("RMSNorm")
class VibeVoiceAcousticTokenizerRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps: float = 1e-6) -> None:
        """
        VibeVoiceAcousticTokenizerRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class VibeVoiceAcousticTokenizerFeedForward(nn.Module):
    def __init__(self, config, hidden_size):
        super().__init__()
        self.linear1 = nn.Linear(hidden_size, config.ffn_expansion * hidden_size)
        self.activation = ACT2FN[config.hidden_act]
        self.linear2 = nn.Linear(config.ffn_expansion * hidden_size, hidden_size)

    def forward(self, hidden_states):
        return self.linear2(self.activation(self.linear1(hidden_states)))


class VibeVoiceAcousticTokenizerConv1dPaddingCache:
    """
    Padding cache for VibeVoiceAcousticTokenizerConv1d causal convolutions in order to support streaming via cache padding.
    See: https://huggingface.co/papers/2005.06720 & https://huggingface.co/papers/2204.07064

    A padding cache is a list of cached partial hidden states for each convolution layer.
    Hidden states are cached from the previous call to the VibeVoiceAcousticTokenizerConv1d forward pass, given the padding size.
    """

    def __init__(
        self,
        num_layers: int,
        per_layer_padding: list[int],
        per_layer_padding_mode: list[str],
        per_layer_in_channels: list[int],
    ):
        # ensure correct number of layers for each arg
        from_args_num_layers = {len(per_layer_padding), len(per_layer_padding_mode), len(per_layer_in_channels)}

        if len(from_args_num_layers) != 1 or from_args_num_layers.pop() != num_layers:
            raise ValueError(
                f"Expected `num_layers` ({num_layers}) values in `per_layer_padding`, `per_layer_padding_mode` and `per_layer_in_channels`"
            )

        self.per_layer_padding = per_layer_padding
        self.per_layer_padding_mode = per_layer_padding_mode
        self.per_layer_in_channels = per_layer_in_channels

        self.padding_cache = [None] * num_layers

    def _cache_init(self, hidden_states: torch.Tensor, layer_idx: int):
        """
        Initialize the cache for a specific layer.

        Parameters:
            hidden_states (`torch.Tensor`):
                The hidden states to initialize the cache with.
            layer_idx (`int`):
                The index of the layer to initialize the cache for.
        Returns:
            `torch.Tensor`, the initialized cache.
        """
        batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device
        padding, padding_mode, in_channels = (
            self.per_layer_padding[layer_idx],
            self.per_layer_padding_mode[layer_idx],
            self.per_layer_in_channels[layer_idx],
        )

        if padding_mode == "constant":
            current_cache = torch.zeros(batch_size, in_channels, padding, device=device, dtype=dtype)
        elif padding_mode == "replicate":
            current_cache = (
                torch.ones(batch_size, in_channels, padding, device=device, dtype=dtype) * hidden_states[..., :1]
            )
        else:
            raise NotImplementedError(f"Padding mode {padding_mode} not supported")

        return current_cache

    def update(self, hidden_states: torch.Tensor, layer_idx: int):
        """
        Updates the padding cache with the new padding states for the layer `layer_idx` and returns the current cache.

        Parameters:
            hidden_states (`torch.Tensor`):
                The hidden states to be partially cached.
            layer_idx (`int`):
                The index of the layer to cache the states for.
        Returns:
            `torch.Tensor` or `None`, the current padding cache.
        """
        batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device
        padding, in_channels = self.per_layer_padding[layer_idx], self.per_layer_in_channels[layer_idx]

        if self.padding_cache[layer_idx] is None:
            current_cache = self._cache_init(hidden_states, layer_idx)
        else:
            current_cache = self.padding_cache[layer_idx]

        # update the cache
        if padding > 0:
            shortfall = max(0, padding - hidden_states.shape[-1])
            if shortfall > 0:
                padding_states = torch.cat([current_cache[:, :, -shortfall:], hidden_states], dim=-1)
            else:
                padding_states = hidden_states[:, :, -padding:]
        else:
            padding_states = torch.empty(batch_size, in_channels, 0, dtype=dtype, device=device)

        self.padding_cache[layer_idx] = padding_states
        return current_cache


class VibeVoiceAcousticTokenizerCausalConv1d(nn.Module):
    """Conv1d with built-in causal padding and optional streaming support through a cache."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
        layer_idx: int | None = None,
    ):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation, groups=groups)
        self.causal_padding = (kernel_size - 1) * dilation - (stride - 1)
        if self.causal_padding < 0:
            raise ValueError(
                f"Invalid causal padding {self.causal_padding} for kernel_size={kernel_size}, "
                f"dilation={dilation}, stride={stride}."
            )
        self.layer_idx = layer_idx

    def forward(
        self,
        hidden_states: torch.Tensor,
        padding_cache: VibeVoiceAcousticTokenizerConv1dPaddingCache | None = None,
    ) -> torch.Tensor:
        if padding_cache is not None:
            layer_padding = padding_cache.update(hidden_states, self.layer_idx)
        else:
            layer_padding = torch.zeros(
                hidden_states.shape[0],
                hidden_states.shape[1],
                self.causal_padding,
                device=hidden_states.device,
                dtype=hidden_states.dtype,
            )
        hidden_states = torch.cat([layer_padding, hidden_states], dim=-1)

        return self.conv(hidden_states)


class VibeVoiceAcousticTokenizerCausalConvTranspose1d(nn.Module):
    """ConvTranspose1d with built-in causal padding and optional streaming support through a cache."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        layer_idx: int | None = None,
    ):
        super().__init__()
        self.convtr = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)

        self.stride = stride
        self.layer_idx = layer_idx
        self.padding_total = kernel_size - stride
        self.causal_padding = kernel_size - 1

    def forward(
        self,
        hidden_states: torch.Tensor,
        padding_cache: Optional["VibeVoiceAcousticTokenizerConv1dPaddingCache"] = None,
    ) -> torch.Tensor:
        time_dim = hidden_states.shape[-1]

        if padding_cache is not None:
            layer_padding = padding_cache.update(hidden_states, self.layer_idx)
            hidden_states = torch.cat([layer_padding, hidden_states], dim=-1)
        hidden_states = self.convtr(hidden_states)

        # Remove extra padding at the right side
        if self.padding_total > 0:
            hidden_states = hidden_states[..., : -self.padding_total]

        if padding_cache is not None and layer_padding.shape[2] != 0:
            # For first chunk (layer_padding.shape[2] == 0) return full output
            # for subsequent chunks return only new output
            expected_new_output = time_dim * self.stride
            if hidden_states.shape[2] >= expected_new_output:
                hidden_states = hidden_states[:, :, -expected_new_output:]
        return hidden_states


class VibeVoiceAcousticTokenizerConvNext1dLayer(nn.Module):
    """ConvNeXt-like block adapted for 1D convolutions."""

    def __init__(self, config, hidden_size, dilation=1, stride=1, layer_idx=None):
        super().__init__()

        self.norm = VibeVoiceAcousticTokenizerRMSNorm(hidden_size, eps=config.rms_norm_eps)
        self.ffn_norm = VibeVoiceAcousticTokenizerRMSNorm(hidden_size, eps=config.rms_norm_eps)
        self.ffn = VibeVoiceAcousticTokenizerFeedForward(config, hidden_size)
        self.gamma = nn.Parameter(config.layer_scale_init_value * torch.ones(hidden_size), requires_grad=True)
        self.ffn_gamma = nn.Parameter(config.layer_scale_init_value * torch.ones(hidden_size), requires_grad=True)
        self.mixer = VibeVoiceAcousticTokenizerCausalConv1d(
            in_channels=hidden_size,
            out_channels=hidden_size,
            kernel_size=config.kernel_size,
            groups=hidden_size,
            dilation=dilation,
            stride=stride,
            layer_idx=layer_idx,
        )

    def forward(self, hidden_states, padding_cache=None):
        # mixer
        residual = hidden_states
        hidden_states = self.norm(hidden_states.transpose(1, 2)).transpose(1, 2)
        hidden_states = self.mixer(hidden_states, padding_cache=padding_cache)
        hidden_states = hidden_states * self.gamma.unsqueeze(-1)
        hidden_states = residual + hidden_states

        # ffn
        residual = hidden_states
        hidden_states = self.ffn_norm(hidden_states.transpose(1, 2))
        hidden_states = self.ffn(hidden_states).transpose(1, 2)
        hidden_states = hidden_states * self.ffn_gamma.unsqueeze(-1)
        return residual + hidden_states


class VibeVoiceAcousticTokenizerEncoderStem(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.conv = VibeVoiceAcousticTokenizerCausalConv1d(
            in_channels=config.channels,
            out_channels=config.num_filters,
            kernel_size=config.kernel_size,
            layer_idx=0,
        )
        self.stage = nn.ModuleList(
            [
                VibeVoiceAcousticTokenizerConvNext1dLayer(
                    config,
                    hidden_size=config.num_filters,
                    layer_idx=layer_idx,
                )
                for layer_idx in range(1, config.depths[0] + 1)
            ]
        )

    def forward(self, hidden_states, padding_cache=None):
        hidden_states = self.conv(hidden_states, padding_cache=padding_cache)
        for block in self.stage:
            hidden_states = block(hidden_states, padding_cache=padding_cache)
        return hidden_states


class VibeVoiceAcousticTokenizerEncoderLayer(nn.Module):
    def __init__(self, config, stage_idx):
        super().__init__()

        depth_idx = stage_idx + 1  # first depth is for stem layer
        layer_idx = sum(depth + 1 for depth in config.depths[:depth_idx])
        intermediate_channels = int(config.num_filters * (2 ** (depth_idx)))

        self.conv = VibeVoiceAcousticTokenizerCausalConv1d(
            in_channels=int(config.num_filters * (2**stage_idx)),
            out_channels=intermediate_channels,
            kernel_size=int(config.downsampling_ratios[stage_idx] * 2),
            stride=config.downsampling_ratios[stage_idx],
            layer_idx=layer_idx,
        )
        self.stage = nn.ModuleList(
            [
                VibeVoiceAcousticTokenizerConvNext1dLayer(
                    config, hidden_size=intermediate_channels, layer_idx=layer_idx + offset
                )
                for offset in range(1, config.depths[depth_idx] + 1)
            ]
        )

    def forward(self, hidden_states, padding_cache=None):
        hidden_states = self.conv(hidden_states, padding_cache=padding_cache)
        for block in self.stage:
            hidden_states = block(hidden_states, padding_cache=padding_cache)
        return hidden_states


class VibeVoiceAcousticTokenizerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.stem = VibeVoiceAcousticTokenizerEncoderStem(config)
        self.conv_layers = nn.ModuleList(
            [
                VibeVoiceAcousticTokenizerEncoderLayer(config, stage_idx)
                for stage_idx in range(len(config.downsampling_ratios))
            ]
        )
        self.head = VibeVoiceAcousticTokenizerCausalConv1d(
            in_channels=int(config.num_filters * (2 ** len(config.downsampling_ratios))),
            out_channels=config.hidden_size,
            kernel_size=config.kernel_size,
            layer_idx=sum(depth + 1 for depth in config.depths),
        )

    def forward(self, hidden_states, padding_cache=None):
        hidden_states = self.stem(hidden_states, padding_cache=padding_cache)
        for layer in self.conv_layers:
            hidden_states = layer(hidden_states, padding_cache=padding_cache)
        hidden_states = self.head(hidden_states, padding_cache=padding_cache)
        return hidden_states.permute(0, 2, 1)


class VibeVoiceAcousticTokenizerDecoderStem(nn.Module):
    def __init__(self, config):
        super().__init__()

        intermediate_channels = int(config.num_filters * 2 ** (len(config.decoder_depths) - 1))
        self.conv = VibeVoiceAcousticTokenizerCausalConv1d(
            in_channels=config.hidden_size,
            out_channels=intermediate_channels,
            kernel_size=config.kernel_size,
            layer_idx=0,
        )
        self.stage = nn.ModuleList(
            [
                VibeVoiceAcousticTokenizerConvNext1dLayer(
                    config,
                    hidden_size=intermediate_channels,
                    layer_idx=layer_idx,
                )
                for layer_idx in range(1, config.decoder_depths[0] + 1)
            ]
        )

    def forward(self, hidden_states, padding_cache=None):
        hidden_states = self.conv(hidden_states, padding_cache=padding_cache)
        for block in self.stage:
            hidden_states = block(hidden_states, padding_cache=padding_cache)
        return hidden_states


class VibeVoiceAcousticTokenizerDecoderLayer(nn.Module):
    def __init__(self, config, stage_idx):
        super().__init__()

        depth_idx = stage_idx + 1  # first depth is for stem layer
        layer_idx = sum(depth + 1 for depth in config.decoder_depths[:depth_idx])
        intermediate_channels = int(config.num_filters * (2 ** (len(config.decoder_depths) - 2 - stage_idx)))

        self.convtr = VibeVoiceAcousticTokenizerCausalConvTranspose1d(
            in_channels=int(config.num_filters * (2 ** (len(config.decoder_depths) - 1 - stage_idx))),
            out_channels=intermediate_channels,
            kernel_size=int(config.upsampling_ratios[stage_idx] * 2),
            stride=config.upsampling_ratios[stage_idx],
            layer_idx=layer_idx,
        )
        self.stage = nn.ModuleList(
            [
                VibeVoiceAcousticTokenizerConvNext1dLayer(
                    config, hidden_size=intermediate_channels, layer_idx=layer_idx + offset
                )
                for offset in range(1, config.decoder_depths[depth_idx] + 1)
            ]
        )

    def forward(self, hidden_states, padding_cache=None):
        hidden_states = self.convtr(hidden_states, padding_cache=padding_cache)
        for block in self.stage:
            hidden_states = block(hidden_states, padding_cache=padding_cache)
        return hidden_states


class VibeVoiceAcousticTokenizerDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.stem = VibeVoiceAcousticTokenizerDecoderStem(config)
        self.conv_layers = nn.ModuleList(
            [
                VibeVoiceAcousticTokenizerDecoderLayer(config, stage_idx)
                for stage_idx in range(len(config.upsampling_ratios))
            ]
        )
        self.head = VibeVoiceAcousticTokenizerCausalConv1d(
            in_channels=config.num_filters,
            out_channels=config.channels,
            kernel_size=config.kernel_size,
            layer_idx=sum(depth + 1 for depth in config.decoder_depths),
        )

    def forward(self, hidden_states, padding_cache=None):
        hidden_states = self.stem(hidden_states, padding_cache=padding_cache)
        for layer in self.conv_layers:
            hidden_states = layer(hidden_states, padding_cache=padding_cache)
        hidden_states = self.head(hidden_states, padding_cache=padding_cache)
        return hidden_states


@auto_docstring
class VibeVoiceAcousticTokenizerPreTrainedModel(PreTrainedModel):
    config: VibeVoiceAcousticTokenizerConfig
    base_model_prefix = "vibevoice_acoustic_tokenizer"
    main_input_name = "input_values"
    _no_split_modules = ["VibeVoiceAcousticTokenizerEncoder", "VibeVoiceAcousticTokenizerDecoder"]

    def _init_weights(self, module):
        super()._init_weights(module)
        if isinstance(module, VibeVoiceAcousticTokenizerConvNext1dLayer):
            init.constant_(module.gamma, self.config.layer_scale_init_value)
            init.constant_(module.ffn_gamma, self.config.layer_scale_init_value)


@auto_docstring(
    custom_intro="""
    VibeVoice acoustic tokenizer with an encoder and decoder for continuous acoustic tokens.
    """
)
class VibeVoiceAcousticTokenizerModel(VibeVoiceAcousticTokenizerPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.encoder = VibeVoiceAcousticTokenizerEncoder(config)
        self.decoder = VibeVoiceAcousticTokenizerDecoder(config)
        self.post_init()

    @can_return_tuple
    @auto_docstring
    def encode(self, input_values, padding_cache=None, use_cache=None, sample=True):
        r"""
        input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`):
            Input audio waveform to be encoded into latent representation.
        padding_cache (`VibeVoiceAcousticTokenizerConv1dPaddingCache`, *optional*):
            Cache object for streaming mode to maintain convolution states across layers.
        use_cache (`bool`, *optional*):
            Whether to use caching for convolution states.
        sample (`bool`, *optional*):
            Whether to sample from the VAE. If False, no noise is added.
        """
        if use_cache and padding_cache is None:
            per_layer_padding = [self.encoder.stem.conv.causal_padding]
            per_layer_in_channels = [self.encoder.stem.conv.conv.in_channels]
            per_layer_padding.extend([block.mixer.causal_padding for block in self.encoder.stem.stage])
            per_layer_in_channels.extend([block.mixer.conv.in_channels for block in self.encoder.stem.stage])
            for layer in self.encoder.conv_layers:
                per_layer_padding.append(layer.conv.causal_padding)
                per_layer_in_channels.append(layer.conv.conv.in_channels)
                per_layer_padding.extend([block.mixer.causal_padding for block in layer.stage])
                per_layer_in_channels.extend([block.mixer.conv.in_channels for block in layer.stage])
            per_layer_padding.append(self.encoder.head.causal_padding)
            per_layer_in_channels.append(self.encoder.head.conv.in_channels)

            padding_cache = VibeVoiceAcousticTokenizerConv1dPaddingCache(
                num_layers=len(per_layer_padding),
                per_layer_padding=per_layer_padding,
                per_layer_padding_mode=["constant"] * len(per_layer_padding),
                per_layer_in_channels=per_layer_in_channels,
            )

        latents = self.encoder(input_values, padding_cache=padding_cache)

        if sample:
            noise_std = self.config.vae_std * torch.randn(latents.shape[0], device=latents.device, dtype=latents.dtype)
            latents = latents + noise_std[:, None, None] * torch.randn_like(latents)
        return VibeVoiceAcousticTokenizerEncoderOutput(latents=latents, padding_cache=padding_cache)

    @can_return_tuple
    @auto_docstring
    def decode(self, latents, padding_cache=None, use_cache=False):
        r"""
        latents (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`):
            Input latent representation to be decoded back into audio.
        padding_cache (`VibeVoiceAcousticTokenizerConv1dPaddingCache`, *optional*):
            Cache object for streaming mode to maintain convolution states across layers.
        use_cache (`bool`, *optional*):
            Whether to use caching for convolution states.
        """
        if use_cache and padding_cache is None:
            per_layer_padding = [self.decoder.stem.conv.causal_padding]
            per_layer_in_channels = [self.decoder.stem.conv.conv.in_channels]
            per_layer_padding.extend([block.mixer.causal_padding for block in self.decoder.stem.stage])
            per_layer_in_channels.extend([block.mixer.conv.in_channels for block in self.decoder.stem.stage])
            for layer in self.decoder.conv_layers:
                per_layer_padding.append(layer.convtr.causal_padding)
                per_layer_in_channels.append(layer.convtr.convtr.in_channels)
                per_layer_padding.extend([block.mixer.causal_padding for block in layer.stage])
                per_layer_in_channels.extend([block.mixer.conv.in_channels for block in layer.stage])
            per_layer_padding.append(self.decoder.head.causal_padding)
            per_layer_in_channels.append(self.decoder.head.conv.in_channels)

            padding_cache = VibeVoiceAcousticTokenizerConv1dPaddingCache(
                num_layers=len(per_layer_padding),
                per_layer_padding=per_layer_padding,
                per_layer_padding_mode=["constant"] * len(per_layer_padding),
                per_layer_in_channels=per_layer_in_channels,
            )

        latents = latents.permute(0, 2, 1)
        audio = self.decoder(latents, padding_cache=padding_cache)
        return VibeVoiceAcousticTokenizerDecoderOutput(audio=audio, padding_cache=padding_cache)

    @can_return_tuple
    @auto_docstring
    def forward(self, input_values, padding_cache=None, use_cache=False, sample=True, **kwargs):
        r"""
        input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`):
            Input audio waveform to be encoded into latent representation.
        padding_cache (`VibeVoiceAcousticTokenizerConv1dPaddingCache`, *optional*):
            Cache object for streaming mode to maintain convolution states across layers. Note only used by decoder.
        use_cache (`bool`, *optional*):
            Whether to use caching for convolution states.
        sample (`bool`, *optional*):
            Whether to sample from the VAE latent distribution. If False, no noise is added to the latents.
        """
        encoder_output = self.encode(input_values, sample=sample)
        decoder_output = self.decode(encoder_output.latents, padding_cache=padding_cache, use_cache=use_cache)
        return VibeVoiceAcousticTokenizerOutput(
            audio=decoder_output.audio,
            latents=encoder_output.latents,
            padding_cache=decoder_output.padding_cache,
        )


__all__ = ["VibeVoiceAcousticTokenizerModel", "VibeVoiceAcousticTokenizerPreTrainedModel"]
