#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_lightglue.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional

import torch
import torchvision.transforms.v2.functional as tvF

from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import BaseImageProcessorFast
from ...image_transforms import group_images_by_shape, reorder_images
from ...image_utils import (
    ImageInput,
    ImageType,
    PILImageResampling,
    SizeDict,
    get_image_type,
    is_pil_image,
    is_valid_image,
    is_vision_available,
    to_numpy_array,
)
from ...processing_utils import Unpack
from ...utils import TensorType, auto_docstring
from .image_processing_lightglue import LightGlueImageProcessorKwargs


if TYPE_CHECKING:
    from .modeling_lightglue import LightGlueKeypointMatchingOutput

if is_vision_available():
    from PIL import Image, ImageDraw


def _is_valid_image(image):
    return is_pil_image(image) or (
        is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3
    )


def flatten_pair_images(images):
    # Handle the pair validation and flattening similar to slow processor
    if isinstance(images, list):
        if len(images) == 2 and all((_is_valid_image(image) or isinstance(image, torch.Tensor)) for image in images):
            # Single pair of images - keep as is, they'll be processed by the base class
            return images
        elif all(
            isinstance(image_pair, list)
            and len(image_pair) == 2
            and all(_is_valid_image(image) or isinstance(image, torch.Tensor) for image in image_pair)
            for image_pair in images
        ):
            # Multiple pairs - flatten them
            images = [image for image_pair in images for image in image_pair]
            return images
    raise ValueError(
        "Input images must be a one of the following :",
        " - A pair of PIL images.",
        " - A pair of 3D arrays.",
        " - A list of pairs of PIL images.",
        " - A list of pairs of 3D arrays.",
    )


def is_grayscale(
    image: "torch.Tensor",
):
    """Checks if an image is grayscale (all RGB channels are identical)."""
    if image.ndim < 3 or image.shape[0 if image.ndim == 3 else 1] == 1:
        return True
    return torch.all(image[..., 0, :, :] == image[..., 1, :, :]) and torch.all(
        image[..., 1, :, :] == image[..., 2, :, :]
    )


def convert_to_grayscale(
    image: "torch.Tensor",
) -> "torch.Tensor":
    """
    Converts an image to grayscale format using the NTSC formula. Only support torch.Tensor.

    This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each
    channel, because of an issue that is discussed in :
    https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446

    Args:
        image (torch.Tensor):
            The image to convert.
    """
    if is_grayscale(image):
        return image
    return tvF.rgb_to_grayscale(image, num_output_channels=3)


