# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import math
import operator
import os
import re
from functools import reduce

from ..distributed import DistributedConfig
from ..utils import is_torch_greater_or_equal, logging
from ..utils.generic import GeneralInterface
from ..utils.import_utils import is_torch_available


if is_torch_available():
    import torch
    import torch.distributed as dist
    from torch import nn

    # Cache this result has it's a C FFI call which can be pretty time-consuming
    _torch_distributed_available = torch.distributed.is_available()


logger = logging.get_logger(__name__)


def initialize_tensor_parallelism(
    tp_plan: str | dict[str, str] | None, tp_size: int | None = None, device_mesh=None, device_map=None
):
    r"""
    Sets up the device mesh and initialized the backend for tensor parallelism.
    This function is called when the model is loaded and the TP plan is set to 'auto'.
    """
    if tp_size is not None and tp_plan is None:
        raise ValueError("tp_plan has to be set when tp_size is passed.")
    if tp_plan is not None and device_map is not None:
        raise ValueError("`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization.")
    if device_mesh is None:
        if not is_torch_greater_or_equal("2.5"):
            raise OSError("Tensor parallel is only supported for `torch>=2.5`.")

        # Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
        device_type = torch._C._get_accelerator().type
        if device_type == "mps":
            device_type = "cpu"  # fallback
        current_device = getattr(torch, device_type)
        if not torch.distributed.is_initialized():
            try:
                rank = int(os.environ["RANK"])
                local_rank = int(os.environ["LOCAL_RANK"])
                world_size = int(os.environ["WORLD_SIZE"])

                backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
                backend = backend_map.get(device_type)

                torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
                current_device = getattr(torch, device_type)
                if device_type != "cpu":
                    current_device.set_device(local_rank)

            except Exception as e:
                raise OSError(
                    "We tried to initialize torch.distributed for you, but it failed. Make "
                    "sure you init torch distributed in your script to use `tp_plan`."
                ) from e

        if device_type != "cpu":
            current_device.set_device(int(os.environ["LOCAL_RANK"]))
            index = current_device.current_device()
            tp_device = torch.device(device_type, index)
            device_map = tp_device
            # Silence output for non-primary ranks
            if index > 0:
                import sys

                sys.stdout = open(os.devnull, "w")
                sys.stderr = open(os.devnull, "w")

        else:
            tp_device = torch.device(device_type)
            device_map = device_type or {}

        tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
        device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
    else:
        if device_mesh.ndim > 1:
            if "tp" not in device_mesh.mesh_dim_names:
                raise ValueError(
                    "When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
                    "Please provide a valid `device_mesh`."
                )
            device_mesh = device_mesh["tp"]
        tp_size = device_mesh.size()
        device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")

    return device_map, device_mesh, tp_size


def replace_layer_number_by_wildcard(name: str) -> str:
    """
    Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
    a dot (`.`) and the end of the string.
    This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
    numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
    """
    return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)


def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
    """
    Get the TP style for a parameter from the TP plan.

    The TP plan is a dictionary that maps parameter names to TP styles.
    The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").

    The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
    not parent classes for `post_init` calls
    """
    generic_param_name = replace_layer_number_by_wildcard(parameter_name)
    if generic_param_name in tp_plan:
        return tp_plan[generic_param_name]
    elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
        return tp_plan[module_name]
    return None


# =============================================================================
# Tensor Sharding Utilities
# =============================================================================


if is_torch_available():
    str_to_dtype = {
        "BOOL": torch.bool,
        "U8": torch.uint8,
        "I8": torch.int8,
        "I16": torch.int16,
        "F16": torch.float16,
        "BF16": torch.bfloat16,
        "I32": torch.int32,
        "F32": torch.float32,
        "F64": torch.float64,
        "I64": torch.int64,
        "F8_E4M3": torch.float8_e4m3fn,
    }


def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]:
    """
    Convert block count or proportions to block sizes.

    This function accepts

    - The number of blocks (int), in which case the block size is
      total_size//blocks; or
    - A list of block sizes (list[int]).

    In the second case, if sum(blocks) < total_size, the ratios between
    the block sizes will be preserved. For instance, if blocks is
    [2, 1, 1] and total_size is 1024, the returned block sizes are
    [512, 256, 256].
    """
    if isinstance(blocks, list):
        total_blocks = sum(blocks)
        assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
        part_size = total_size // total_blocks
        return [part_size * block for block in blocks]
    else:
        assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
        single_size = total_size // blocks
        return [single_size] * blocks


