#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/pe_audio/modular_pe_audio.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_pe_audio.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...configuration_utils import PreTrainedConfig
from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
from ...masking_utils import create_bidirectional_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from ..auto import AutoModel
from .configuration_pe_audio import PeAudioConfig, PeAudioEncoderConfig


class Snake1d(nn.Module):
    """
    A 1-dimensional Snake activation function module.
    """

    def __init__(self, hidden_dim):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))

    def forward(self, hidden_states):
        shape = hidden_states.shape
        hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
        hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
        hidden_states = hidden_states.reshape(shape)
        return hidden_states


class PeAudioDacResidualUnit(nn.Module):
    """
    A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
    """

    def __init__(self, dimension: int = 16, dilation: int = 1):
        super().__init__()
        pad = ((7 - 1) * dilation) // 2

        self.snake1 = Snake1d(dimension)
        self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
        self.snake2 = Snake1d(dimension)
        self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)

    def forward(self, hidden_state):
        """
        Forward pass through the residual unit.

        Args:
            hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
                Input tensor .

        Returns:
            output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
                Input tensor after passing through the residual unit.
        """
        output_tensor = hidden_state
        output_tensor = self.conv1(self.snake1(output_tensor))
        output_tensor = self.conv2(self.snake2(output_tensor))

        padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
        if padding > 0:
            hidden_state = hidden_state[..., padding:-padding]
        output_tensor = hidden_state + output_tensor
        return output_tensor


