#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/sam2_video/modular_sam2_video.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_sam2_video.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Union

import numpy as np
import torch

from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, auto_docstring
from ...utils.import_utils import requires
from ...video_utils import VideoInput
from .modeling_sam2_video import Sam2VideoInferenceSession


@requires(backends=("torch",))
@auto_docstring
class Sam2VideoProcessor(ProcessorMixin):
    def __init__(
        self, image_processor, video_processor, target_size: int | None = None, point_pad_value: int = -10, **kwargs
    ):
        r"""
        target_size (`int`, *optional*):
            The target size (in pixels) for normalizing input points and bounding boxes. If not provided, defaults
            to the image processor's size configuration. All input coordinates (points and boxes) are normalized
            to this size before being passed to the model. This ensures consistent coordinate representation
            regardless of the original image dimensions.
        point_pad_value (`int`, *optional*, defaults to -10):
            The value used for padding input points when batching sequences of different lengths. This value is
            used to mark padded positions and is preserved during coordinate normalization.
        """
        super().__init__(image_processor, video_processor, **kwargs)
        self.point_pad_value = point_pad_value
        self.target_size = target_size if target_size is not None else self.image_processor.size["height"]

    @auto_docstring
    def __call__(
        self,
        images: ImageInput | None = None,
        segmentation_maps: ImageInput | None = None,
        input_points: list[list[list[list[float]]]] | torch.Tensor | None = None,
        input_labels: list[list[list[int]]] | torch.Tensor | None = None,
        input_boxes: list[list[list[float]]] | torch.Tensor | None = None,
        original_sizes: list[list[float]] | torch.Tensor | None = None,
        return_tensors: str | TensorType | None = None,
        **kwargs,
    ) -> BatchEncoding:
        r"""
        segmentation_maps (`ImageInput`, *optional*):
            The segmentation maps to process.
        input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*):
            The points to add to the frame.
        input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*):
            The labels for the points.
        input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*):
            The bounding boxes to add to the frame.
        original_sizes (`list[list[float]]`, `torch.Tensor`, *optional*):
            The original sizes of the images.

        Returns:
            A [`BatchEncoding`] with the following fields:
            - `pixel_values` (`torch.Tensor`): The processed image(s).
            - `original_sizes` (`list[list[float]]`): The original sizes of the images.
            - `labels` (`torch.Tensor`): The processed segmentation maps (if provided).
            - `input_points` (`torch.Tensor`): The processed points.
            - `input_labels` (`torch.Tensor`): The processed labels.
            - `input_boxes` (`torch.Tensor`): The processed bounding boxes.
        """
        if images is not None:
            encoding_image_processor = self.image_processor(
                images,
                segmentation_maps=segmentation_maps,
                return_tensors=return_tensors,
                **kwargs,
            )
        elif original_sizes is not None:
            if isinstance(original_sizes, torch.Tensor):
                original_sizes = original_sizes.cpu().tolist()
            encoding_image_processor = BatchEncoding({"original_sizes": original_sizes}, tensor_type=return_tensors)
        else:
            raise ValueError("Either images or original_sizes must be provided")

        # pop arguments that are not used in the forward but used nevertheless
        original_sizes = encoding_image_processor["original_sizes"]
        # Check original_sizes is of length 1 or len(images)
        if images is not None and len(original_sizes) != 1 and len(original_sizes) != len(images):
            raise ValueError(
                "original_sizes must be of length 1 or len(images). If you are passing a single image, you must pass a single original_size."
            )

        # Process input points, labels, and boxes if provided
        if input_points is not None or input_labels is not None or input_boxes is not None:
            # Validate and convert inputs to standardized format
            processed_points = self._validate_single_input(
                input_points,
                expected_depth=4,
                input_name="points",
                expected_format="[image level, object level, point level, point coordinates]",
                expected_coord_size=2,
            )
            processed_labels = self._validate_single_input(
                input_labels,
                expected_depth=3,
                input_name="labels",
                expected_format="[image level, object level, point level]",
            )
            processed_boxes = self._validate_single_input(
                input_boxes,
                expected_depth=3,
                input_name="boxes",
                expected_format="[image level, box level, box coordinates]",
                expected_coord_size=4,
            )

            # Get padding requirements for all inputs
            if processed_points is not None:
                points_max_dims = self._get_nested_dimensions(processed_points)[:3]
            if processed_labels is not None:
                labels_max_dims = self._get_nested_dimensions(processed_labels)[:3]
            if processed_boxes is not None:
                boxes_max_dims = self._get_nested_dimensions(processed_boxes)[:2]

            # Ensure points and labels have consistent dimensions
            if processed_points is not None and processed_labels is not None:
                if points_max_dims != labels_max_dims:
                    raise ValueError(
                        "Input points and labels have inconsistent dimensions. Please ensure they have the same dimensions."
                    )

            # Check that boxes don't need padding (model limitation)
            if processed_boxes is not None and len(processed_boxes) >= 2:
                if any(len(img_boxes) < boxes_max_dims[1] for img_boxes in processed_boxes):
                    raise ValueError(
                        "Input boxes have inconsistent dimensions that would require padding, "
                        "but boxes cannot be padded due to model limitations. "
                        "Please ensure all images have the same number of boxes."
                    )

            # Pad and normalize all inputs to final tensor format
            if processed_points is not None:
                padded_points = self._pad_nested_list(processed_points, points_max_dims + [2])
                final_points = torch.tensor(padded_points, dtype=torch.float32)
                self._normalize_tensor_coordinates(final_points, original_sizes, preserve_padding=True)
                encoding_image_processor.update({"input_points": final_points})

            if processed_labels is not None:
                padded_labels = self._pad_nested_list(processed_labels, labels_max_dims)
                final_labels = torch.tensor(padded_labels, dtype=torch.int64)
                encoding_image_processor.update({"input_labels": final_labels})

            if processed_boxes is not None:
                final_boxes = torch.tensor(processed_boxes, dtype=torch.float32)
                self._normalize_tensor_coordinates(final_boxes, original_sizes, is_bounding_box=True)
                encoding_image_processor.update({"input_boxes": final_boxes})

        return encoding_image_processor

    def _normalize_coordinates(
        self, target_size: int, coords: "torch.Tensor", original_size, is_bounding_box=False
    ) -> "torch.Tensor":
        """
        Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.

        Args:
            target_size (`int`):
                The target size of the image.
            coords (`torch.Tensor`):
                The coordinates to be normalized.
            original_size (`tuple`):
                The original size of the image.
            is_bounding_box (`bool`, *optional*, defaults to `False`):
                Whether the coordinates are bounding boxes.
        """
        old_h, old_w = original_size
        new_h, new_w = target_size, target_size
        coords = deepcopy(coords).float()

        if is_bounding_box:
            coords = coords.reshape(-1, 2, 2)
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)

        if is_bounding_box:
            coords = coords.reshape(-1, 4)

        return coords

    def _convert_to_nested_list(self, data, expected_depth, current_depth=0):
        """
        Recursively convert various input formats (tensors, numpy arrays, lists) to nested lists.

        Args:
            data: Input data in any format
            expected_depth: Expected nesting depth
            current_depth: Current depth in recursion

        Returns:
            Nested list representation of the data
        """
        if data is None:
            return None

        # Convert tensor/numpy to list if we're at a leaf level or if it's a multi-dimensional array
        if isinstance(data, torch.Tensor):  # PyTorch tensor
            if current_depth == expected_depth - 2 or len(data.shape) <= 2:  # At coordinate level or small tensor
                return data.numpy().tolist()
            else:
                return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
        elif isinstance(data, np.ndarray):  # NumPy array
            if current_depth == expected_depth - 2 or len(data.shape) <= 2:  # At coordinate level or small array
                return data.tolist()
            else:
                return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
        elif isinstance(data, list):
            if current_depth == expected_depth:
                # We've reached the expected depth, return as is
                return data
            else:
                # Continue recursion
                return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data]
        elif isinstance(data, (int, float)):
            return data
        else:
            raise TypeError(f"Unsupported data type: {type(data)}")

    def _get_nested_dimensions(self, nested_list, max_dims=None):
        """
        Get the maximum dimensions at each level of nesting.

        Args:
            nested_list (`list`):
                Nested list structure.
            max_dims (`list`, *optional*):
                Current maximum dimensions (for recursion).

        Returns:
            `list`: A list of maximum dimensions for each nesting level.
        """
        if max_dims is None:
            max_dims = []

        if not isinstance(nested_list, list):
            return max_dims

        if len(max_dims) == 0:
            max_dims.append(len(nested_list))
        else:
            max_dims[0] = max(max_dims[0], len(nested_list))

        if len(nested_list) > 0:
            for item in nested_list:
                if isinstance(item, list):
                    sub_dims = self._get_nested_dimensions(item)
                    # Merge sub_dims into max_dims
                    for i, dim in enumerate(sub_dims):
                        if i + 1 >= len(max_dims):
                            max_dims.append(dim)
                        else:
                            max_dims[i + 1] = max(max_dims[i + 1], dim)

        return max_dims

    def _pad_nested_list(self, nested_list, target_dims, current_level=0, pad_value=None):
        """
        Recursively pad a nested list to match target dimensions.

        Args:
            nested_list (`list`):
                Nested list to pad.
            target_dims (`list`):
                Target dimensions for each level.
            current_level (`int`, *optional*, defaults to 0):
                Current nesting level.
            pad_value (`int`, *optional*):
                Value to use for padding.

        Returns:
            `list`: The padded nested list.
        """
        if pad_value is None:
            pad_value = self.point_pad_value

        if current_level >= len(target_dims):
            return nested_list

        # Ensure we have a list
        if not isinstance(nested_list, list):
            nested_list = [nested_list]

        # Pad current level
        current_size = len(nested_list)
        target_size = target_dims[current_level]

        # Pad with appropriate values
        if current_level == len(target_dims) - 1:
            # At the coordinate level, pad with pad_value
            nested_list.extend([pad_value] * (target_size - current_size))
        else:
            # At higher levels, pad with nested structures
            if current_size > 0:
                # Create appropriately sized template
                if current_level < len(target_dims) - 2:
                    # For non-coordinate levels, create empty nested structure
                    template_dims = target_dims[current_level + 1 :]
                    template = self._create_empty_nested_structure(template_dims, pad_value)
                else:
                    # For coordinate level, create list of pad_values
                    template = [pad_value] * target_dims[current_level + 1]

                nested_list.extend([deepcopy(template) for _ in range(target_size - current_size)])
            else:
                # Create from scratch
                template_dims = target_dims[current_level + 1 :]
                template = self._create_empty_nested_structure(template_dims, pad_value)
                nested_list.extend([deepcopy(template) for _ in range(target_size)])

        # Recursively pad sublists
        if current_level < len(target_dims) - 1:
            for i in range(len(nested_list)):
                if isinstance(nested_list[i], list):
                    nested_list[i] = self._pad_nested_list(nested_list[i], target_dims, current_level + 1, pad_value)

        return nested_list

    def _create_empty_nested_structure(self, dims, pad_value):
        """
        Create an empty nested structure with given dimensions filled with pad_value.

        Args:
            dims (`list`):
                The dimensions of the nested structure.
            pad_value (`int`):
                The value to fill the structure with.
        """
        if len(dims) == 1:
            return [pad_value] * dims[0]
        else:
            return [self._create_empty_nested_structure(dims[1:], pad_value) for _ in range(dims[0])]

    def _get_nesting_level(self, input_list):
        """
        Get the nesting level of a list structure.

        Args:
            input_list (`list`):
                The list to get the nesting level of.
        """
        if isinstance(input_list, list):
            if len(input_list) == 0:
                return 1
            return 1 + self._get_nesting_level(input_list[0])
        elif isinstance(input_list, (np.ndarray, torch.Tensor)):
            # For arrays/tensors, the nesting level is the number of dimensions
            return len(input_list.shape)
        return 0

    def _validate_single_input(
        self,
        data: torch.Tensor | np.ndarray | list,
        expected_depth: int,
        input_name: str,
        expected_format: str,
        expected_coord_size: int | None = None,
    ) -> list:
        """
                Validate a single input by ensuring proper nesting and raising an error if the input is not valid.

                Args:
                    data (`torch.Tensor`, `np.ndarray`, or `list`):
                        Input data to process.
                    expected_depth (`int`):
                        Expected nesting depth.
                    input_name (`str`):
                        Name of the input for error messages.
                    expected_format (`str`):
                        The expected format of the input.
                    expected_coord_size (`int`, *optional*):
                        Expected coordinate size (2 for points, 4 for boxes, None for labels).
        .
        """
        if data is None:
            return None

        # Handle tensors and numpy arrays first
        if isinstance(data, (torch.Tensor, np.ndarray)):
            # For tensors/arrays, we can directly check the number of dimensions
            if data.ndim != expected_depth:
                raise ValueError(
                    f"Input {input_name} must be a tensor/array with {expected_depth} dimensions. The expected nesting format is {expected_format}. Got {data.ndim} dimensions."
                )
            elif expected_coord_size is not None:
                if data.shape[-1] != expected_coord_size:
                    raise ValueError(
                        f"Input {input_name} must be a tensor/array with {expected_coord_size} as the last dimension, got {data.shape[-1]}."
                    )
            return self._convert_to_nested_list(data, expected_depth)

        # Handle nested lists
        if isinstance(data, list):
            current_depth = self._get_nesting_level(data)
            if current_depth != expected_depth:
                raise ValueError(
                    f"Input {input_name} must be a nested list with {expected_depth} levels. The expected nesting format is {expected_format}. Got {current_depth} levels."
                )
            return self._convert_to_nested_list(data, expected_depth)

    def _normalize_tensor_coordinates(self, tensor, original_sizes, is_bounding_box=False, preserve_padding=False):
        """
        Helper method to normalize coordinates in a tensor across multiple images.

        Args:
            tensor (`torch.Tensor`):
                Input tensor with coordinates.
            original_sizes (`list`):
                Original image sizes.
            is_bounding_box (`bool`, *optional*, defaults to `False`):
                Whether coordinates are bounding boxes.
            preserve_padding (`bool`, *optional*, defaults to `False`):
                Whether to preserve padding values (for points).
        """
        if preserve_padding:
            # For points: avoid normalizing pad values
            mask = tensor != self.point_pad_value
            coord_mask = mask.all(dim=-1, keepdim=True)

        for img_idx in range(len(original_sizes)):
            if img_idx < tensor.shape[0]:
                original_size = original_sizes[img_idx] if img_idx < len(original_sizes) else original_sizes[0]
                normalized_coords = self._normalize_coordinates(
                    self.target_size, tensor[img_idx], original_size, is_bounding_box=is_bounding_box
                )

                if preserve_padding:
                    # Only update non-padded values
                    img_mask = coord_mask[img_idx]
                    tensor[img_idx] = torch.where(
                        img_mask.expand_as(tensor[img_idx]), normalized_coords, tensor[img_idx]
                    )
                else:
                    tensor[img_idx] = normalized_coords

    def post_process_masks(
        self,
        masks,
        original_sizes,
        mask_threshold=0.0,
        binarize=True,
        max_hole_area=0.0,
        max_sprinkle_area=0.0,
        apply_non_overlapping_constraints=False,
        **kwargs,
    ):
        """
        Remove padding and upscale masks to the original image size.

        Args:
            masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
                Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
            original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
                The original sizes of each image before it was resized to the model's expected input shape, in (height,
                width) format.
            mask_threshold (`float`, *optional*, defaults to 0.0):
                Threshold for binarization and post-processing operations.
            binarize (`bool`, *optional*, defaults to `True`):
                Whether to binarize the masks.
            max_hole_area (`float`, *optional*, defaults to 0.0):
                The maximum area of a hole to fill.
            max_sprinkle_area (`float`, *optional*, defaults to 0.0):
                The maximum area of a sprinkle to fill.
            apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`):
                Whether to apply non-overlapping constraints to the masks.

        Returns:
            (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
            is given by original_size.
        """
        return self.image_processor.post_process_masks(
            masks,
            original_sizes,
            mask_threshold,
            binarize,
            max_hole_area,
            max_sprinkle_area,
            apply_non_overlapping_constraints,
            **kwargs,
        )

    @property
    def model_input_names(self):
        image_processor_input_names = self.image_processor.model_input_names
        return list(image_processor_input_names + ["original_sizes"])

    def init_video_session(
        self,
        video: VideoInput | None = None,
        inference_device: Union[str, "torch.device"] = "cpu",
        inference_state_device: Union[str, "torch.device"] | None = None,
        processing_device: Union[str, "torch.device"] | None = None,
        video_storage_device: Union[str, "torch.device"] | None = None,
        max_vision_features_cache_size: int = 1,
        dtype: torch.dtype = torch.float32,
    ):
        """
        Initializes a video session for inference.
        If a video is provided (async inference), the video will be processed and stored on the `video_storage_device`.

        Args:
            video (`VideoInput`, *optional*):
                The video to process. No need to provide when streaming.
            inference_device (`str` or `torch.device`, *optional*, defaults to "cpu"):
                The device to use for inference.
            inference_state_device (`str` or `torch.device`, *optional*):
                The device to store the inference state on.
            processing_device (`str` or `torch.device`, *optional*):
                The device to use for video processing.
            video_storage_device (`str` or `torch.device`, *optional*):
                The device to store the processed video frames on.
            max_vision_features_cache_size (`int`, *optional*, defaults to 1):
                The maximum number of vision features to cache.
            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
                The torch dtype to use for the whole session.
        """
        video_storage_device = video_storage_device if video_storage_device is not None else inference_device
        inference_state_device = inference_state_device if inference_state_device is not None else inference_device
        processing_device = processing_device if processing_device is not None else inference_device
        pixel_values_video = None
        video_height = None
        video_width = None
        if video is not None:
            processed_video = self.video_processor(videos=video, device=processing_device, return_tensors="pt")
            pixel_values_video = processed_video.pixel_values_videos[0]
            video_height = processed_video.original_sizes[0][0]
            video_width = processed_video.original_sizes[0][1]
        inference_session = Sam2VideoInferenceSession(
            video=pixel_values_video,
            video_height=video_height,
            video_width=video_width,
            inference_device=inference_device,
            video_storage_device=video_storage_device,
            inference_state_device=inference_state_device,
            dtype=dtype,
            max_vision_features_cache_size=max_vision_features_cache_size,
        )
        return inference_session

    def add_inputs_to_inference_session(
        self,
        inference_session: Sam2VideoInferenceSession,
        frame_idx: int,
        obj_ids: list[int] | int,
        input_points: list[list[list[list[float]]]] | torch.Tensor | None = None,
        input_labels: list[list[list[int]]] | torch.Tensor | None = None,
        input_boxes: list[list[list[float]]] | torch.Tensor | None = None,
        input_masks: np.ndarray | torch.Tensor | list[np.ndarray] | list[torch.Tensor] | None = None,
        original_size: tuple[int, int] | None = None,
        clear_old_inputs: bool = True,
    ) -> Sam2VideoInferenceSession:
        """
        Process new points, boxes, or masks for a video frame and add them to the inference session.

        Args:
            inference_session (`Sam2VideoInferenceSession`):
                The inference session for the video.
            frame_idx (`int`):
                The index of the frame to process.
            obj_ids (`list[int]` or `int`):
                The object ID(s) to associate with the points or box.
                These can be any integers and can be reused later on to specify an object.
            input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*):
                The points to add to the frame.
            input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*):
                The labels for the points.
            input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*):
                The bounding boxes to add to the frame.
            input_masks (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, or `list[torch.Tensor]`, *optional*):
                The mask(s) to add to the frame.
            original_size (`tuple[int, int]`, *optional*):
                The original size of the video. Provide when streaming.
            clear_old_inputs (`bool`, *optional*, defaults to `True`):
                Whether to clear old inputs for the object.
        """

        if isinstance(obj_ids, int):
            obj_ids = [obj_ids]

        # Validate inputs
        if (input_points is not None) != (input_labels is not None):
            raise ValueError("points and labels must be provided together")
        if input_points is None and input_boxes is None and input_masks is None:
            raise ValueError("at least one of points, boxes, or masks must be provided as input")
        if input_masks is not None and (input_points is not None or input_boxes is not None):
            raise ValueError("masks cannot be provided together with points or boxes")

        if input_masks is not None:
            return self.process_new_mask_for_video_frame(inference_session, frame_idx, obj_ids, input_masks)
        else:
            return self.process_new_points_or_boxes_for_video_frame(
                inference_session,
                frame_idx,
                obj_ids,
                input_points,
                input_labels,
                input_boxes,
                original_size,
                clear_old_inputs,
            )

    def process_new_points_or_boxes_for_video_frame(
        self,
        inference_session: Sam2VideoInferenceSession,
        frame_idx: int,
        obj_ids: list[int],
        input_points: list[list[list[list[float]]]] | torch.Tensor | None = None,
        input_labels: list[list[list[int]]] | torch.Tensor | None = None,
        input_boxes: list[list[list[float]]] | torch.Tensor | None = None,
        original_size: tuple[int, int] | None = None,
        clear_old_inputs: bool = True,
    ) -> Sam2VideoInferenceSession:
        """
        Process new points or boxes for a video frame and add them to the inference session.

        Args:
            inference_session (`Sam2VideoInferenceSession`):
                The inference session for the video.
            frame_idx (`int`):
                The index of the frame to process.
            obj_ids (`list[int]`):
                The object ID(s) to associate with the points or box.
                These can be any integers and can be reused later on to specify an object.
            input_points (`list[list[list[list[float]]]]`, `torch.Tensor`, *optional*):
                The points to add to the frame.
            input_labels (`list[list[list[int]]]`, `torch.Tensor`, *optional*):
                The labels for the points.
            input_boxes (`list[list[list[float]]]`, `torch.Tensor`, *optional*):
                The bounding boxes to add to the frame.
            original_size (`tuple[int, int]`, *optional*):
                The original size of the video. Provide when streaming.
            clear_old_inputs (`bool`, *optional*, defaults to `True`):
                Whether to clear old inputs for the object.
        """
        if original_size is not None:
            inference_session.video_height = original_size[0]
            inference_session.video_width = original_size[1]
        elif inference_session.video_height is None or inference_session.video_width is None:
            raise ValueError("original_size must be provided when adding points or boxes on a first streamed frame")

        original_sizes = [[inference_session.video_height, inference_session.video_width]]

        encoded_inputs = self(
            input_points=input_points,
            input_labels=input_labels,
            input_boxes=input_boxes,
            original_sizes=original_sizes,
            return_tensors="pt",
        )
        input_points = encoded_inputs.get("input_points", None)
        input_labels = encoded_inputs.get("input_labels", None)
        input_boxes = encoded_inputs.get("input_boxes", None)

        if input_points is not None:
            if input_points.shape[1] != len(obj_ids):
                raise ValueError(
                    f"Number of object ids ({len(obj_ids)}) does not match number of points ({input_points.shape[1]})"
                )
        else:
            input_points = torch.zeros(1, len(obj_ids), 0, 2, dtype=torch.float32)
        if input_labels is not None:
            if input_labels.shape[1] != len(obj_ids):
                raise ValueError(
                    f"Number of object ids ({len(obj_ids)}) does not match number of labels ({input_labels.shape[1]})"
                )
        else:
            input_labels = torch.zeros(1, len(obj_ids), 0, dtype=torch.int32)
        if input_boxes is not None:
            if input_boxes.shape[1] != len(obj_ids):
                raise ValueError(
                    f"Number of object ids ({len(obj_ids)}) does not match number of boxes ({input_boxes.shape[1]})"
                )

        if input_boxes is not None:
            if not clear_old_inputs:
                raise ValueError(
                    "cannot add box without clearing old points, since "
                    "box prompt must be provided before any point prompt "
                    "(please use clear_old_points=True instead)"
                )
            box_coords = input_boxes.reshape(1, -1, 2, 2)
            box_labels = torch.tensor([2, 3], dtype=torch.int32).repeat(1, box_coords.shape[1], 1)
            input_points = torch.cat([box_coords, input_points], dim=2)
            input_labels = torch.cat([box_labels, input_labels], dim=2)

        for obj_id, idx in zip(obj_ids, range(len(obj_ids))):
            obj_idx = inference_session.obj_id_to_idx(obj_id)
            input_points_for_obj = input_points[:, idx, :, :].unsqueeze(1)
            input_labels_for_obj = input_labels[:, idx, :].unsqueeze(1)
            # Handle existing points
            if not clear_old_inputs:
                existing_points = inference_session.point_inputs_per_obj[obj_idx].get(frame_idx, None)
                if existing_points is not None:
                    # Concatenate with existing points
                    input_points_for_obj = torch.cat(
                        [existing_points["point_coords"].to(input_points_for_obj.device), input_points_for_obj], dim=2
                    )
                    input_labels_for_obj = torch.cat(
                        [existing_points["point_labels"].to(input_labels_for_obj.device), input_labels_for_obj], dim=2
                    )
            point_inputs = {
                "point_coords": input_points_for_obj,
                "point_labels": input_labels_for_obj,
            }

            inference_session.add_point_inputs(obj_idx, frame_idx, point_inputs)
            inference_session.remove_mask_inputs(obj_idx, frame_idx)  # Clear any mask inputs

        inference_session.obj_with_new_inputs = obj_ids

    def process_new_mask_for_video_frame(
        self,
        inference_session: Sam2VideoInferenceSession,
        frame_idx: int,
        obj_ids: list[int],
        input_masks: np.ndarray | torch.Tensor | list[np.ndarray] | list[torch.Tensor],
    ):
        """
        Add new mask to a frame and add them to the inference session.

        Args:
            inference_session (`Sam2VideoInferenceSession`):
                The inference session for the video.
            frame_idx (`int`):
                The index of the frame to process.
            obj_ids (`list[int]`):
                The object ID(s) to associate with the mask.
                These can be any integers and can be reused later on to specify an object.
            input_masks (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, or `list[torch.Tensor]`):
                The mask(s) to add to the frame.
        """
        if not isinstance(input_masks, list):
            input_masks = [input_masks]
        if len(input_masks) != len(obj_ids):
            raise ValueError(
                f"Number of object ids ({len(obj_ids)}) does not match number of masks ({len(input_masks)})"
            )

        for obj_id, mask in zip(obj_ids, input_masks):
            obj_idx = inference_session.obj_id_to_idx(obj_id)

            device = inference_session.inference_device

            # Process mask
            if not isinstance(mask, torch.Tensor):
                mask = torch.tensor(mask, dtype=torch.bool)
            nb_dim = mask.dim()
            if nb_dim > 4 or nb_dim < 2:
                raise ValueError(f"Mask has an unsupported number of dimensions: {nb_dim}")
            for i in range(4 - nb_dim):
                mask = mask.unsqueeze(0)

            mask_H, mask_W = mask.shape[-2:]
            mask_inputs_orig = mask.to(device)
            mask_inputs_orig = mask_inputs_orig.float().to(device)

            # Resize mask if needed
            if mask_H != self.target_size or mask_W != self.target_size:
                mask_inputs = torch.nn.functional.interpolate(
                    mask_inputs_orig,
                    size=(self.target_size, self.target_size),
                    align_corners=False,
                    mode="bilinear",
                    antialias=True,
                )
                mask_inputs = (mask_inputs >= 0.5).float()
            else:
                mask_inputs = mask_inputs_orig

            inference_session.add_mask_inputs(obj_idx, frame_idx, mask_inputs)
            inference_session.remove_point_inputs(obj_idx, frame_idx)  # Clear any point inputs

        inference_session.obj_with_new_inputs = obj_ids


__all__ = ["Sam2VideoProcessor"]