def get_packed_weights(param, empty_param, device_mesh, rank, dim):
    """
    When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
    So if you have: gate_proj       ( 16, 5120, 8190)
    and             up_proj         ( 16, 5120, 8190)
    packed as       gate_up_proj    ( 16, 5120, 2 * 8190)
    And you shard along the last dimension, you need to interleave the gate and up values:

    Now, if we shard along the last dimension across TP_size (Tensor Parallelism size), we must interleave the values from gate and up projections correctly.

    Let's take TP_size = 4 for an example:

    Packed tensor `gate_up_proj`
    ---------------------------------------------------------------
    [ G0  G1  G2  G3 | G4  G5  G6  G7 | ... | U0  U1  U2  U3 | U4  U5  U6  U7 | ... ]
     ↑─────────────↑   ↑─────────────↑        ↑─────────────↑  ↑─────────────↑
       Gate Slice 0      Gate Slice 1            Up Slice 0       Up Slice 1

    Explanation:
    - The first half of the tensor (left of the center) holds the gate_proj values.
    - The second half (right of the center) holds the up_proj values.
    - For TP=4, we divide each half into 4 slices. In this example, we show two slices for brevity.
    - Each shard receives one slice from the gate part and the corresponding slice from the up part.

    For instance:
    • Shard 0 gets: [ Gate Slice 0, Up Slice 0 ] = [ G0, G1, G2, G3, U0, U1, U2, U3 ]
    • Shard 1 gets: [ Gate Slice 1, Up Slice 1 ] = [ G4, G5, G6, G7, U4, U5, U6, U7 ]
    • … and so on.

    This ensures that each shard receives an equal portion of both gate and up projections, maintaining consistency across tensor parallelism.
    """
    slice_ = param
    total_size = empty_param.shape[dim]
    world_size = device_mesh.size()
    block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=2)

    tensors_slices = []
    block_offset = 0
    for block_size in block_sizes:
        shard_block_size = block_size // world_size
        start = rank * shard_block_size
        stop = (rank + 1) * shard_block_size
        tensors_slices += range(block_offset + start, block_offset + stop)
        block_offset += block_size

    slice_dtype = slice_.get_dtype()
    # Handle F8_E4M3 dtype by converting to float16 before slicing
    # Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
    casted = False
    if slice_dtype == "F8_E4M3" or slice_dtype == "F8_E5M2":
        slice_ = slice_[...].to(torch.float16)
        casted = True

    if dim == 0:
        tensor = slice_[tensors_slices, ...]
    elif dim == 1 or dim == -2:
        tensor = slice_[:, tensors_slices, ...]
    elif dim == 2 or dim == -1:
        tensor = slice_[..., tensors_slices]
    else:
        raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")

    if casted:
        return tensor
    else:
        return tensor.to(str_to_dtype[slice_dtype])


def repack_weights(
    packed_parameter: torch.Tensor,
    sharded_dim: int,  # The dimension index in the global tensor that was sharded
    world_size: int,
    num_blocks: int = 2,
) -> torch.Tensor:
    """
    Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.

    For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
    DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
    along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
    This is an inverse operation to get_packed_weights.

    Args:
        reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
        sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
        world_size: The tensor parallel world size.
        num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).

    Returns:
        The reordered tensor in canonical packed format.
    """

    if num_blocks != 2:
        raise ValueError(
            "Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
        )

    actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
    total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
    original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
    shard_chunk_size = original_block_size_on_dim // world_size

    prefix_shape = packed_parameter.shape[:actual_sharded_dim]
    suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]

    tensor_view = packed_parameter.view(
        *prefix_shape,
        world_size,
        num_blocks,
        shard_chunk_size,
        *suffix_shape,
    )

    # Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
    # This groups all chunks of G together, then all chunks of U together.
    # Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
    # Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
    # Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
    axis_ws_abs = len(prefix_shape)
    axis_npp_abs = len(prefix_shape) + 1

    permute_order = list(range(tensor_view.ndim))
    permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]

    tensor_permuted = tensor_view.permute(*permute_order)

    # Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
    # The final shape should be the same as reconstructed_tensor.
    final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)

    return final_ordered_tensor


def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: int | None = None):
    """
    Generalized tensor sharding across a multi-dimensional device mesh.
    Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
    Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
    `Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
    such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.

    Case (1)
    empty_param                 (16, 5120, 8190)
    dim                         0
    device_mesh.size()          4
    rank 0 gets					(4, 5120, 8190)			 (0 ... 4, 5120, 8190)
    rank 1 gets					(4, 5120, 8190)			 (4 ... 8, 5120, 8190)
    rank 2 gets					(4, 5120, 8190)			 (8 ... 12, 5120, 8190)
    rank 3 gets					(4, 5120, 8190)			 (12 ... 16, 5120, 8190)

    Case (2)
    empty_param                 (16, 5120, 8190)
    dim                         0
    device_mesh.size()          14
    rank 0 gets					(2, 5120, 8190)			 (0 ... 2, 5120, 8190)
    rank 1 gets					(2, 5120, 8190)			 (2 ... 4, 5120, 8190)
    rank 2 gets					(2, 5120, 8190)			 (4 ... 6, 5120, 8190)
    rank 3 gets					(2, 5120, 8190)			 (6 ... 8, 5120, 8190)
    rank 4 gets					(2, 5120, 8190)			 (8 ... 10, 5120, 8190)
    rank 5 gets					(2, 5120, 8190)			 (10 ... 12, 5120, 8190)
    rank 6 gets					(2, 5120, 8190)			 (12 ... 14, 5120, 8190)
    rank 7 gets					(2, 5120, 8190)			 (14 ... 16, 5120, 8190)
    rank 8 gets					(0, 5120, 8190)
    rank 9 gets					(0, 5120, 8190)
    rank 10 gets			    (0, 5120, 8190)
    rank 11 gets				(0, 5120, 8190)
    rank 12 gets				(0, 5120, 8190)
    rank 13 gets				(0, 5120, 8190)

    Case (3)
    empty_param                 (16, 5120, 8190)
    dim                         0
    device_mesh.size()          3
    rank 0 gets					(6, 5120, 8190)			 (0 ... 6, 5120, 8190)
    rank 1 gets					(6, 5120, 8190)			 (6 ... 12, 5120, 8190)
    rank 2 gets					(4, 5120, 8190)			 (12 ... 16, 5120, 8190)

    In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
    Args:
        param (torch.Tensor): The tensor to shard.
        empty_param (torch.Tensor): A tensor used for shape reference.
        device_mesh (torch.Tensor): Shape [d_0, ..., d_n] representing the mesh.
        rank (int): Global rank of the current process/device.
        dim (int): Dimension along which to shard the tensor.
    """
    param_dim = empty_param.ndim
    mesh_shape = device_mesh.shape
    world_size = reduce(operator.mul, mesh_shape)
    # Get param shape: works for both torch.Tensor and safetensors TensorInfo
    param_shape = list(param.shape) if isinstance(param, torch.Tensor) else param.get_shape()
    if dim < 0:
        dim = param_dim + dim
    if empty_param.dim() == 3 and dim == 1 and len(param_shape) == 2:
        dim = 0
    elif empty_param.dim() == 3 and dim == 2 and len(param_shape) == 2:
        dim = 1

    shard_size = math.ceil(param_shape[dim] / world_size)
    start = rank * shard_size
    end = min(start + shard_size, param_shape[dim])

    if dim >= param_dim:
        raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")

    if rank >= world_size:
        raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")

    # we have the full tensor not 1 part of it.
    # in that case, we just assume that the weight was properly saved
    # and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
    # to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
    # here we take care of potential chunking / layer split / layer chunking.
    # The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
    # actually we still shard dim=0 does not change
    # so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
    # tensor on a certain device (with the input tensor_index)
    if tensor_idx is not None and empty_param.dim() == 3 and dim == 0 and len(param_shape) == 2:
        # special case we don't "shard" just send this entire tensor to the correct rank.
        if start <= tensor_idx < end:
            # this tensor does need to be materialized on this device:
            return param[:]
        else:
            return torch.empty([], dtype=torch.int64, device=rank)

    slice_indices = [slice(None)] * len(param_shape)

    if start < param_shape[dim]:
        slice_indices[dim] = slice(start, end)
        param = param[tuple(slice_indices)]
        if isinstance(param, list):  # TODO handle the modulelist case!
            param = [p[:] for p in param]
        return param

    param_shape[dim] = 0
    return torch.empty(tuple(param_shape), dtype=torch.int64)  # empty allocates memory....


