# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Iterable
from copy import deepcopy
from functools import lru_cache, partial
from typing import Any, Optional, Union

import numpy as np
from huggingface_hub.dataclasses import validate_typed_dict

from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from .image_transforms import (
    convert_to_rgb,
    get_resize_output_image_size,
    get_size_with_aspect_ratio,
    group_images_by_shape,
    reorder_images,
)
from .image_utils import (
    ChannelDimension,
    ImageInput,
    ImageType,
    SizeDict,
    get_image_size,
    get_image_size_for_max_height_width,
    get_image_type,
    infer_channel_dimension_format,
    make_flat_list_of_images,
    validate_kwargs,
    validate_preprocess_arguments,
)
from .processing_utils import ImagesKwargs, Unpack
from .utils import (
    TensorType,
    auto_docstring,
    is_torch_available,
    is_torchvision_available,
    is_vision_available,
    logging,
)
from .utils.import_utils import is_rocm_platform, is_torchdynamo_compiling


if is_vision_available():
    from .image_utils import PILImageResampling

if is_torch_available():
    import torch

if is_torchvision_available():
    from torchvision.transforms.v2 import functional as F

    from .image_utils import pil_torch_interpolation_mapping

else:
    pil_torch_interpolation_mapping = None


logger = logging.get_logger(__name__)


@lru_cache(maxsize=10)
def validate_fast_preprocess_arguments(
    do_rescale: Optional[bool] = None,
    rescale_factor: Optional[float] = None,
    do_normalize: Optional[bool] = None,
    image_mean: Optional[Union[float, list[float]]] = None,
    image_std: Optional[Union[float, list[float]]] = None,
    do_center_crop: Optional[bool] = None,
    crop_size: Optional[SizeDict] = None,
    do_resize: Optional[bool] = None,
    size: Optional[SizeDict] = None,
    interpolation: Optional["F.InterpolationMode"] = None,
    return_tensors: Optional[Union[str, TensorType]] = None,
    data_format: ChannelDimension = ChannelDimension.FIRST,
):
    """
    Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
    Raises `ValueError` if arguments incompatibility is caught.
    """
    validate_preprocess_arguments(
        do_rescale=do_rescale,
        rescale_factor=rescale_factor,
        do_normalize=do_normalize,
        image_mean=image_mean,
        image_std=image_std,
        do_center_crop=do_center_crop,
        crop_size=crop_size,
        do_resize=do_resize,
        size=size,
        interpolation=interpolation,
    )
    # Extra checks for ImageProcessorFast
    if return_tensors is not None and return_tensors != "pt":
        raise ValueError("Only returning PyTorch tensors is currently supported.")

    if data_format != ChannelDimension.FIRST:
        raise ValueError("Only channel first data format is currently supported.")


def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
    """
    Squeezes a tensor, but only if the axis specified has dim 1.
    """
    if axis is None:
        return tensor.squeeze()

    try:
        return tensor.squeeze(axis=axis)
    except ValueError:
        return tensor


def max_across_indices(values: Iterable[Any]) -> list[Any]:
    """
    Return the maximum value across all indices of an iterable of values.
    """
    return [max(values_i) for values_i in zip(*values)]


def get_max_height_width(images: list["torch.Tensor"]) -> tuple[int, ...]:
    """
    Get the maximum height and width across all images in a batch.
    """

    _, max_height, max_width = max_across_indices([img.shape for img in images])

    return (max_height, max_width)


def divide_to_patches(
    image: Union[np.ndarray, "torch.Tensor"], patch_size: int
) -> list[Union[np.ndarray, "torch.Tensor"]]:
    """
    Divides an image into patches of a specified size.

    Args:
        image (`Union[np.array, "torch.Tensor"]`):
            The input image.
        patch_size (`int`):
            The size of each patch.
    Returns:
        list: A list of Union[np.array, "torch.Tensor"] representing the patches.
    """
    patches = []
    height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
    for i in range(0, height, patch_size):
        for j in range(0, width, patch_size):
            patch = image[:, i : i + patch_size, j : j + patch_size]
            patches.append(patch)

    return patches