@auto_docstring
class LightGlueImageProcessorFast(BaseImageProcessorFast):
    resample = PILImageResampling.BILINEAR
    size = {"height": 480, "width": 640}
    default_to_square = False
    do_resize = True
    do_rescale = True
    rescale_factor = 1 / 255
    do_normalize = None
    valid_kwargs = LightGlueImageProcessorKwargs

    def __init__(self, **kwargs: Unpack[LightGlueImageProcessorKwargs]):
        super().__init__(**kwargs)

    @auto_docstring
    def preprocess(self, images: ImageInput, **kwargs: Unpack[LightGlueImageProcessorKwargs]) -> BatchFeature:
        return super().preprocess(images, **kwargs)

    def _prepare_images_structure(
        self,
        images: ImageInput,
        **kwargs,
    ) -> ImageInput:
        # we need to handle image pairs validation and flattening
        images = self.fetch_images(images)
        return flatten_pair_images(images)

    def _preprocess(
        self,
        images: list["torch.Tensor"],
        size: dict[str, int] | SizeDict,
        rescale_factor: float,
        do_rescale: bool,
        do_resize: bool,
        interpolation: Optional["tvF.InterpolationMode"],
        do_grayscale: bool,
        disable_grouping: bool,
        return_tensors: str | TensorType,
        **kwargs,
    ) -> BatchFeature:
        grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
        processed_images_grouped = {}

        for shape, stacked_images in grouped_images.items():
            if do_resize:
                stacked_images = self.resize(stacked_images, size=size, interpolation=interpolation)
            processed_images_grouped[shape] = stacked_images
        resized_images = reorder_images(processed_images_grouped, grouped_images_index)

        grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
        processed_images_grouped = {}
        for shape, stacked_images in grouped_images.items():
            if do_rescale:
                stacked_images = self.rescale(stacked_images, rescale_factor)
            if do_grayscale:
                stacked_images = convert_to_grayscale(stacked_images)
            processed_images_grouped[shape] = stacked_images

        processed_images = reorder_images(processed_images_grouped, grouped_images_index)

        # Convert back to pairs format
        image_pairs = [processed_images[i : i + 2] for i in range(0, len(processed_images), 2)]

        # Stack each pair into a single tensor to match slow processor format
        stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]

        # Return in same format as slow processor

        return BatchFeature(data={"pixel_values": stacked_pairs}, tensor_type=return_tensors)

    def post_process_keypoint_matching(
        self,
        outputs: "LightGlueKeypointMatchingOutput",
        target_sizes: TensorType | list[tuple],
        threshold: float = 0.0,
    ) -> list[dict[str, torch.Tensor]]:
        """
        Converts the raw output of [`LightGlueKeypointMatchingOutput`] into lists of keypoints, scores and descriptors
        with coordinates absolute to the original image sizes.
        Args:
            outputs ([`LightGlueKeypointMatchingOutput`]):
                Raw outputs of the model.
            target_sizes (`torch.Tensor` or `list[tuple[tuple[int, int]]]`, *optional*):
                Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`tuple[int, int]`) containing the
                target size `(height, width)` of each image in the batch. This must be the original image size (before
                any processing).
            threshold (`float`, *optional*, defaults to 0.0):
                Threshold to filter out the matches with low scores.
        Returns:
            `list[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
            of the pair, the matching scores and the matching indices.
        """
        if outputs.mask.shape[0] != len(target_sizes):
            raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
        if not all(len(target_size) == 2 for target_size in target_sizes):
            raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")

        if isinstance(target_sizes, list):
            image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device)
        else:
            if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
                raise ValueError(
                    "Each element of target_sizes must contain the size (h, w) of each image of the batch"
                )
            image_pair_sizes = target_sizes

        keypoints = outputs.keypoints.clone()
        keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
        keypoints = keypoints.to(torch.int32)

        results = []
        for mask_pair, keypoints_pair, matches, scores in zip(
            outputs.mask, keypoints, outputs.matches[:, 0], outputs.matching_scores[:, 0]
        ):
            mask0 = mask_pair[0] > 0
            mask1 = mask_pair[1] > 0
            keypoints0 = keypoints_pair[0][mask0]
            keypoints1 = keypoints_pair[1][mask1]
            matches0 = matches[mask0]
            scores0 = scores[mask0]

            # Filter out matches with low scores, invalid matches, and out-of-bounds indices
            valid_matches = (scores0 > threshold) & (matches0 > -1) & (matches0 < keypoints1.shape[0])

            matched_keypoints0 = keypoints0[valid_matches]
            matched_keypoints1 = keypoints1[matches0[valid_matches]]
            matching_scores = scores0[valid_matches]

            results.append(
                {
                    "keypoints0": matched_keypoints0,
                    "keypoints1": matched_keypoints1,
                    "matching_scores": matching_scores,
                }
            )

        return results

    def visualize_keypoint_matching(
        self,
        images,
        keypoint_matching_output: list[dict[str, torch.Tensor]],
    ) -> list["Image.Image"]:
        """
        Plots the image pairs side by side with the detected keypoints as well as the matching between them.

        Args:
            images:
                Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2
                images or a list of list of 2 images list with pixel values ranging from 0 to 255.
            keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
                A post processed keypoint matching output

        Returns:
            `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
            keypoints as well as the matching between them.
        """
        from .image_processing_lightglue import validate_and_format_image_pairs

        images = validate_and_format_image_pairs(images)
        images = [to_numpy_array(image) for image in images]
        image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]

        results = []
        for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
            height0, width0 = image_pair[0].shape[:2]
            height1, width1 = image_pair[1].shape[:2]
            plot_image = torch.zeros((max(height0, height1), width0 + width1, 3), dtype=torch.uint8)
            plot_image[:height0, :width0] = torch.from_numpy(image_pair[0])
            plot_image[:height1, width0:] = torch.from_numpy(image_pair[1])

            plot_image_pil = Image.fromarray(plot_image.numpy())
            draw = ImageDraw.Draw(plot_image_pil)

            keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
            keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
            for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
                keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
            ):
                color = self._get_color(matching_score)
                draw.line(
                    (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
                    fill=color,
                    width=3,
                )
                draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
                draw.ellipse(
                    (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
                    fill="black",
                )

            results.append(plot_image_pil)
        return results

    def _get_color(self, score):
        """Maps a score to a color."""
        r = int(255 * (1 - score))
        g = int(255 * score)
        b = 0
        return r, g, b


__all__ = ["LightGlueImageProcessorFast"]