def _split_along_last_dim(x, world_size):
    """Split tensor along last dimension into world_size chunks."""
    return torch.chunk(x, world_size, dim=-1)


# =============================================================================
# Distributed Communication Primitives
# =============================================================================
#
# Naming convention:
#   - Functions describe their FORWARD behavior
#   - Backward behavior is the "conjugate" operation for gradient flow
#
# Available operations:
#   ┌────────────────────┬─────────────────────┬─────────────────────┐
#   │ Function           │ Forward             │ Backward            │
#   ├────────────────────┼─────────────────────┼─────────────────────┤
#   │ all_reduce         │ all-reduce (sum)    │ identity            │
#   │ all_reduce_backward│ identity            │ all-reduce (sum)    │
#   │ all_gather         │ all-gather          │ split (local chunk) │
#   │ split              │ split (local chunk) │ all-gather          │
#   │ reduce_scatter     │ reduce-scatter      │ all-gather          │
#   └────────────────────┴─────────────────────┴─────────────────────┘
# ===================


class _AllReduceBackward(torch.autograd.Function):
    """Identity forward, all-reduce backward. Used before colwise layers (f in Megatron)."""

    @staticmethod
    def forward(ctx, x, device_mesh):
        ctx.device_mesh = device_mesh
        return x

    @staticmethod
    def backward(ctx, grad_output):
        device_mesh = ctx.device_mesh
        if device_mesh.size() == 1:
            return grad_output, None
        dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
        return grad_output, None


class _AllReduceForward(torch.autograd.Function):
    """All-reduce forward, identity backward. Used after rowwise layers (g in Megatron)."""

    @staticmethod
    def forward(ctx, x, device_mesh):
        if device_mesh.size() == 1:
            return x
        dist.all_reduce(x, op=dist.ReduceOp.SUM, group=device_mesh.get_group())
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None


class _AllGather(torch.autograd.Function):
    """All-gather forward, split backward. Gathers sharded outputs."""

    @staticmethod
    def forward(ctx, x, device_mesh):
        ctx.device_mesh = device_mesh
        world_size = device_mesh.size()

        if world_size == 1:
            return x

        last_dim = x.dim() - 1
        rank = device_mesh.get_local_rank()
        group = device_mesh.get_group()

        x = x.contiguous()
        tensor_list = [torch.empty_like(x) for _ in range(world_size)]
        tensor_list[rank] = x
        dist.all_gather(tensor_list, x, group=group)
        return torch.cat(tensor_list, dim=last_dim).contiguous()

    @staticmethod
    def backward(ctx, grad_output):
        device_mesh = ctx.device_mesh
        world_size = device_mesh.size()

        if world_size == 1:
            return grad_output, None

        rank = device_mesh.get_local_rank()
        chunks = _split_along_last_dim(grad_output, world_size)
        return chunks[rank].contiguous(), None


class _Split(torch.autograd.Function):
    """Split forward, all-gather backward. Scatters replicated input."""

    @staticmethod
    def forward(ctx, x, device_mesh):
        ctx.device_mesh = device_mesh
        world_size = device_mesh.size()

        if world_size == 1:
            return x

        rank = device_mesh.get_local_rank()
        chunks = _split_along_last_dim(x, world_size)
        return chunks[rank].contiguous()

    @staticmethod
    def backward(ctx, grad_output):
        device_mesh = ctx.device_mesh
        world_size = device_mesh.size()

        if world_size == 1:
            return grad_output, None

        last_dim = grad_output.dim() - 1
        rank = device_mesh.get_local_rank()
        group = device_mesh.get_group()

        grad_output = grad_output.contiguous()
        tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
        tensor_list[rank] = grad_output
        dist.all_gather(tensor_list, grad_output, group=group)
        return torch.cat(tensor_list, dim=last_dim).contiguous(), None