@auto_docstring
class BaseImageProcessorFast(BaseImageProcessor):
    r"""
    Base class for fast image processors using PyTorch and TorchVision for image transformations.

    This class provides a complete implementation for standard image preprocessing operations (resize, crop, rescale,
    normalize) with GPU support and batch processing optimizations. Most image processors can be implemented by simply
    setting class attributes; only processors requiring custom logic need to override methods.

    Basic Implementation
    --------------------

    For processors that only need standard operations (resize, center crop, rescale, normalize), define class
    attributes:

        class MyImageProcessorFast(BaseImageProcessorFast):
            resample = PILImageResampling.BILINEAR
            image_mean = IMAGENET_DEFAULT_MEAN
            image_std = IMAGENET_DEFAULT_STD
            size = {"height": 224, "width": 224}
            do_resize = True
            do_rescale = True
            do_normalize = True

    Custom Processing
    -----------------

    Override `_preprocess` (most common):
        For custom image processing logic, override `_preprocess`. This method receives a list of torch tensors with
        channel dimension first and should return a BatchFeature. Use `group_images_by_shape` and `reorder_images` for
        efficient batch processing:

            def _preprocess(
                self,
                images: list[torch.Tensor],
                do_resize: bool,
                size: SizeDict,
                # ... other parameters
                **kwargs,
            ) -> BatchFeature:
                # Group images by shape for batched operations
                grouped_images, indices = group_images_by_shape(images)
                processed_groups = {}

                for shape, stacked_images in grouped_images.items():
                    if do_resize:
                        stacked_images = self.resize(stacked_images, size)
                    # Custom processing here
                    processed_groups[shape] = stacked_images

                processed_images = reorder_images(processed_groups, indices)
                return BatchFeature(data={"pixel_values": torch.stack(processed_images)})

    Override `_preprocess_image_like_inputs` (for additional inputs):
        For processors handling multiple input types (e.g., images + segmentation maps), override this method:

            def _preprocess_image_like_inputs(
                self,
                images: ImageInput,
                segmentation_maps: Optional[ImageInput] = None,
                do_convert_rgb: bool,
                input_data_format: ChannelDimension,
                device: Optional[torch.device] = None,
                **kwargs,
            ) -> BatchFeature:
                images = self._prepare_image_like_inputs(images, do_convert_rgb, input_data_format, device)
                batch_feature = self._preprocess(images, **kwargs)

                if segmentation_maps is not None:
                    # Process segmentation maps separately
                    maps = self._prepare_image_like_inputs(segmentation_maps, ...)
                    batch_feature["labels"] = self._preprocess(maps, ...)

                return batch_feature

    Override `_further_process_kwargs` (for custom kwargs formatting):
        To format custom kwargs before validation:

            def _further_process_kwargs(self, custom_param=None, **kwargs):
                kwargs = super()._further_process_kwargs(**kwargs)
                if custom_param is not None:
                    kwargs["custom_param"] = self._format_custom_param(custom_param)
                return kwargs

    Override `_validate_preprocess_kwargs` (for custom validation):
        To add custom validation logic:

            def _validate_preprocess_kwargs(self, custom_param=None, **kwargs):
                super()._validate_preprocess_kwargs(**kwargs)
                if custom_param is not None and custom_param < 0:
                    raise ValueError("custom_param must be non-negative")

    Override `_prepare_images_structure` (for nested inputs):
        By default, nested image lists are flattened. Override to preserve structure:

            def _prepare_images_structure(self, images, expected_ndims=3):
                # Custom logic to handle nested structure
                return images  # Return as-is or with custom processing

    Custom Parameters
    -----------------

    To add parameters beyond `ImagesKwargs`, create a custom kwargs class and set it as `valid_kwargs`:

        class MyImageProcessorKwargs(ImagesKwargs):
            custom_param: Optional[int] = None
            another_param: Optional[bool] = None

        class MyImageProcessorFast(BaseImageProcessorFast):
            valid_kwargs = MyImageProcessorKwargs
            custom_param = 10  # default value

            def _preprocess(self, images, custom_param, **kwargs):
                # Use custom_param in processing
                ...

    Key Notes
    ---------

    - Images in `_preprocess` are always torch tensors with channel dimension first, regardless of input format
    - Arguments not provided by users default to class attribute values
    - Use batch processing utilities (`group_images_by_shape`, `reorder_images`) for GPU efficiency
    - Image loading, format conversion, and argument handling are automatic - focus only on processing logic
    """

    resample = None
    image_mean = None
    image_std = None
    size = None
    default_to_square = True
    crop_size = None
    do_resize = None
    do_center_crop = None
    do_pad = None
    pad_size = None
    do_rescale = None
    rescale_factor = 1 / 255
    do_normalize = None
    do_convert_rgb = None
    return_tensors = None
    data_format = ChannelDimension.FIRST
    input_data_format = None
    device = None
    model_input_names = ["pixel_values"]
    image_seq_length = None
    valid_kwargs = ImagesKwargs
    unused_kwargs = None

    def __init__(self, **kwargs: Unpack[ImagesKwargs]):
        super().__init__(**kwargs)
        kwargs = self.filter_out_unused_kwargs(kwargs)
        size = kwargs.pop("size", self.size)
        self.size = (
            get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
            if size is not None
            else None
        )
        crop_size = kwargs.pop("crop_size", self.crop_size)
        self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
        pad_size = kwargs.pop("pad_size", self.pad_size)
        self.pad_size = get_size_dict(size=pad_size, param_name="pad_size") if pad_size is not None else None

        for key in self.valid_kwargs.__annotations__:
            kwarg = kwargs.pop(key, None)
            if kwarg is not None:
                setattr(self, key, kwarg)
            else:
                setattr(self, key, deepcopy(getattr(self, key, None)))

        # get valid kwargs names
        self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())

    @property
    def is_fast(self) -> bool:
        """
        `bool`: Whether or not this image processor is a fast processor (backed by PyTorch and TorchVision).
        """
        return True

    def pad(
        self,
        images: list["torch.Tensor"],
        pad_size: SizeDict = None,
        fill_value: Optional[int] = 0,
        padding_mode: Optional[str] = "constant",
        return_mask: bool = False,
        disable_grouping: Optional[bool] = False,
        is_nested: Optional[bool] = False,
        **kwargs,
    ) -> Union[tuple["torch.Tensor", "torch.Tensor"], "torch.Tensor"]:
        """
        Pads images to `(pad_size["height"], pad_size["width"])` or to the largest size in the batch.

        Args:
            images (`list[torch.Tensor]`):
                Images to pad.
            pad_size (`SizeDict`, *optional*):
                Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
            fill_value (`int`, *optional*, defaults to `0`):
                The constant value used to fill the padded area.
            padding_mode (`str`, *optional*, defaults to "constant"):
                The padding mode to use. Can be any of the modes supported by
                `torch.nn.functional.pad` (e.g. constant, reflection, replication).
            return_mask (`bool`, *optional*, defaults to `False`):
                Whether to return a pixel mask to denote padded regions.
            disable_grouping (`bool`, *optional*, defaults to `False`):
                Whether to disable grouping of images by size.

        Returns:
            `Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]`: The padded images and pixel masks if `return_mask` is `True`.
        """
        if pad_size is not None:
            if not (pad_size.height and pad_size.width):
                raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.")
            pad_size = (pad_size.height, pad_size.width)
        else:
            pad_size = get_max_height_width(images)

        grouped_images, grouped_images_index = group_images_by_shape(
            images, disable_grouping=disable_grouping, is_nested=is_nested
        )
        processed_images_grouped = {}
        processed_masks_grouped = {}
        for shape, stacked_images in grouped_images.items():
            image_size = stacked_images.shape[-2:]
            padding_height = pad_size[0] - image_size[0]
            padding_width = pad_size[1] - image_size[1]
            if padding_height < 0 or padding_width < 0:
                raise ValueError(
                    f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the "
                    f"image size. Got pad_size={pad_size}, image_size={image_size}."
                )
            if image_size != pad_size:
                padding = (0, 0, padding_width, padding_height)
                stacked_images = F.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode)
            processed_images_grouped[shape] = stacked_images

            if return_mask:
                # keep only one from the channel dimension in pixel mask
                stacked_masks = torch.zeros_like(stacked_images, dtype=torch.int64)[..., 0, :, :]
                stacked_masks[..., : image_size[0], : image_size[1]] = 1
                processed_masks_grouped[shape] = stacked_masks

        processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=is_nested)
        if return_mask:
            processed_masks = reorder_images(processed_masks_grouped, grouped_images_index, is_nested=is_nested)
            return processed_images, processed_masks

        return processed_images

    def resize(
        self,
        image: "torch.Tensor",
        size: SizeDict,
        interpolation: Optional["F.InterpolationMode"] = None,
        antialias: bool = True,
        **kwargs,
    ) -> "torch.Tensor":
        """
        Resize an image to `(size["height"], size["width"])`.

        Args:
            image (`torch.Tensor`):
                Image to resize.
            size (`SizeDict`):
                Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
            interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
                `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
            antialias (`bool`, *optional*, defaults to `True`):
                Whether to use antialiasing.

        Returns:
            `torch.Tensor`: The resized image.
        """
        interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
        if size.shortest_edge and size.longest_edge:
            # Resize the image so that the shortest edge or the longest edge is of the given size
            # while maintaining the aspect ratio of the original image.
            new_size = get_size_with_aspect_ratio(
                image.size()[-2:],
                size.shortest_edge,
                size.longest_edge,
            )
        elif size.shortest_edge:
            new_size = get_resize_output_image_size(
                image,
                size=size.shortest_edge,
                default_to_square=False,
                input_data_format=ChannelDimension.FIRST,
            )
        elif size.max_height and size.max_width:
            new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width)
        elif size.height and size.width:
            new_size = (size.height, size.width)
        else:
            raise ValueError(
                "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
                f" {size}."
            )
        # This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs
        # Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209
        # TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd)
        if is_torchdynamo_compiling() and is_rocm_platform():
            return self.compile_friendly_resize(image, new_size, interpolation, antialias)
        return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)

    @staticmethod
    def compile_friendly_resize(
        image: "torch.Tensor",
        new_size: tuple[int, int],
        interpolation: Optional["F.InterpolationMode"] = None,
        antialias: bool = True,
    ) -> "torch.Tensor":
        """
        A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
        """
        if image.dtype == torch.uint8:
            # 256 is used on purpose instead of 255 to avoid numerical differences
            # see https://github.com/huggingface/transformers/pull/38540#discussion_r2127165652
            image = image.float() / 256
            image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
            image = image * 256
            # torch.where is used on purpose instead of torch.clamp to avoid bug in torch.compile
            # see https://github.com/huggingface/transformers/pull/38540#discussion_r2126888471
            image = torch.where(image > 255, 255, image)
            image = torch.where(image < 0, 0, image)
            image = image.round().to(torch.uint8)
        else:
            image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
        return image

    def rescale(
        self,
        image: "torch.Tensor",
        scale: float,
        **kwargs,
    ) -> "torch.Tensor":
        """
        Rescale an image by a scale factor. image = image * scale.

        Args:
            image (`torch.Tensor`):
                Image to rescale.
            scale (`float`):
                The scaling factor to rescale pixel values by.

        Returns:
            `torch.Tensor`: The rescaled image.
        """
        return image * scale

    def normalize(
        self,
        image: "torch.Tensor",
        mean: Union[float, Iterable[float]],
        std: Union[float, Iterable[float]],
        **kwargs,
    ) -> "torch.Tensor":
        """
        Normalize an image. image = (image - image_mean) / image_std.

        Args:
            image (`torch.Tensor`):
                Image to normalize.
            mean (`torch.Tensor`, `float` or `Iterable[float]`):
                Image mean to use for normalization.
            std (`torch.Tensor`, `float` or `Iterable[float]`):
                Image standard deviation to use for normalization.

        Returns:
            `torch.Tensor`: The normalized image.
        """
        return F.normalize(image, mean, std)

    @lru_cache(maxsize=10)
    def _fuse_mean_std_and_rescale_factor(
        self,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[Union[float, list[float]]] = None,
        image_std: Optional[Union[float, list[float]]] = None,
        do_rescale: Optional[bool] = None,
        rescale_factor: Optional[float] = None,
        device: Optional["torch.device"] = None,
    ) -> tuple:
        if do_rescale and do_normalize:
            # Fused rescale and normalize
            image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
            image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
            do_rescale = False
        return image_mean, image_std, do_rescale

    def rescale_and_normalize(
        self,
        images: "torch.Tensor",
        do_rescale: bool,
        rescale_factor: float,
        do_normalize: bool,
        image_mean: Union[float, list[float]],
        image_std: Union[float, list[float]],
    ) -> "torch.Tensor":
        """
        Rescale and normalize images.
        """
        image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
            do_normalize=do_normalize,
            image_mean=image_mean,
            image_std=image_std,
            do_rescale=do_rescale,
            rescale_factor=rescale_factor,
            device=images.device,
        )
        # if/elif as we use fused rescale and normalize if both are set to True
        if do_normalize:
            images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
        elif do_rescale:
            images = self.rescale(images, rescale_factor)

        return images

    def center_crop(
        self,
        image: "torch.Tensor",
        size: SizeDict,
        **kwargs,
    ) -> "torch.Tensor":
        """
        Note: override torchvision's center_crop to have the same behavior as the slow processor.
        Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
        any edge, the image is padded with 0's and then center cropped.

        Args:
            image (`"torch.Tensor"`):
                Image to center crop.
            size (`dict[str, int]`):
                Size of the output image.

        Returns:
            `torch.Tensor`: The center cropped image.
        """
        if size.height is None or size.width is None:
            raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
        image_height, image_width = image.shape[-2:]
        crop_height, crop_width = size.height, size.width

        if crop_width > image_width or crop_height > image_height:
            padding_ltrb = [
                (crop_width - image_width) // 2 if crop_width > image_width else 0,
                (crop_height - image_height) // 2 if crop_height > image_height else 0,
                (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
                (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
            ]
            image = F.pad(image, padding_ltrb, fill=0)  # PIL uses fill value 0
            image_height, image_width = image.shape[-2:]
            if crop_width == image_width and crop_height == image_height:
                return image

        crop_top = int((image_height - crop_height) / 2.0)
        crop_left = int((image_width - crop_width) / 2.0)
        return F.crop(image, crop_top, crop_left, crop_height, crop_width)

    def convert_to_rgb(
        self,
        image: ImageInput,
    ) -> ImageInput:
        """
        Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
        as is.
        Args:
            image (ImageInput):
                The image to convert.

        Returns:
            ImageInput: The converted image.
        """
        return convert_to_rgb(image)

    def filter_out_unused_kwargs(self, kwargs: dict):
        """
        Filter out the unused kwargs from the kwargs dictionary.
        """
        if self.unused_kwargs is None:
            return kwargs

        for kwarg_name in self.unused_kwargs:
            if kwarg_name in kwargs:
                logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.")
                kwargs.pop(kwarg_name)
        return kwargs

    def _prepare_images_structure(
        self,
        images: ImageInput,
        expected_ndims: int = 3,
    ) -> ImageInput:
        """
        Prepare the images structure for processing.

        Args:
            images (`ImageInput`):
                The input images to process.

        Returns:
            `ImageInput`: The images with a valid nesting.
        """
        # Checks for `str` in case of URL/local path and optionally loads images
        images = self.fetch_images(images)
        return make_flat_list_of_images(images, expected_ndims=expected_ndims)

    def _process_image(
        self,
        image: ImageInput,
        do_convert_rgb: Optional[bool] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        device: Optional["torch.device"] = None,
    ) -> "torch.Tensor":
        image_type = get_image_type(image)
        if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
            raise ValueError(f"Unsupported input image type {image_type}")

        if do_convert_rgb:
            image = self.convert_to_rgb(image)

        if image_type == ImageType.PIL:
            image = F.pil_to_tensor(image)
        elif image_type == ImageType.NUMPY:
            # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
            image = torch.from_numpy(image).contiguous()

        # If the image is 2D, we need to unsqueeze it to add a channel dimension for processing
        if image.ndim == 2:
            image = image.unsqueeze(0)

        # Infer the channel dimension format if not provided
        if input_data_format is None:
            input_data_format = infer_channel_dimension_format(image)

        if input_data_format == ChannelDimension.LAST:
            # We force the channel dimension to be first for torch tensors as this is what torchvision expects.
            image = image.permute(2, 0, 1).contiguous()

        # Now that we have torch tensors, we can move them to the right device
        if device is not None:
            image = image.to(device)

        return image

    def _prepare_image_like_inputs(
        self,
        images: ImageInput,
        do_convert_rgb: Optional[bool] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        device: Optional["torch.device"] = None,
        expected_ndims: int = 3,
    ) -> list["torch.Tensor"]:
        """
        Prepare image-like inputs for processing.

        Args:
            images (`ImageInput`):
                The image-like inputs to process.
            do_convert_rgb (`bool`, *optional*):
                Whether to convert the images to RGB.
            input_data_format (`str` or `ChannelDimension`, *optional*):
                The input data format of the images.
            device (`torch.device`, *optional*):
                The device to put the processed images on.
            expected_ndims (`int`, *optional*):
                The expected number of dimensions for the images. (can be 2 for segmentation maps etc.)

        Returns:
            List[`torch.Tensor`]: The processed images.
        """

        # Get structured images (potentially nested)
        images = self._prepare_images_structure(images, expected_ndims=expected_ndims)

        process_image_partial = partial(
            self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
        )

        # Check if we have nested structure, assuming the nesting is consistent
        has_nested_structure = len(images) > 0 and isinstance(images[0], (list, tuple))

        if has_nested_structure:
            processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
        else:
            processed_images = [process_image_partial(img) for img in images]

        return processed_images

    def _further_process_kwargs(
        self,
        size: Optional[SizeDict] = None,
        crop_size: Optional[SizeDict] = None,
        pad_size: Optional[SizeDict] = None,
        default_to_square: Optional[bool] = None,
        image_mean: Optional[Union[float, list[float]]] = None,
        image_std: Optional[Union[float, list[float]]] = None,
        data_format: Optional[ChannelDimension] = None,
        **kwargs,
    ) -> dict:
        """
        Update kwargs that need further processing before being validated
        Can be overridden by subclasses to customize the processing of kwargs.
        """
        if kwargs is None:
            kwargs = {}
        if size is not None:
            size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
        if crop_size is not None:
            crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
        if pad_size is not None:
            pad_size = SizeDict(**get_size_dict(size=pad_size, param_name="pad_size"))
        if isinstance(image_mean, list):
            image_mean = tuple(image_mean)
        if isinstance(image_std, list):
            image_std = tuple(image_std)
        if data_format is None:
            data_format = ChannelDimension.FIRST

        kwargs["size"] = size
        kwargs["crop_size"] = crop_size
        kwargs["pad_size"] = pad_size
        kwargs["image_mean"] = image_mean
        kwargs["image_std"] = image_std
        kwargs["data_format"] = data_format

        # torch resize uses interpolation instead of resample
        # Check if resample is an int before checking if it's an instance of PILImageResampling
        # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
        # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
        resample = kwargs.pop("resample")
        kwargs["interpolation"] = (
            pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
        )

        return kwargs

    def _validate_preprocess_kwargs(
        self,
        do_rescale: Optional[bool] = None,
        rescale_factor: Optional[float] = None,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[Union[float, tuple[float]]] = None,
        image_std: Optional[Union[float, tuple[float]]] = None,
        do_resize: Optional[bool] = None,
        size: Optional[SizeDict] = None,
        do_center_crop: Optional[bool] = None,
        crop_size: Optional[SizeDict] = None,
        interpolation: Optional["F.InterpolationMode"] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        data_format: Optional[ChannelDimension] = None,
        **kwargs,
    ):
        """
        validate the kwargs for the preprocess method.
        """
        validate_fast_preprocess_arguments(
            do_rescale=do_rescale,
            rescale_factor=rescale_factor,
            do_normalize=do_normalize,
            image_mean=image_mean,
            image_std=image_std,
            do_resize=do_resize,
            size=size,
            do_center_crop=do_center_crop,
            crop_size=crop_size,
            interpolation=interpolation,
            return_tensors=return_tensors,
            data_format=data_format,
        )

    @auto_docstring
    def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
        # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
        validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)

        # Perform type validation on received kwargs
        validate_typed_dict(self.valid_kwargs, kwargs)

        # Set default kwargs from self. This ensures that if a kwarg is not provided
        # by the user, it gets its default value from the instance, or is set to None.
        for kwarg_name in self._valid_kwargs_names:
            kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))

        # Extract parameters that are only used for preparing the input images
        do_convert_rgb = kwargs.pop("do_convert_rgb")
        input_data_format = kwargs.pop("input_data_format")
        device = kwargs.pop("device")

        # Update kwargs that need further processing before being validated
        kwargs = self._further_process_kwargs(**kwargs)

        # Validate kwargs
        self._validate_preprocess_kwargs(**kwargs)

        # Pop kwargs that are not needed in _preprocess
        kwargs.pop("data_format")

        return self._preprocess_image_like_inputs(
            images, *args, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device, **kwargs
        )

    def _preprocess_image_like_inputs(
        self,
        images: ImageInput,
        *args,
        do_convert_rgb: bool,
        input_data_format: ChannelDimension,
        device: Optional[Union[str, "torch.device"]] = None,
        **kwargs: Unpack[ImagesKwargs],
    ) -> BatchFeature:
        """
        Preprocess image-like inputs.
        To be overridden by subclasses when image-like inputs other than images should be processed.
        It can be used for segmentation maps, depth maps, etc.
        """
        # Prepare input images
        images = self._prepare_image_like_inputs(
            images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
        )
        return self._preprocess(images, *args, **kwargs)

    def _preprocess(
        self,
        images: list["torch.Tensor"],
        do_resize: bool,
        size: SizeDict,
        interpolation: Optional["F.InterpolationMode"],
        do_center_crop: bool,
        crop_size: SizeDict,
        do_rescale: bool,
        rescale_factor: float,
        do_normalize: bool,
        image_mean: Optional[Union[float, list[float]]],
        image_std: Optional[Union[float, list[float]]],
        do_pad: Optional[bool],
        pad_size: Optional[SizeDict],
        disable_grouping: Optional[bool],
        return_tensors: Optional[Union[str, TensorType]],
        **kwargs,
    ) -> BatchFeature:
        # Group images by size for batched resizing
        grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
        resized_images_grouped = {}
        for shape, stacked_images in grouped_images.items():
            if do_resize:
                stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
            resized_images_grouped[shape] = stacked_images
        resized_images = reorder_images(resized_images_grouped, grouped_images_index)

        # Group images by size for further processing
        # Needed in case do_resize is False, or resize returns images with different sizes
        grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
        processed_images_grouped = {}
        for shape, stacked_images in grouped_images.items():
            if do_center_crop:
                stacked_images = self.center_crop(stacked_images, crop_size)
            # Fused rescale and normalize
            stacked_images = self.rescale_and_normalize(
                stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
            )
            processed_images_grouped[shape] = stacked_images
        processed_images = reorder_images(processed_images_grouped, grouped_images_index)

        if do_pad:
            processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)

        return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

    def to_dict(self):
        encoder_dict = super().to_dict()

        # Filter out None values that are class defaults, but preserve explicitly set None values
        filtered_dict = {}
        for key, value in encoder_dict.items():
            if value is None:
                class_default = getattr(type(self), key, "NOT_FOUND")
                # Keep None if user explicitly set it (class default is non-None)
                if class_default != "NOT_FOUND" and class_default is not None:
                    filtered_dict[key] = value
            else:
                filtered_dict[key] = value

        filtered_dict.pop("_valid_processor_keys", None)
        filtered_dict.pop("_valid_kwargs_names", None)
        return filtered_dict