class PeAudioDacEncoderBlock(nn.Module):
    """Encoder block used in PE_AUDIO_DAC encoder."""

    def __init__(self, config: PreTrainedConfig, stride: int = 1, stride_index: int = 1):
        super().__init__()

        dimension = config.encoder_hidden_size * 2**stride_index
        self.res_unit1 = PeAudioDacResidualUnit(dimension // 2, dilation=1)
        self.res_unit2 = PeAudioDacResidualUnit(dimension // 2, dilation=3)
        self.res_unit3 = PeAudioDacResidualUnit(dimension // 2, dilation=9)
        self.snake1 = Snake1d(dimension // 2)
        self.conv1 = nn.Conv1d(
            dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
        )

    def forward(self, hidden_state):
        hidden_state = self.res_unit1(hidden_state)
        hidden_state = self.res_unit2(hidden_state)
        hidden_state = self.snake1(self.res_unit3(hidden_state))
        hidden_state = self.conv1(hidden_state)

        return hidden_state


class PeAudioDacEncoder(nn.Module):
    """PE_AUDIO_DAC Encoder"""

    def __init__(self, config: PreTrainedConfig):
        super().__init__()

        strides = config.downsampling_ratios
        # Create first convolution
        self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)

        self.block = []
        # Create EncoderBlocks that double channels as they downsample by `stride`
        for stride_index, stride in enumerate(strides):
            stride_index = stride_index + 1
            self.block += [PeAudioDacEncoderBlock(config, stride=stride, stride_index=stride_index)]

        self.block = nn.ModuleList(self.block)
        d_model = config.encoder_hidden_size * 2**stride_index
        self.snake1 = Snake1d(d_model)
        self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)

    def forward(self, hidden_state):
        hidden_state = self.conv1(hidden_state)

        for module in self.block:
            hidden_state = module(hidden_state)

        hidden_state = self.snake1(hidden_state)
        hidden_state = self.conv2(hidden_state)

        return hidden_state


class PeAudioEncoderEmbedder(nn.Module):
    def __init__(self, config: PeAudioEncoderConfig):
        super().__init__()
        self.dac_encoder = PeAudioDacEncoder(config.dac_config)
        self.bottleneck = nn.Conv1d(config.dac_config.hidden_size, config.dac_config.codebook_dim, 1)
        self.data_proj = nn.Linear(config.dac_config.codebook_dim, config.hidden_size)
        self.config = config

    def forward(
        self,
        input_values: torch.Tensor,
        padding_mask: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        with torch.no_grad(), torch.backends.cudnn.flags(enabled=False):
            hidden_states = self.dac_encoder(input_values)
            hidden_states = self.bottleneck(hidden_states)

        codec_features = hidden_states.transpose(1, 2)
        inputs_embeds = self.data_proj(codec_features)

        if padding_mask is not None:
            padding_mask = padding_mask[:, :: self.config.dac_config.hop_length]

        return inputs_embeds, padding_mask


class PeAudioContrastiveHead(nn.Module):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
    ) -> None:
        super().__init__()
        self.layer_norm = nn.LayerNorm(normalized_shape=in_dim, eps=1e-6)
        self.proj = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.FloatTensor:
        return self.proj(self.layer_norm(x))


class PeAudioMaskedGroupNorm(nn.GroupNorm):
    def forward(self, x, padding_mask=None):
        if padding_mask is None:
            return super().forward(x)

        batch_size, hidden_size, seq_len = x.shape
        group_size = hidden_size // self.num_groups
        grouped_shape = (batch_size, -1, group_size, seq_len)

        x_grouped = x.view(grouped_shape)
        padding_mask_grouped = padding_mask.reshape(grouped_shape).bool()

        mean = torch.masked.mean(x_grouped, mask=padding_mask_grouped, dim=(2, 3), keepdim=True)
        var = torch.masked.var(x_grouped, mask=padding_mask_grouped, dim=(2, 3), keepdim=True, unbiased=False)

        x_norm = (x_grouped - mean) / torch.sqrt(var + self.eps)
        x_norm = x_norm.view(x.shape)

        if self.affine:
            x_norm = x_norm * self.weight.view(1, -1, 1) + self.bias.view(1, -1, 1)

        return x_norm * padding_mask


class PeAudioConvBlock1d(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.groupnorm = PeAudioMaskedGroupNorm(num_groups=1, num_channels=config.hidden_size)
        self.activation = nn.SiLU()
        self.project = nn.Conv1d(
            in_channels=config.hidden_size,
            out_channels=config.hidden_size,
            kernel_size=3,
            padding="same",
        )

    def forward(self, x, padding_mask=None):
        x = self.groupnorm(x, padding_mask=padding_mask)
        x = self.activation(x)
        return self.project(x)


class PeAudioResnetBlock1d(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.block1 = PeAudioConvBlock1d(config)
        self.block2 = PeAudioConvBlock1d(config)

    def forward(self, hidden_states, padding_mask=None):
        """
        Args:
            hidden_states: (batch_size, seq_len, hidden_size)
            padding_mask: (batch_size, seq_len)
        Returns:
            hidden_states: (batch_size, seq_len, hidden_size)
        """
        # transpose for convolutions
        # (batch_size, seq_len, hidden_size) -> (batch_size, hidden_size, seq_len)
        hidden_states = hidden_states.transpose(1, 2)

        if padding_mask is not None:
            padding_mask = padding_mask.unsqueeze(1).expand_as(hidden_states)

        residual = hidden_states
        hidden_states = self.block1(hidden_states, padding_mask=padding_mask)
        hidden_states = self.block2(hidden_states, padding_mask=padding_mask)
        hidden_states = residual + hidden_states

        return hidden_states.transpose(1, 2)


class PeAudioEncoderPatchEmbedder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.resnet_block = PeAudioResnetBlock1d(config)
        self.class_embedding = nn.Parameter(torch.randn(1, 1, config.hidden_size))

    def forward(self, inputs_embeds, padding_mask=None):
        # Embedding step: prepend class token and run the ResNet block.
        hidden_states = torch.cat(
            [self.class_embedding.expand(inputs_embeds.size(0), -1, -1), inputs_embeds],
            dim=1,
        )

        if padding_mask is not None:
            # TODO: any reason why we take padding_mask[0] and not just 1?
            padding_mask = torch.cat([padding_mask[:, [0]], padding_mask], dim=1)

        hidden_states = self.resnet_block(hidden_states, padding_mask=padding_mask)
        return hidden_states, padding_mask


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor | None,
    scaling: float,
    dropout: float = 0.0,
    **kwargs: Unpack[TransformersKwargs],
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


def stack_freqs(cos: torch.Tensor, sin: torch.Tensor):
    dim = cos.size(-1)
    cos = cos.narrow(-1, 0, dim // 2)
    sin = sin.narrow(-1, 0, dim // 2)
    freqs_cis = torch.stack((cos, -sin, sin, cos), dim=-1).view(*cos.size(), 2, 2)
    return freqs_cis


def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    freqs_cis = stack_freqs(cos, sin)
    freqs_cis = freqs_cis.unsqueeze(unsqueeze_dim)
    q_ = q.reshape(*q.shape[:-1], -1, 1, 2)
    k_ = k.reshape(*k.shape[:-1], -1, 1, 2)
    return (q_ * freqs_cis).sum(5).flatten(3), (k_ * freqs_cis).sum(5).flatten(3)


@use_kernel_forward_from_hub("RMSNorm")
class PeAudioEncoderRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps: float = 1e-6) -> None:
        """
        PeAudioEncoderRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


@use_kernelized_func(apply_rotary_pos_emb)
class PeAudioEncoderAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config, layer_idx):
        super().__init__()
        self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = False

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )
        self.q_norm = PeAudioEncoderRMSNorm(
            self.head_dim, eps=config.rms_norm_eps
        )  # unlike olmo, only on the head dim!
        self.k_norm = PeAudioEncoderRMSNorm(
            self.head_dim, eps=config.rms_norm_eps
        )  # thus post q_norm does not need reshape

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
        attention_mask: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class PeAudioEncoderMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


class PeAudioEncoderLayer(GradientCheckpointingLayer):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = PeAudioEncoderAttention(config=config, layer_idx=layer_idx)

        self.mlp = PeAudioEncoderMLP(config)
        self.input_layernorm = PeAudioEncoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = PeAudioEncoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        use_cache: bool | None = False,
        cache_position: torch.LongTensor | None = None,
        position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        # Self Attention
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


@auto_docstring
class PeAudioPreTrainedModel(PreTrainedModel):
    config: PeAudioConfig
    base_model_prefix = "audio_model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["PeAudioEncoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn = True
    _supports_sdpa = True
    _supports_flex_attn = True

    _can_compile_fullgraph = True
    _supports_attention_backend = True
    _can_record_outputs = {
        "hidden_states": PeAudioEncoderLayer,
        "attentions": PeAudioEncoderAttention,
    }

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)

        if hasattr(self.config, "initializer_range"):
            std = self.config.initializer_range
        else:
            # 0.02 is the standard default value across the library
            std = getattr(self.config.get_text_config(), "initializer_range", 0.02)

        if isinstance(module, PeAudioEncoderPatchEmbedder):
            embed_dim = module.class_embedding.shape[-1]
            init.normal_(module.class_embedding, mean=0.0, std=embed_dim**-0.5 * std)
        if isinstance(module, nn.Conv1d):
            init.trunc_normal_(module.weight, std=0.02)
            init.constant_(module.bias, 0)
        elif isinstance(module, Snake1d):
            init.ones_(module.alpha)
        elif isinstance(module, nn.ConvTranspose1d):
            module.reset_parameters()
        elif isinstance(module, nn.Embedding):
            init.normal_(module.weight, mean=0.0, std=0.02)


@dataclass
@auto_docstring(
    custom_intro="""
    Class for outputs of [`PeAudioEncoder`].
    """
)
class PeAudioEncoderOutput(BaseModelOutputWithPooling):
    codec_features: torch.FloatTensor | None = None
    output_mask: tuple[torch.FloatTensor] | None = None


class PeAudioEncoderRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config: PeAudioEncoderConfig, device=None):
        super().__init__()
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config

        self.rope_type = self.config.rope_parameters["rope_type"]
        rope_init_fn: Callable = self.compute_default_rope_parameters
        if self.rope_type != "default":
            rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
        inv_freq, self.attention_scaling = rope_init_fn(self.config, device)

        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)

    @staticmethod
    def compute_default_rope_parameters(
        config: PeAudioEncoderConfig | None = None,
        device: Optional["torch.device"] = None,
        seq_len: int | None = None,
    ) -> tuple["torch.Tensor", float]:
        """
        Computes the inverse frequencies according to the original RoPE implementation
        Args:
            config ([`~transformers.PreTrainedConfig`]):
                The model configuration.
            device (`torch.device`):
                The device to use for initialization of the inverse frequencies.
            seq_len (`int`, *optional*):
                The current sequence length. Unused for this type of RoPE.
        Returns:
            Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
            post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
        """
        base = config.rope_parameters["rope_theta"]
        dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads

        attention_factor = 1.0  # Unused in this type of RoPE

        # Compute the inverse frequencies
        inv_freq = 1.0 / (
            base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
        )
        return inv_freq, attention_factor

    @torch.no_grad()
    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with maybe_autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


@auto_docstring(
    custom_intro="""
    The PeAudio Encoder model.
    """
)
class PeAudioEncoder(PeAudioPreTrainedModel):
    config: PeAudioEncoderConfig
    main_input_name = "input_values"
    base_model_prefix = "audio_model.audio_encoder"

    def __init__(self, config: PeAudioEncoderConfig):
        super().__init__(config)
        self.embedder = PeAudioEncoderEmbedder(config)
        self.patch_embedder = PeAudioEncoderPatchEmbedder(config)
        self.layers = nn.ModuleList(
            [PeAudioEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = PeAudioEncoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = PeAudioEncoderRotaryEmbedding(config=config)
        self.output = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
        self.gradient_checkpointing = False

        self.post_init()

    @can_return_tuple
    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_values: torch.Tensor,
        padding_mask: torch.Tensor | None = None,
        **kwargs,
    ) -> tuple | BaseModelOutputWithPooling:
        inputs_embeds, padding_mask = self.embedder(input_values, padding_mask=padding_mask)
        inputs_embeds, attention_mask = self.patch_embedder(inputs_embeds, padding_mask=padding_mask)

        if attention_mask is not None:
            attention_mask = create_bidirectional_mask(
                config=self.config,
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
            )

        position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
        position_embeddings = self.rotary_emb(inputs_embeds, position_ids)

        hidden_states = inputs_embeds
        for encoder_layer in self.layers[: self.config.num_hidden_layers]:
            hidden_states = encoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_embeddings=position_embeddings,
                **kwargs,
            )

        hidden_states = self.norm(hidden_states)
        hidden_states = self.output(hidden_states)

        return PeAudioEncoderOutput(
            last_hidden_state=hidden_states[:, 1:],
            pooler_output=hidden_states[:, 0],
            output_mask=padding_mask,
        )


# TODO: not sure about the typing for text_model_output
@dataclass
# @auto_docstring
class PeAudioOutput(ModelOutput):
    loss: torch.FloatTensor | None = None
    logits_audio_text: torch.FloatTensor | None = None
    text_audio_embeds: torch.FloatTensor | None = None
    audio_embeds: torch.FloatTensor | None = None
    text_outputs: BaseModelOutputWithPooling = None
    audio_outputs: BaseModelOutputWithPooling = None

    def to_tuple(self) -> tuple[Any]:
        return tuple(
            self[k] if k not in ["text_outputs", "audio_outputs"] else getattr(self, k).to_tuple() for k in self.keys()
        )


class PeAudioModel(PeAudioPreTrainedModel):
    def __init__(self, config: PeAudioConfig):
        super().__init__(config)
        self.text_model = AutoModel.from_config(config.text_config)
        self.audio_encoder = PeAudioEncoder(config.audio_config)

        self.text_audio_head = PeAudioContrastiveHead(config.text_config.hidden_size, config.text_config.hidden_size)
        self.audio_head = PeAudioContrastiveHead(config.audio_config.hidden_size, config.text_config.hidden_size)

        self.text_audio_logit_scale = nn.Parameter(torch.zeros(1))
        self.text_audio_logit_bias = nn.Parameter(torch.zeros(1))

        self.post_init()

    def get_text_audio_embeds(self, input_ids, attention_mask=None):
        # TODO: naming can be improved here...
        text_outputs: MaskedLMOutput = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        text_audio_embeds = text_outputs.hidden_states[-1][:, 0]
        return self.text_audio_head(text_audio_embeds)

    def get_audio_embeds(self, input_values, padding_mask=None):
        audio_outputs: BaseModelOutputWithPooling = self.audio_encoder(
            input_values=input_values,
            padding_mask=padding_mask,
            return_dict=True,
        )
        audio_embeds = audio_outputs.pooler_output
        return self.audio_head(audio_embeds)

    @can_return_tuple
    def forward(
        self,
        input_ids: torch.Tensor,
        input_values: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        padding_mask: torch.Tensor | None = None,
        return_loss: bool | None = None,
        **kwargs,
    ) -> PeAudioOutput:
        audio_outputs: BaseModelOutputWithPooling = self.audio_encoder(
            input_values=input_values, padding_mask=padding_mask, **kwargs
        )

        kwargs["output_hidden_states"] = True
        text_outputs: MaskedLMOutput = self.text_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

        audio_embeds = audio_outputs.pooler_output
        audio_embeds = self.audio_head(audio_embeds)

        text_audio_embeds = text_outputs.hidden_states[-1][:, 0]
        text_audio_embeds = self.text_audio_head(text_audio_embeds)

        logits_audio_text = audio_embeds @ text_audio_embeds.T
        logits_audio_text = logits_audio_text * self.text_audio_logit_scale.to(
            logits_audio_text.device
        ) + self.text_audio_logit_bias.to(logits_audio_text.device)

        loss = None
        if return_loss:
            labels = torch.eye(logits_audio_text.shape[0], device=logits_audio_text.device)
            loss = -F.logsigmoid(labels * logits_audio_text).sum() / logits_audio_text.shape[0]

        return PeAudioOutput(
            logits_audio_text=logits_audio_text,
            text_audio_embeds=text_audio_embeds,
            audio_embeds=audio_embeds,
            text_outputs=text_outputs,
            audio_outputs=audio_outputs,
            loss=loss,
        )


# TODO: underline in documentation that logits output shape is
# 1. Model: (n_audio, n_text)
# 2. Frame-level: (n_audio, n_text, n_frames)
class PeAudioFrameLevelModel(PeAudioModel):
    def get_audio_embeds(self, input_values, padding_mask=None):
        audio_outputs: BaseModelOutputWithPooling = self.audio_encoder(
            input_values=input_values,
            padding_mask=padding_mask,
            return_dict=True,
        )
        audio_embeds = audio_outputs.last_hidden_state
        audio_embeds = self.audio_head(audio_embeds)
        return audio_embeds

    @can_return_tuple
    def forward(
        self,
        input_ids: torch.Tensor,
        input_values: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        padding_mask: torch.Tensor | None = None,
        return_loss: bool | None = None,
        **kwargs,
    ) -> PeAudioOutput:
        audio_outputs: BaseModelOutputWithPooling = self.audio_encoder(
            input_values=input_values, padding_mask=padding_mask, **kwargs
        )
        kwargs["output_hidden_states"] = True
        text_outputs: MaskedLMOutput = self.text_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

        audio_embeds = audio_outputs.last_hidden_state
        audio_embeds = self.audio_head(audio_embeds)

        text_audio_embeds = text_outputs.hidden_states[-1][:, 0]
        text_audio_embeds = self.text_audio_head(text_audio_embeds)

        logits_audio_text = (audio_embeds @ text_audio_embeds.T).transpose(1, 2)
        logits_audio_text = logits_audio_text * self.text_audio_logit_scale + self.text_audio_logit_bias

        loss = None
        if return_loss:
            labels = torch.eye(logits_audio_text.shape[0], device=logits_audio_text.device)
            loss = -F.logsigmoid(labels * logits_audio_text).sum() / logits_audio_text.shape[0]

        return PeAudioOutput(
            logits_audio_text=logits_audio_text,
            text_audio_embeds=text_audio_embeds,
            audio_embeds=audio_embeds,
            text_outputs=text_outputs,
            audio_outputs=audio_outputs,
            loss=loss,
        )


__all__ = ["PeAudioFrameLevelModel", "PeAudioModel", "PeAudioEncoder"]