class _ReduceScatter(torch.autograd.Function):
    """Reduce-scatter forward, all-gather backward. For sequence parallel."""

    @staticmethod
    def forward(ctx, x, device_mesh):
        ctx.device_mesh = device_mesh
        world_size = device_mesh.size()

        if world_size == 1:
            return x

        last_dim = x.dim() - 1
        group = device_mesh.get_group()

        input_chunks = list(x.chunk(world_size, dim=last_dim))
        output_shape = list(x.shape)
        output_shape[last_dim] //= world_size
        output = torch.empty(output_shape, dtype=x.dtype, device=x.device)

        dist.reduce_scatter(output, input_chunks, op=dist.ReduceOp.SUM, group=group)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        device_mesh = ctx.device_mesh
        world_size = device_mesh.size()

        if world_size == 1:
            return grad_output, None

        last_dim = grad_output.dim() - 1
        rank = device_mesh.get_local_rank()
        group = device_mesh.get_group()

        grad_output = grad_output.contiguous()
        tensor_list = [torch.empty_like(grad_output) for _ in range(world_size)]
        tensor_list[rank] = grad_output
        dist.all_gather(tensor_list, grad_output, group=group)
        return torch.cat(tensor_list, dim=last_dim).contiguous(), None


# =============================================================================
# Convenience wrappers
# =============================================================================


def all_reduce_backward(x, device_mesh):
    """Identity forward, all-reduce backward. Use before colwise layers."""
    return _AllReduceBackward.apply(x, device_mesh)


def all_reduce_forward(x, device_mesh):
    """All-reduce forward, identity backward. Use after rowwise layers."""
    return _AllReduceForward.apply(x, device_mesh)


def all_gather(x, device_mesh):
    """All-gather forward, split backward."""
    return _AllGather.apply(x, device_mesh)


def split(x, device_mesh):
    """Split forward, all-gather backward."""
    return _Split.apply(x, device_mesh)


def reduce_scatter(x, device_mesh):
    """Reduce-scatter forward, all-gather backward."""
    return _ReduceScatter.apply(x, device_mesh)


def distribute_module(
    module: nn.Module,
    device_mesh=None,
    input_fn=None,
    output_fn=None,
) -> nn.Module:
    """
    Copy pasted from torch's function but we remove the communications (partitioning)
    as well as buffer registering that is similarly not efficient.
    """
    if input_fn is not None:
        module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
    if output_fn is not None:
        module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
    return module


class TensorParallelLayer:
    """General tensor parallel layer for transformers"""

    device_mesh = None
    rank = None
    empty_param = None

    def __init__(self, device_mesh=None, rank=None, empty_param=None):
        self.rank = rank
        self.device_mesh = device_mesh
        self.empty_param = empty_param

    @staticmethod
    def _prepare_input_fn(mod, inputs, device_mesh): ...

    @staticmethod
    def _prepare_output_fn(mod, outputs, device_mesh): ...

    def shard_tensor(
        self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
    ) -> torch.Tensor:
        raise NotImplementedError

    def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
        distribute_module(
            module,
            device_mesh,
            self._prepare_input_fn,
            self._prepare_output_fn,
        )

    def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
        """
        Compute the expected shape after TP sharding for a given full shape.

        Args:
            full_shape: The full (unsharded) parameter shape

        Returns:
            The expected sharded shape for this rank
        """
        # Default: no sharding, return full shape
        return tuple(full_shape)


class ColwiseParallel(TensorParallelLayer):
    """
    Column-wise parallel: weight is sharded on dim -2 (output features).
    Forward: input replicated -> output sharded on last dim.
    If gather_output=True, output is all-gathered to produce full tensor.
    """

    def __init__(self, gather_output: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.gather_output = gather_output

    def _prepare_input_fn(self, mod, inputs, device_mesh):
        input_tensor = inputs[0] if inputs else inputs
        return all_reduce_backward(input_tensor, device_mesh)

    def _prepare_output_fn(self, mod, outputs, device_mesh):
        if self.gather_output:
            return all_gather(outputs, device_mesh)
        return outputs

    def shard_tensor(
        self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
    ) -> torch.Tensor:
        # If only 1 dim, shard this one (usually it's a `bias`)
        dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
        if dim == 1:
            parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
        else:
            parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
        return parameter.to(device=device, dtype=dtype)

    def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
        world_size = self.device_mesh.size()
        shape = list(full_shape)
        # Colwise shards dim -2, but 1D tensors (bias) shard on dim -1
        dim = -1 if len(shape) == 1 else -2
        dim = len(shape) + dim if dim < 0 else dim
        shard_size = math.ceil(shape[dim] / world_size)
        start = self.rank * shard_size
        end = min(start + shard_size, shape[dim])
        shape[dim] = end - start
        return tuple(shape)


class RowwiseParallel(TensorParallelLayer):
    """
    Row-wise parallel: weight is sharded on dim -1 (input features).
    Forward: input (optionally split) -> output partial -> all-reduce to replicate.

    Args:
        split_input: If True, splits replicated input before matmul. Use when input
                     comes from a non-parallelizable operation (chunk/slice).
                     Default False (expects pre-sharded input from colwise layer).
    """

    def __init__(self, split_input: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.split_input = split_input

    def _prepare_input_fn(self, mod, inputs, device_mesh):
        if hasattr(mod, "bias") and mod.bias is not None:
            mod._bias = mod.bias
            mod.bias = None

        input_tensor = inputs[0] if inputs else inputs

        if self.split_input:
            # Input is replicated, split it to match sharded weight
            return split(input_tensor, device_mesh)
        return input_tensor

    def _prepare_output_fn(self, mod, outputs, device_mesh):
        outputs = all_reduce_forward(outputs, device_mesh)
        if hasattr(mod, "_bias") and mod._bias is not None:
            outputs = outputs + mod._bias
        return outputs

    def shard_tensor(
        self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
    ) -> torch.Tensor:
        # If only 1 dim, it should not be sharded (usually it's a `bias`)
        dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
        if dim == 1:
            parameter = param[...]
        else:
            parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
        return parameter.to(device=device, dtype=dtype)

    def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
        # 1D tensors (bias) are NOT sharded in rowwise
        if len(full_shape) == 1:
            return tuple(full_shape)
        world_size = self.device_mesh.size()
        shape = list(full_shape)
        dim = -1
        dim = len(shape) + dim if dim < 0 else dim
        shard_size = math.ceil(shape[dim] / world_size)
        start = self.rank * shard_size
        end = min(start + shard_size, shape[dim])
        shape[dim] = end - start
        return tuple(shape)


class PackedColwiseParallel(ColwiseParallel):
    """Packed column-wise parallel for fused weights like gate_up_proj."""

    def shard_tensor(
        self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
    ) -> torch.Tensor:
        # If only 1 dim, shard this one (usually it's a `bias`)
        dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
        if dim == 1:
            parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
        else:
            expected_shape = self.get_expected_sharded_shape(self.empty_param.shape)
            if dim < len(expected_shape):
                # Input is unpacked (e.g., gate_proj that will be concatenated to gate_up_proj)
                # Use regular tensor shard - concatenation will happen after
                parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -2)
            else:
                # Input is already packed, use packed sharding
                parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -2)
        return parameter.to(device=device, dtype=dtype)


class PackedRowwiseParallel(RowwiseParallel):
    """Packed row-wise parallel for fused weights like gate_up_proj."""

    def shard_tensor(
        self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
    ) -> torch.Tensor:
        # If only 1 dim, it should not be sharded (usually it's a `bias`)
        dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
        if dim == 1:
            parameter = param[...]
        else:
            # Check if input tensor is unpacked (shape mismatch with expected packed size)
            # This happens when using MergeModulelist + Concatenate for fused weights like gate_up_proj
            param_shape = param.shape if isinstance(param, torch.Tensor) else param.get_shape()
            expected_packed_dim = self.empty_param.shape[-1] if self.empty_param.dim() >= 1 else 0
            actual_dim = param_shape[-1] if len(param_shape) >= 1 else 0

            if actual_dim < expected_packed_dim:
                # Input is unpacked, use regular tensor shard
                parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
            else:
                # Input is already packed, use packed sharding
                parameter = get_packed_weights(param, self.empty_param, self.device_mesh, self.rank, -1)
        return parameter.to(device=device, dtype=dtype)


class EmbeddingParallel(TensorParallelLayer):
    """EmbeddingParallel: shards embedding table, handles masked lookups for vocab parallelism."""

    def __init__(self, *, embedding_dim_sharding: int = 0, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim_sharding = embedding_dim_sharding

    def _prepare_input_fn(self, mod, inputs, device_mesh):
        input_tensor = inputs[0] if inputs else inputs

        # For vocab-parallel (dim 0), we need to handle masking and offsetting
        if self.embedding_dim_sharding == 0:
            rank = device_mesh.get_local_rank()

            # Get vocab range for this rank
            # Use weight.shape[0] to get the actual local (sharded) size, not num_embeddings
            # which may not be updated after sharding
            per_partition_size = mod.weight.shape[0]
            vocab_start_index = rank * per_partition_size
            vocab_end_index = vocab_start_index + per_partition_size

            # Build mask for out-of-vocabulary tokens
            input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
            mod._input_mask = input_mask

            # Offset input to local indices and mask invalid ones
            masked_input = input_tensor.clone() - vocab_start_index
            masked_input[input_mask] = 0  # Set to valid local index

            return masked_input

        return input_tensor

    def _prepare_output_fn(self, mod, outputs, device_mesh):
        # For vocab-parallel (dim 0), zero out embeddings for out-of-range tokens before all-reduce
        if self.embedding_dim_sharding == 0 and hasattr(mod, "_input_mask"):
            input_mask = mod._input_mask
            # Use multiplication instead of in-place assignment to preserve gradients
            mask_expanded = input_mask.unsqueeze(-1).expand_as(outputs)
            outputs = outputs * (~mask_expanded).float()
            del mod._input_mask

        return all_reduce_forward(outputs, device_mesh)

    def shard_tensor(
        self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
    ) -> torch.Tensor:
        # If only 1 dim, shard this one (usually it's a `bias`)
        dim = param.dim() if isinstance(param, torch.Tensor) else len(param.get_shape())
        if dim == 1:
            parameter = get_tensor_shard(param, self.empty_param, self.device_mesh, self.rank, -1)
        else:
            parameter = get_tensor_shard(
                param,
                self.empty_param,
                self.device_mesh,
                self.rank,
                self.embedding_dim_sharding,
            )
        return parameter.to(device=device, dtype=dtype)

    def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
        world_size = self.device_mesh.size()
        shape = list(full_shape)
        # EmbeddingParallel shards on self.embedding_dim_sharding (default 0)
        # 1D tensors (bias) shard on dim -1
        dim = -1 if len(shape) == 1 else self.embedding_dim_sharding
        dim = len(shape) + dim if dim < 0 else dim
        shard_size = math.ceil(shape[dim] / world_size)
        start = self.rank * shard_size
        end = min(start + shard_size, shape[dim])
        shape[dim] = end - start
        return tuple(shape)


class SequenceParallel(TensorParallelLayer):
    """
    Sequence Parallel: input/output sharded on sequence dimension.
    Weights are replicated.
    """

    def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
        super().__init__(**kwargs)
        self.sequence_dim = sequence_dim

    def _prepare_input_fn(self, mod, inputs, device_mesh):
        input_tensor = inputs[0] if inputs else inputs
        # For sequence parallel, input is sharded on sequence dim
        # All-gather for the layer, then reduce-scatter after
        return all_gather(input_tensor, device_mesh)

    def _prepare_output_fn(self, mod, outputs, device_mesh):
        return reduce_scatter(outputs, device_mesh)

    def shard_tensor(
        self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
    ) -> torch.Tensor:
        return param[...].to(device=device, dtype=dtype)


class GroupedGemmParallel(TensorParallelLayer):
    """
    Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def shard_tensor(
        self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
    ) -> torch.Tensor:
        global_num_experts = self.empty_param.shape[0]
        if global_num_experts % self.device_mesh.size() != 0:
            raise ValueError(
                f"Global number of experts must be divisible by number of devices: {global_num_experts} % {self.device_mesh.size()} != 0"
            )
        local_num_experts = global_num_experts // self.device_mesh.size()
        shard_size = local_num_experts
        if isinstance(device, torch.device):
            device = device.index if device.index is not None else 0
        start = device * shard_size
        end = (device + 1) * shard_size
        # special case we don't "shard" just send this entire tensor to the correct rank.
        shape = param.get_shape() if not isinstance(param, torch.Tensor) else param.shape
        if tensor_idx is not None and start <= tensor_idx < end:
            # this tensor does need to be materialized on this device:
            return param[:].to(device=device)
        elif tensor_idx is None:  # a bias or a weight, but already merged
            return param[start:end].to(device=device, dtype=dtype)
        elif len(shape) >= 1 and tensor_idx is not None:
            return None
        else:  # bias case
            return param[:].to(device=device, dtype=dtype)

    def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]:
        # GroupedGemm shards on dim 0 (experts dimension)
        world_size = self.device_mesh.size()
        shape = list(full_shape)
        local_num_experts = shape[0] // world_size
        shape[0] = local_num_experts
        return tuple(shape)


class RouterParallel(TensorParallelLayer):
    """
    Allows to reshape the router scores to support running expert parallel.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @staticmethod
    def _prepare_input_fn(mod, inputs, device_mesh):
        return inputs[0] if inputs else inputs

    @staticmethod
    def _prepare_output_fn(mod, outputs, device_mesh):
        """
        Imagine if you had 4 tokens, top_k = 4, and 128experts.
        With EP = 8. The num_local_expert should be 128/8 = 16
        Imagine router_indices being:
        [ 52,  42, 119,  67],
        [102,  89,  61,  40],
        [ 82, 103,   4,  34],
        [ 93,  23, 109,  11],

        then you can map which rank should be getting which values

        [3, 2, 7, 4],
        [6, 5, 3, 2],
        [5, 6, 0, 2],
        [5, 1, 6, 0],

        Thus for say rank 0, you fill with 16 (num_local_expert) the index tensor

        [ 16, 16, 16, 16],
        [ 16, 16, 16, 16],
        [ 16, 16, 4, 16],
        [ 16, 16, 16, 11],

        This works well. For another rank you need to make sure you round to num_local_expert
        because the next operation will one hot encode the router index vector.

        This allows us to know directly which local expert is hit.
        Similarly the scores are indexed with something created form
        router_indices.

        The kinda naive training loop that we use for device_map "auto" uses a similar logic.
        Here we are just making each rank believe that he is alone, and he computes his part of the hiddenstates.
        Mask invalid indices with num_local_expert for one-hot encoding, so the computes will skip the masking index.
        """
        ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size()
        if mod.num_experts % ep_size != 0:
            raise ValueError(
                f"The number of experts must be divisible by number of ep_size: {mod.num_experts} % {ep_size} != 0"
            )
        num_local_experts = mod.num_experts // ep_size
        router_logits, router_scores, router_indices = outputs
        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_scores)
        router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts]
        router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, -1)
        # As -1 % 1 is 0, we can only use mask fill when num_local_experts is 1
        if num_local_experts > 1:
            router_indices = torch.fmod(router_indices, num_local_experts)
        else:
            router_indices = router_indices.masked_fill(router_indices > 0, 0).masked_fill(router_indices < 0, -1)
        router_indices = router_indices.masked_fill(router_indices == -1, num_local_experts)
        return router_logits, router_scores, router_indices

    def shard_tensor(
        self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
    ) -> torch.Tensor:
        return param[...].to(device=device, dtype=dtype)


class MoeTensorParalellExperts(TensorParallelLayer):
    """
    Note: For tensor parallel, the MoEExpertsParallel TP layer handles gradient sync:
        - all_reduce_backward on hidden_states (for colwise gate_up_proj gradient)
        - all_reduce_backward on top_k_weights (for router gradient)
        - all_reduce_forward on output (for partial expert outputs)
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @staticmethod
    def _prepare_input_fn(mod, inputs, device_mesh):
        # inputs = (hidden_states, top_k_index, top_k_weights)
        hidden_states = inputs[0]
        top_k_index = inputs[1]
        top_k_weights = inputs[2]

        # all_reduce_backward on hidden_states for correct colwise (gate_up_proj) gradient
        hidden_states = all_reduce_backward(hidden_states, device_mesh)

        # all_reduce_backward on routing weights for correct router gradient
        # This is needed because ∂L/∂routing_weights = ∂L/∂output * partial_expert_output
        # and partial_expert_output is different on each GPU before all-reduce
        top_k_weights = all_reduce_backward(top_k_weights, device_mesh)

        return (hidden_states, top_k_index, top_k_weights)

    @staticmethod
    def _prepare_output_fn(mod, outputs, device_mesh):
        # all_reduce_forward to sum partial expert outputs across GPUs
        return all_reduce_forward(outputs, device_mesh)

    def shard_tensor(
        self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None
    ) -> torch.Tensor:
        # This class doesn't shard tensors - sharding is handled by packed_colwise/rowwise
        # on the individual weight tensors (gate_up_proj/down_proj)
        return param[...].to(device=device, dtype=dtype)


class ParallelInterface(GeneralInterface):
    # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
    # a new instance is created (in order to locally override a given entry)
    _global_mapping = (
        {
            "embedding_rowwise": EmbeddingParallel(embedding_dim_sharding=0),
            "colwise_gather_output": ColwiseParallel(gather_output=True),
            "colwise": ColwiseParallel(),
            "rowwise": RowwiseParallel(),
            "rowwise_split_input": RowwiseParallel(split_input=True),
            "packed_colwise": PackedColwiseParallel(),
            "packed_rowwise": PackedRowwiseParallel(),
            "sequence_parallel": SequenceParallel(),
            "grouped_gemm": GroupedGemmParallel(),
            "ep_router": RouterParallel(),
            "moe_tp_experts": MoeTensorParalellExperts(),
        }
        if is_torch_available() and _torch_distributed_available
        else {}
    )

    # Map plan names to sharding dimensions for weights
    # For weights: colwise shards dim -2, rowwise shards dim -1
    # For embedding: rowwise shards dim 0 (vocab), colwise shards dim -2 (hidden)
    plan_to_weight_dim: dict[str, int | None] = {
        "colwise": -2,
        "colwise_gather_output": -2,
        "packed_colwise": -2,
        "rowwise": -1,
        "rowwise_split_input": -1,
        "packed_rowwise": -1,
        "embedding_rowwise": 0,
        "sequence_parallel": None,
    }

    # Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced)
    plan_to_bias_dim: dict[str, int | None] = {
        "colwise": -1,
        "colwise_gather_output": -1,
        "packed_colwise": -1,
        "rowwise": None,
        "rowwise_split_input": None,
        "packed_rowwise": None,
        "embedding_rowwise": None,
        "sequence_parallel": None,
    }


ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()


# =============================================================================
# High-Level API Functions
# =============================================================================


def gather_full_tensor(local_tensor: torch.Tensor, shard_dim: int, device_mesh) -> torch.Tensor:
    """
    All-gather a sharded tensor along the specified dimension to reconstruct the full tensor.

    Args:
        local_tensor: The local shard of the tensor on this rank
        shard_dim: The dimension along which the tensor was sharded
        device_mesh: The device mesh for distributed communication

    Returns:
        The full reconstructed tensor (same on all ranks)
    """
    world_size = device_mesh.size()
    # In case of TP+DP configuration, the TP group should be used for gathering, not the full DP group
    process_group = device_mesh.get_group("tp") if "tp" in device_mesh.mesh_dim_names else None

    # Normalize negative dimension
    if shard_dim < 0:
        shard_dim = local_tensor.ndim + shard_dim

    # Gather all shards
    gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
    dist.all_gather(gathered_tensors, local_tensor.contiguous(), group=process_group)

    # Concatenate along the shard dimension
    return torch.cat(gathered_tensors, dim=shard_dim)


def gather_state_dict_for_save(
    state_dict: dict[str, torch.Tensor],
    tp_plan: dict[str, str],
    device_mesh,
    tp_size: int,
) -> dict[str, torch.Tensor]:
    """
    Gather sharded tensors to reconstruct full tensors for saving.

    This function all-gathers each sharded tensor along its shard dimension
    to reconstruct the full unsharded tensor for checkpoint saving.

    Args:
        state_dict: The model state dict with local sharded tensors
        tp_plan: The tensor parallel plan mapping layer patterns to shard styles
        device_mesh: The device mesh for distributed communication
        tp_size: The tensor parallel world size

    Returns:
        State dict with full (gathered) tensors
    """
    # Use the global mappings from ParallelInterface (can be extended by users)
    plan_to_weight_dim = ALL_PARALLEL_STYLES.plan_to_weight_dim
    plan_to_bias_dim = ALL_PARALLEL_STYLES.plan_to_bias_dim

    result = {}
    for key, tensor in state_dict.items():
        # Find the matching TP plan for this parameter
        param_name = key.rsplit(".", 1)[0] if "." in key else key
        param_type = key.rsplit(".", 1)[1] if "." in key else None
        generic_param_name = re.sub(r"\d+", "*", param_name)
        # Also check the full key for nn.Parameter (e.g., MoE experts without .weight suffix)
        generic_full_key = re.sub(r"\d+", "*", key)

        # Check if this parameter has a TP plan
        current_plan = None
        if generic_full_key in tp_plan:
            # Full key match (e.g., "model.layers.*.mlp.experts.gate_up_proj" for MoE experts)
            current_plan = tp_plan[generic_full_key]
        elif generic_param_name in tp_plan:
            current_plan = tp_plan[generic_param_name]
        elif "." in generic_param_name:
            parent_param_name = generic_param_name.rsplit(".", 1)[0]
            if parent_param_name in tp_plan:
                current_plan = tp_plan[parent_param_name]

        if current_plan is None or current_plan not in plan_to_weight_dim:
            # Not sharded, keep as-is
            result[key] = tensor
            continue

        # Determine sharding dimension based on param type
        if param_type == "bias":
            shard_dim = plan_to_bias_dim.get(current_plan)
        else:
            shard_dim = plan_to_weight_dim.get(current_plan)

        if shard_dim is None:
            # Replicated, keep as-is
            result[key] = tensor
            continue

        # Gather full tensor and handle packed weights repacking
        full_tensor = gather_full_tensor(tensor, shard_dim, device_mesh)
        if current_plan in ("packed_colwise", "packed_rowwise"):
            full_tensor = repack_weights(full_tensor, shard_dim, tp_size, 2)
        result[key] = full_tensor.contiguous()

    return result


def add_tensor_parallel_hooks_to_module(
    model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None
):
    r"""
    This function is called in `PretrainedModel.post_init()`. It is responsible of adding hooks
    to the modules of the `model`, based on the `PretrainedModel._tp_plan`.

    This is the place where we add the `pre_forward` and `post_forwards` hooks. These are defined
    for each `TensorParallelLayer` as `_prepare_input_fn` and `_prepare_output_fn`.

    """
    if current_module_plan is not None:
        tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
        try:
            tp_layer.prepare_module_tp(module, device_mesh)
        except NotImplementedError as e:
            print(
                f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
            )

        module._hf_tp_plan = current_module_plan
        module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"


def shard_and_distribute_module(
    model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
):
    r"""
    This function is called in `from_pretrained` when loading a model's checkpoints.
    It receives the pointer to the parameter (or the parameter itself) and takes care of "sharding".
    All process run this function, so they just load the partition of the tensor that they require.

    Main uses cases:
    - column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
    - packed layers: you slice the weights, then shard like above
    - custom operation:
        - you want to add an all-gather at the end of a local layer.
        - you want to have a layer that is isolated from the rest of the world (because torch.DTensor does not work well with `.view` for instance)

    """
    param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
    tp_plan = model.tp_plan or {}
    module_to_tp = model.get_submodule(param_name)
    rank = int(rank)
    current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan)

    if dist.get_rank() == 0:
        if current_shard_plan is None:
            logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.")
        else:
            logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}")

    if current_shard_plan is not None:
        try:
            tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
            tp_layer.empty_param = empty_param
            tp_layer.device_mesh = device_mesh
            tp_layer.rank = rank
            param = tp_layer.shard_tensor(param, tensor_idx=None, dtype=param_casting_dtype, device=rank)
            if is_contiguous:
                param = param.contiguous()
        except NotImplementedError as e:
            print(
                f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
            )
    else:
        param = param[:].to(param_casting_dtype)

    # SUPER IMPORTANT we have to use setattr
    # otherwise loading is crazy slow
    if not isinstance(param, torch.nn.Parameter):
        param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point())
    setattr(module_to_tp, param_type, param)
    return param


def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
    """
    Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied.
    """

    if tp_plan is None:
        return

    generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
    unsharded_layers = set(generic_keys)
    unused_rules = tp_plan.copy()

    for key in generic_keys:
        param_name = key.rsplit(".", 1)[0] if "." in key else key
        generic_param_name = re.sub(r"\d+", "*", param_name)

        if generic_param_name in tp_plan:
            unused_rules.pop(generic_param_name, None)
            unsharded_layers.discard(key)
        elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
            unused_rules.pop(parent_param_name, None)
            unsharded_layers.discard(key)

    if len(unused_rules) > 0:
        logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
    if len(unsharded_layers) > 0:
        logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")


def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size):
    """Distribute a model according to the TP plan."""
    model._tp_size = tp_size
    model._device_mesh = device_mesh
    if distributed_config is not None:
        if isinstance(distributed_config, dict):
            distributed_config = DistributedConfig.from_dict(distributed_config)
        model.config.distributed_config = distributed_config
    # Set the new requested tp_plan on the model
    if isinstance(tp_plan, dict):
        model.tp_plan = tp_plan
    model_plan = model.tp_plan
    if model_plan is not None and _torch_distributed_available:
        for v in model_plan.values():
            if v not in ALL_PARALLEL_STYLES:
                raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}")
        for name, module in model.named_modules():
            if not getattr(module, "_is_hooked", False):
                plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model_plan, is_weight=False)
                add_tensor_parallel_hooks_to_module(
                    model=model,
                    module=module,
                    tp_plan=model_plan,
                    layer_name="",
                    current_module_plan=plan,
                    device_mesh=device_mesh,
                )
            module._is_hooked = True
    return model
