diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 18218534c0cc27..b626ff3dd717ec 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -112,17 +112,9 @@ def values(self): def items(self): return self.data.items() - def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): - """ - Convert the inner content to tensors. - - Args: - tensor_type (`str` or [`~utils.TensorType`], *optional*): - The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If - `None`, no modification is done. - """ + def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None): if tensor_type is None: - return self + return None, None # Convert to TensorType if not isinstance(tensor_type, TensorType): @@ -167,6 +159,21 @@ def as_tensor(value, dtype=None): return np.asarray(value, dtype=dtype) is_tensor = is_numpy_array + return is_tensor, as_tensor + + def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): + """ + Convert the inner content to tensors. + + Args: + tensor_type (`str` or [`~utils.TensorType`], *optional*): + The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If + `None`, no modification is done. + """ + if tensor_type is None: + return self + + is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type) # Do the tensor conversion in batch for key, value in self.items(): diff --git a/src/transformers/models/fuyu/image_processing_fuyu.py b/src/transformers/models/fuyu/image_processing_fuyu.py index 2d83e18af40788..0e415980c97fdd 100644 --- a/src/transformers/models/fuyu/image_processing_fuyu.py +++ b/src/transformers/models/fuyu/image_processing_fuyu.py @@ -1,27 +1,182 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Fuyu.""" + import math -from typing import List, Union +from typing import Dict, List, Optional, Union import numpy as np -from ...image_processing_utils import BaseImageProcessor +from ...image_processing_utils import BaseImageProcessor, BatchFeature from ...image_transforms import ( - normalize, pad, resize, + to_channel_dimension_format, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + is_valid_image, + make_list_of_images, + to_numpy_array, +) +from ...utils import ( + TensorType, + is_torch_available, + is_torch_device, + is_torch_dtype, + logging, + requires_backends, ) -from ...image_utils import to_numpy_array -from ...utils import is_torch_available, is_vision_available, logging, requires_backends - -if is_vision_available(): - import PIL if is_torch_available(): import torch + logger = logging.get_logger(__name__) +def make_list_of_list_of_images( + images: Union[List[List[ImageInput]], List[ImageInput], ImageInput] +) -> List[List[ImageInput]]: + if is_valid_image(images): + return [[images]] + + if isinstance(images, list) and all(isinstance(image, list) for image in images): + return images + + if isinstance(images, list): + return [make_list_of_images(image) for image in images] + + raise ValueError("images must be a list of list of images or a list of images or an image.") + + +class FuyuBatchFeature(BatchFeature): + """ + BatchFeature class for Fuyu image processor and processor. + + The outputs dictionary from the processors contains a mix of tensors and lists of tensors. + """ + + def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): + """ + Convert the inner content to tensors. + + Args: + tensor_type (`str` or [`~utils.TensorType`], *optional*): + The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If + `None`, no modification is done. + """ + if tensor_type is None: + return self + + is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type=tensor_type) + + def _convert_tensor(elem): + if is_tensor(elem): + return elem + return as_tensor(elem) + + def _safe_convert_tensor(elem): + try: + return _convert_tensor(elem) + except: # noqa E722 + if key == "overflowing_values": + raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") + raise ValueError( + "Unable to create tensor, you should probably activate padding " + "with 'padding=True' to have batched tensors with the same length." + ) + + # Do the tensor conversion in batch + for key, value in self.items(): + if isinstance(value, list) and isinstance(value[0], list): + # List[List[Any]] -> List[List[Tensor]] + self[key] = [[_safe_convert_tensor(elem) for elem in elems] for elems in value] + elif isinstance(value, list): + # List[Any] -> List[Tensor] + self[key] = [_safe_convert_tensor(elem) for elem in value] + else: + # Any -> Tensor + self[key] = _safe_convert_tensor(value) + return self + + def to(self, *args, **kwargs) -> "BatchFeature": + """ + Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in + different `dtypes` and sending the `BatchFeature` to a different `device`. + + Args: + args (`Tuple`): + Will be passed to the `to(...)` function of the tensors. + kwargs (`Dict`, *optional*): + Will be passed to the `to(...)` function of the tensors. + + Returns: + [`BatchFeature`]: The same instance after modification. + """ + requires_backends(self, ["torch"]) + import torch # noqa + + new_data = {} + device = kwargs.get("device") + # Check if the args are a device or a dtype + if device is None and len(args) > 0: + # device should be always the first argument + arg = args[0] + if is_torch_dtype(arg): + # The first argument is a dtype + pass + elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + else: + # it's something else + raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") + + def _to(elem): + # check if v is a floating point + if torch.is_floating_point(elem): + # cast and send to device + return elem.to(*args, **kwargs) + if device is not None: + return elem.to(device=device) + + return elem + + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + for k, v in self.items(): + if isinstance(v, list) and isinstance(v[0], list): + # Data structure is a list of lists + new_v = [] + for elems in v: + new_v.append([_to(elem) for elem in elems]) + new_data[k] = new_v + elif isinstance(v, list): + # Data structure is a list + new_data[k] = [_to(elem) for elem in v] + else: + new_data[k] = _to(v) + self.data = new_data + return self + + class FuyuImageProcessor(BaseImageProcessor): """ This class should handle the image processing part before the main FuyuForCausalLM. In particular, it should @@ -29,9 +184,9 @@ class FuyuImageProcessor(BaseImageProcessor): - Processing Images: Taking a batch of images as input. If the images are variable-sized, it resizes them based on the desired patch - dimensions. The image output is always img_h ........................................... 1080 img_w - ........................................... 1920 Then, it patches up these images using the patchify_image - function. + dimensions. The image output is always img_h, img_w of (1080, 1920) + + Then, it patches up these images using the patchify_image function. - Creating Image Input IDs: For each patch, a placeholder ID is given to identify where these patches belong in a token sequence. For @@ -40,6 +195,32 @@ class FuyuImageProcessor(BaseImageProcessor): - Image Patch Indices: For each image patch, the code maintains an index where these patches should be inserted in a token stream. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image to `size`. + size (`Dict[str, int]`, *optional*, defaults to `{"height": 1080, "width": 1920}`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to `size`. + padding_value (`float`, *optional*, defaults to 1.0): + The value to pad the image with. + padding_mode (`str`, *optional*, defaults to `"constant"`): + The padding mode to use when padding the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float`, *optional*, defaults to 0.5): + The mean to use when normalizing the image. + image_std (`float`, *optional*, defaults to 0.5): + The standard deviation to use when normalizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `1 / 255`): + The factor to use when rescaling the image. + patch_size (`Dict[str, int]`, *optional*, defaults to `{"height": 30, "width": 30}`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches. """ model_input_names = [ @@ -51,204 +232,483 @@ class FuyuImageProcessor(BaseImageProcessor): ] def __init__( - self, target_height=1080, target_width=1920, padding_value=1.0, padding_mode: str = "constant", **kwargs + self, + do_resize: bool = True, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_pad: bool = True, + padding_value: float = 1.0, + padding_mode: str = "constant", + do_normalize: bool = True, + image_mean: Union[float, List[float]] = 0.5, + image_std: Union[float, List[float]] = 0.5, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + patch_size: Optional[Dict[str, int]] = None, + **kwargs, ): super().__init__(**kwargs) - self.target_width = target_width - self.target_height = target_height + self.do_resize = do_resize + self.size = size if size is not None else {"height": 1080, "width": 1920} + self.resample = resample + self.do_pad = do_pad self.padding_value = padding_value self.padding_mode = padding_mode + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.patch_size = patch_size if patch_size is not None else {"height": 30, "width": 30} + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. - def get_num_patches(self, img_h: int, img_w: int, patch_dim_h: int, patch_dim_w: int) -> int: - """Calculate number of patches required to encode an image.""" - if img_h % patch_dim_h != 0: - raise ValueError(f"{img_h=} must be divisible by {patch_dim_h=}") - if img_w % patch_dim_w != 0: - raise ValueError(f"{img_w=} must be divisible by {patch_dim_w=}") + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + image_height, image_width = get_image_size(image, input_data_format) + target_height, target_width = size["height"], size["width"] - num_patches_per_dim_h = img_h // patch_dim_h - num_patches_per_dim_w = img_w // patch_dim_w - num_patches = num_patches_per_dim_h * num_patches_per_dim_w + if image_width <= target_width and image_height <= target_height: + return image + + height_scale_factor = target_height / image_height + width_scale_factor = target_width / image_width + optimal_scale_factor = min(height_scale_factor, width_scale_factor) + new_height = int(image_height * optimal_scale_factor) + new_width = int(image_width * optimal_scale_factor) + + scaled_image = resize( + image=image, + size=(new_height, new_width), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return scaled_image + + def pad_image( + self, + image: np.ndarray, + size: Dict[str, int], + mode: str = "constant", + constant_values: float = 1.0, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Pad an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to pad. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + data_format (`ChannelDimension` or `str`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + image_height, image_width = get_image_size(image, input_data_format) + target_height, target_width = size["height"], size["width"] + padding_top = 0 + padding_left = 0 + padding_bottom = target_height - image_height + padding_right = target_width - image_width + padded_image = pad( + image, + padding=((padding_top, padding_bottom), (padding_left, padding_right)), + mode=mode, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + return padded_image + + def preprocess( + self, + images, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_pad: Optional[bool] = None, + padding_value: Optional[float] = None, + padding_mode: Optional[str] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[float] = None, + image_std: Optional[float] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + patch_size: Optional[Dict[str, int]] = None, + data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + return_tensors: Optional[TensorType] = None, + ): + """ + + Utility function to preprocess the images and extract necessary information about original formats. + + Args: + images (`ImageInput`): + Images to preprocess. Expects a single image, a list or images or a list of lists of images. Pixel + values range from 0 to 255, or between 0 and 1 if `do_rescale` is `False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image to `size`. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to `size`. + padding_value (`float`, *optional*, defaults to `self.padding_value`): + The value to pad the image with. + padding_mode (`str`, *optional*, defaults to `self.padding_mode`): + The padding mode to use when padding the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float`, *optional*, defaults to `self.image_mean`): + The mean to use when normalizing the image. + image_std (`float`, *optional*, defaults to `self.image_std`): + The standard deviation to use when normalizing the image. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + The factor to use when rescaling the image. + patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format of the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_pad = do_pad if do_pad is not None else self.do_pad + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + padding_value = padding_value if padding_value is not None else self.padding_value + padding_mode = padding_mode if padding_mode is not None else self.padding_mode + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + patch_size = patch_size if patch_size is not None else self.patch_size + + if isinstance(images, list) and any(isinstance(elem, list) and len(elem) >= 2 for elem in images): + raise ValueError("Multiple images for a single sample are not yet supported.") + + batch_images = make_list_of_list_of_images(images) + + if do_resize and size is None: + raise ValueError("Size must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_normalize and image_mean is None or image_std is None: + raise ValueError("image_mean and image_std must be specified if do_normalize is True.") + + # All transformations expect numpy arrays. + batch_images = [[to_numpy_array(image) for image in images] for images in batch_images] + + if is_scaled_image(batch_images[0][0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(batch_images[0][0]) + + original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images] + + if do_resize: + batch_images = [ + [self.resize(image, size=size, input_data_format=input_data_format) for image in images] + for images in batch_images + ] + + image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images] + image_unpadded_heights = [[image_size[0]] for image_size in image_sizes] + image_unpadded_widths = [[image_size[1]] for image_size in image_sizes] + + # scale_h is the same as scale_w + image_scale_factors = [ + [resized_size[0] / original_size[0]] + for original_size, resized_size in zip(original_image_sizes, image_sizes) + ] + + if do_pad: + batch_images = [ + [ + self.pad_image( + image, + size=size, + mode=padding_mode, + constant_values=padding_value, + input_data_format=input_data_format, + ) + for image in images + ] + for images in batch_images + ] + + if do_rescale: + batch_images = [ + [self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) for image in images] + for images in batch_images + ] + + if do_normalize: + batch_images = [ + [ + self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + for images in batch_images + ] + + if data_format is not None: + batch_images = [ + [to_channel_dimension_format(image, data_format, input_data_format) for image in images] + for images in batch_images + ] + + data = { + "images": batch_images, + "image_unpadded_heights": image_unpadded_heights, + "image_unpadded_widths": image_unpadded_widths, + "image_scale_factors": image_scale_factors, + } + return FuyuBatchFeature(data=data, tensor_type=return_tensors) + + def get_num_patches(self, image_height: int, image_width: int, patch_size: Dict[str, int] = None) -> int: + """ + Calculate number of patches required to encode an image. + + Args: + image_height (`int`): + Height of the image. + image_width (`int`): + Width of the image. + patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches. + """ + patch_size = patch_size if patch_size is not None else self.patch_size + patch_height, patch_width = self.patch_size["height"], self.patch_size["width"] + + if image_height % patch_height != 0: + raise ValueError(f"{image_height=} must be divisible by {patch_height}") + if image_width % patch_width != 0: + raise ValueError(f"{image_width=} must be divisible by {patch_width}") + + num_patches_per_dim_h = image_height // patch_height + num_patches_per_dim_w = image_width // patch_width + num_patches = num_patches_per_dim_h * num_patches_per_dim_w return num_patches - def patchify_image(self, image: "torch.Tensor", patch_dim_h: int, patch_dim_w: int) -> "torch.Tensor": + def patchify_image(self, image: "torch.Tensor", patch_size: Optional[Dict[str, int]] = None) -> "torch.Tensor": """ Convert an image into a tensor of patches. Args: - image: Image to convert. Shape: [batch, channels, height, width] - patch_dim_h: Height of each patch. - patch_dim_w: Width of each patch. + image (`torch.Tensor`): + Image to convert. Shape: [batch, channels, height, width] + patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches. """ requires_backends(self, ["torch"]) + patch_size = patch_size if patch_size is not None else self.patch_size + patch_height, patch_width = patch_size["height"], patch_size["width"] # TODO refer to https://github.com/ArthurZucker/transformers/blob/0f0a3fe5ca5697ee58faeb5b53f049af720b5e98/src/transformers/models/vit_mae/modeling_vit_mae.py#L871 # torch implementation is faster but does not handle non-squares - batch_size, channels, height, width = image.shape - unfolded_along_height = image.unfold(2, patch_dim_h, patch_dim_h) - patches = unfolded_along_height.unfold(3, patch_dim_w, patch_dim_w) - - patches_reshaped = patches.contiguous().view(batch_size, channels, -1, patch_dim_h, patch_dim_w) - - patches_final = patches_reshaped.permute(0, 2, 3, 4, 1).reshape( - batch_size, -1, channels * patch_dim_h * patch_dim_w - ) - - return patches_final + batch_size, channels, _, _ = image.shape + unfolded_along_height = image.unfold(2, patch_height, patch_height) + patches = unfolded_along_height.unfold(3, patch_width, patch_width) + patches = patches.contiguous() + patches = patches.view(batch_size, channels, -1, patch_height, patch_width) + patches = patches.permute(0, 2, 3, 4, 1) + patches = patches.reshape(batch_size, -1, channels * patch_height * patch_width) + return patches - def process_images_for_model_input( + def preprocess_with_tokenizer_info( self, image_input: "torch.Tensor", image_present: "torch.Tensor", image_unpadded_h: "torch.Tensor", image_unpadded_w: "torch.Tensor", - image_patch_dim_h: int, - image_patch_dim_w: int, image_placeholder_id: int, image_newline_id: int, variable_sized: bool, - ) -> dict: + patch_size: Optional[Dict[str, int]] = None, + ) -> FuyuBatchFeature: """Process images for model input. In particular, variable-sized images are handled here. Args: - image_input: [batch_size, 1, c, h, w] tensor of images padded to model input size. - image_present: [batch_size, 1] tensor of 1s and 0s indicating whether an image is present. - image_unpadded_h: [batch_size, 1] tensor of unpadded image heights. - image_unpadded_w: [batch_size, 1] tensor of unpadded image widths. - image_patch_dim_h: The height of the image patches. - image_patch_dim_w: The width of the image patches. - image_placeholder_id: The id of the image placeholder token. - image_newline_id: The id of the image newline token. - variable_sized: Whether to process images as variable-sized. + image_input (`torch.Tensor` of shape [batch_size, subsequence_size, num_channels, height, width]): + Tensor of images padded to model input size. + image_present (`torch.Tensor` of shape [batch_size, subsequence_size, num_images]): + Tensor of 1s and 0s indicating whether an image is present. + image_unpadded_h (`torch.Tensor` of shape [batch_size, subsequence_size]): + Tensor of unpadded image heights. + image_unpadded_w (`torch.Tensor` of shape [batch_size, subsequence_size]): + Tensor of unpadded image widths. + image_placeholder_id (int): + The id of the image placeholder token. Comes from an associated tokenizer. + image_newline_id (int): + The id of the image newline token. Comes from an associated tokenizer. + variable_sized (bool): + Whether to process images as variable-sized. + patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`): + Size of the patches. """ requires_backends(self, ["torch"]) + + patch_size = patch_size if patch_size is not None else self.patch_size + patch_height, patch_width = patch_size["height"], patch_size["width"] + # Only images that are present. images: List[List[torch.Tensor]] = [] - image_patches: List[List[torch.Tensor]] = [] + batch_image_patches: List[List[torch.Tensor]] = [] # Image input ids for every subsequence, including ones with no image present. - image_input_ids: List[List[torch.Tensor]] = [] - for bi in range(image_input.shape[0]): - images.append([]) - image_input_ids.append([]) - image_patches.append([]) - for si in range(image_input.shape[1]): - if image_present[bi, si]: - image = image_input[bi, si] + batch_image_input_ids: List[List[torch.Tensor]] = [] + for batch_index in range(image_input.shape[0]): + image_input_ids = [] + image_patches = [] + for subseq_index in range(image_input.shape[1]): + if image_present[batch_index, subseq_index]: + image = image_input[batch_index, subseq_index] + image_height, image_width = image.shape[1], image.shape[2] if variable_sized: # The min() is required here due to floating point issues: # math.ceil(torch.tensor(300).cuda() / 30) == 11 new_h = min( - image.shape[1], math.ceil(image_unpadded_h[bi, si] / image_patch_dim_h) * image_patch_dim_h + image_height, + math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height, ) new_w = min( - image.shape[2], math.ceil(image_unpadded_w[bi, si] / image_patch_dim_w) * image_patch_dim_w + image_width, + math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width, ) image = image[:, :new_h, :new_w] - images[bi].append(image) - num_patches = self.get_num_patches( - img_h=image.shape[1], - img_w=image.shape[2], - patch_dim_h=image_patch_dim_h, - patch_dim_w=image_patch_dim_w, + image_height, image_width = new_h, new_w + + num_patches = self.get_num_patches(image_height=image_height, image_width=image_width) + tensor_of_image_ids = torch.full( + [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device ) - ids = torch.full([num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device) - patches = self.patchify_image( - image=image.unsqueeze(0), patch_dim_h=image_patch_dim_h, patch_dim_w=image_patch_dim_w - ).squeeze(0) + patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0) + assert num_patches == patches.shape[0] + if variable_sized: # Now terminate each line with |NEWLINE|. - ids = ids.reshape(-1, new_w // image_patch_dim_w) - ids = torch.cat( - [ - ids, - torch.full( - [ids.shape[0], 1], image_newline_id, dtype=torch.int32, device=image_input.device - ), - ], - dim=1, + tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width) + newline_ids = torch.full( + [tensor_of_image_ids.shape[0], 1], + image_newline_id, + dtype=torch.int32, + device=image_input.device, ) - ids = ids.reshape(-1) - image_input_ids[bi].append(ids) - image_patches[bi].append(patches) + tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1) + tensor_of_image_ids = tensor_of_image_ids.reshape(-1) + + images.append([image]) + image_input_ids.append(tensor_of_image_ids) + image_patches.append(patches) else: - image_input_ids[bi].append(torch.tensor([], dtype=torch.int32, device=image_input.device)) + image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device)) + + batch_image_input_ids.append(image_input_ids) + batch_image_patches.append(image_patches) # Create image_patch_input_indices, where non-negative values correspond to image patches to be inserted in # the stream. image_patch_indices_per_batch: List[List[torch.Tensor]] = [] image_patch_indices_per_subsequence: List[List[torch.Tensor]] = [] - for bi in range(len(image_input_ids)): - image_patch_indices_per_batch.append([]) - image_patch_indices_per_subsequence.append([]) + + for sample_image_input_ids in batch_image_input_ids: index_offset = 0 - for si in range(len(image_input_ids[bi])): + per_batch_indices = [] + per_subsequence_indices = [] + for subseq_image_input_ids in sample_image_input_ids: # Indices of image patches. - num_patches = torch.count_nonzero(image_input_ids[bi][si] == image_placeholder_id) + patches_mask = subseq_image_input_ids == image_placeholder_id + num_patches = torch.count_nonzero(patches_mask) indices = torch.arange( - num_patches, - dtype=image_input_ids[bi][si].dtype, - device=image_input_ids[bi][si].device, + num_patches, dtype=subseq_image_input_ids.dtype, device=subseq_image_input_ids.device ) # Place those indices in the image input ids token stream, with -1 representing non-index tokens. - indices_in_stream_per_batch = torch.full_like(image_input_ids[bi][si], -1) - indices_in_stream_per_subsequence = torch.full_like(image_input_ids[bi][si], -1) - indices_in_stream_per_batch[ - torch.nonzero(image_input_ids[bi][si] == image_placeholder_id, as_tuple=True)[0] - ] = (indices + index_offset) - indices_in_stream_per_subsequence[ - torch.nonzero(image_input_ids[bi][si] == image_placeholder_id, as_tuple=True)[0] - ] = indices - - image_patch_indices_per_batch[bi].append(indices_in_stream_per_batch) - image_patch_indices_per_subsequence[bi].append(indices_in_stream_per_subsequence) - index_offset += num_patches - - return { - "images": images, - "image_input_ids": image_input_ids, - "image_patches": image_patches, - "image_patch_indices_per_batch": image_patch_indices_per_batch, - "image_patch_indices_per_subsequence": image_patch_indices_per_subsequence, - } + indices_in_stream_per_batch = torch.full_like(subseq_image_input_ids, -1) + indices_in_stream_per_subsequence = torch.full_like(subseq_image_input_ids, -1) + patches_inds = torch.nonzero(patches_mask, as_tuple=True)[0] - def _scale_to_target_aspect_ratio(self, image: np.ndarray) -> np.ndarray: - image_height, image_width, _ = image.shape - if image_width <= self.target_width and image_height <= self.target_height: - return image - - height_scale_factor = self.target_height / image_height - width_scale_factor = self.target_width / image_width - optimal_scale_factor = min(height_scale_factor, width_scale_factor) + indices_in_stream_per_batch[patches_inds] = indices + index_offset + indices_in_stream_per_subsequence[patches_inds] = indices - new_height = int(image_height * optimal_scale_factor) - new_width = int(image_width * optimal_scale_factor) - - scaled_image = resize(image=image, size=(new_height, new_width)) - return np.array(scaled_image) - - def _pad_to_target_size(self, image: np.ndarray) -> np.ndarray: - image_height, image_width, _ = image.shape - - padding_top = 0 - padding_left = 0 - padding_bottom = self.target_height - image_height - padding_right = self.target_width - image_width + per_batch_indices.append(indices_in_stream_per_batch) + per_subsequence_indices.append(indices_in_stream_per_subsequence) + index_offset += num_patches - padded_image = pad( - image, - ((padding_top, padding_bottom), (padding_left, padding_right)), - mode=self.padding_mode, - constant_values=self.padding_value, + image_patch_indices_per_batch.append(per_batch_indices) + image_patch_indices_per_subsequence.append(per_subsequence_indices) + + return FuyuBatchFeature( + data={ + "images": images, + "image_input_ids": batch_image_input_ids, + "image_patches": batch_image_patches, + "image_patch_indices_per_batch": image_patch_indices_per_batch, + "image_patch_indices_per_subsequence": image_patch_indices_per_subsequence, + } ) - return padded_image - - def apply_transformation(self, image: Union[np.ndarray, PIL.Image.Image]) -> np.ndarray: - if isinstance(image, PIL.Image.Image): - image = to_numpy_array(image) - scaled_image = self._scale_to_target_aspect_ratio(image) - padded_image = self._pad_to_target_size(scaled_image) - normalized_padded_image = normalize(padded_image, 0.5, 0.5) - return normalized_padded_image diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 89127843befe08..345d0a0e92a5ee 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -257,8 +257,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if image_patches is not None and past_key_values is None: - patch_embeddings = self.vision_embed_tokens(image_patches.to(self.vision_embed_tokens.weight.dtype)) - patch_embeddings = patch_embeddings.to(inputs_embeds.device) + patch_embeddings = [ + self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0) + for patch in image_patches + ] inputs_embeds = self.gather_continuous_embeddings( word_embeddings=inputs_embeds, continuous_embeddings=patch_embeddings, diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py index ea660b072d721a..e0f362a6c8763b 100644 --- a/src/transformers/models/fuyu/processing_fuyu.py +++ b/src/transformers/models/fuyu/processing_fuyu.py @@ -1,45 +1,50 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for GIT +""" import re -from typing import Any, Iterable, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np -from ...image_utils import ( - ChannelDimension, - get_image_size, - infer_channel_dimension_format, - is_scaled_image, - to_numpy_array, -) from ...processing_utils import ProcessorMixin -from ...utils import is_torch_available, is_vision_available, logging +from ...tokenization_utils_base import PaddingStrategy, TruncationStrategy +from ...utils import TensorType, is_torch_available, logging, requires_backends -if is_torch_available() and is_vision_available(): - from .image_processing_fuyu import FuyuImageProcessor +if is_torch_available(): + from .image_processing_fuyu import FuyuBatchFeature logger = logging.get_logger(__name__) -if is_vision_available(): - import PIL if is_torch_available(): import torch -BBOX_OPEN_STRING = "<0x00>" # -BBOX_CLOSE_STRING = "<0x01>" # -POINT_OPEN_STRING = "<0x02>" # -POINT_CLOSE_STRING = "<0x03>" # TEXT_REPR_BBOX_OPEN = "" TEXT_REPR_BBOX_CLOSE = "" TEXT_REPR_POINT_OPEN = "" TEXT_REPR_POINT_CLOSE = "" -TOKEN_BBOX_OPEN_STRING = BBOX_OPEN_STRING = "<0x00>" # -BBOX_CLOSE_STRING = "<0x01>" # -TOKEN_BBOX_CLOSE_STRING = TOKEN_POINT_OPEN_STRING = POINT_OPEN_STRING = "<0x02>" # -TOKEN_POINT_CLOSE_STRING = POINT_CLOSE_STRING = "<0x03>" # +TOKEN_BBOX_OPEN_STRING = "<0x00>" # +TOKEN_BBOX_CLOSE_STRING = "<0x01>" # +TOKEN_POINT_OPEN_STRING = "<0x02>" # +TOKEN_POINT_CLOSE_STRING = "<0x03>" # BEGINNING_OF_ANSWER_STRING = "<0x04>" # @@ -87,18 +92,16 @@ def construct_full_unpacked_stream( all_bi_stream = [] - for bi in range(batch_size): + for batch_index in range(batch_size): all_si_stream = [] # First, construct full token stream (including image placeholder tokens) and loss mask for each subsequence # and append to lists. We use lists rather than tensors because each subsequence is variable-sized. - for si in range(num_sub_sequences): - image_adjustment = image_tokens[bi][si] - si_stream = torch.cat([image_adjustment, input_stream[bi, si]], dim=0) - num_real_tokens = image_adjustment.shape[0] + num_real_text_tokens[bi][si] - - all_si_stream.append(si_stream[:num_real_tokens]) - # Combine all subsequences for this batch entry. Still using a list because each batch entry is variable-sized. + # TODO Remove this logic in a subsequent release since subsequences are not supported. + image_adjustment = image_tokens[batch_index][0] + subsequence_stream = torch.cat([image_adjustment, input_stream[batch_index, 0]], dim=0) + num_real_tokens = image_adjustment.shape[0] + num_real_text_tokens[batch_index][0] + all_si_stream.append(subsequence_stream[:num_real_tokens]) all_bi_stream.append(torch.cat(all_si_stream, dim=0)) return all_bi_stream @@ -137,7 +140,7 @@ def _segment_prompt_into_text_token_conversions(prompt: str) -> List: return prompt_text_list -def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenizer) -> List[int]: +def _transform_coordinates_and_tokenize(prompt: str, scale_factor: float, tokenizer) -> List[int]: """ This function transforms the prompt in the following fashion: - and to their respective token mappings @@ -161,7 +164,7 @@ def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenize for elem in prompt_text_list: if elem[1]: # This is a location, we need to tokenize it - within_tag_tokenized = _transform_within_tags(elem[0], transformed_image, tokenizer) + within_tag_tokenized = _transform_within_tags(elem[0], scale_factor, tokenizer) # Surround the text with the open and close tags transformed_prompt_tokens.extend(within_tag_tokenized) else: @@ -169,7 +172,7 @@ def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenize return transformed_prompt_tokens -def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int]: +def _transform_within_tags(text: str, scale_factor: float, tokenizer) -> List[int]: """ Given a bounding box of the fashion 1, 2, 3, 4 | 1, 2 This function is responsible for converting 1, 2, 3, 4 into tokens of 1 2 3 4 without any commas. @@ -188,16 +191,14 @@ def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int] num_ints = [float(num.strip()) for num in num_int_strs] # scale to transformed image siz if len(num_ints) == 2: - num_ints_translated = scale_point_to_transformed_image( - x=num_ints[0], y=num_ints[1], transformed_image=transformed_image - ) + num_ints_translated = scale_point_to_transformed_image(x=num_ints[0], y=num_ints[1], scale_factor=scale_factor) elif len(num_ints) == 4: num_ints_translated = scale_bbox_to_transformed_image( top=num_ints[0], left=num_ints[1], bottom=num_ints[2], right=num_ints[3], - transformed_image=transformed_image, + scale_factor=scale_factor, ) else: raise ValueError(f"Invalid number of ints: {len(num_ints)}") @@ -209,7 +210,7 @@ def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int] def _tokenize_prompts_with_image_and_batch( tokenizer, prompts: List[List[str]], - transformed_images: Optional[List[List["torch.Tensor"]]], + scale_factors: Optional[List[List["torch.Tensor"]]], max_tokens_to_generate: int, max_position_embeddings: int, add_BOS: bool, # Same issue with types as above @@ -223,13 +224,13 @@ def _tokenize_prompts_with_image_and_batch( """ # If not tool use, tranform the coordinates while tokenizing - if transformed_images is not None: + if scale_factors is not None: transformed_prompt_tokens = [] - for prompt_seq, transformed_image_seq in zip(prompts, transformed_images): + for prompt_seq, scale_factor_seq in zip(prompts, scale_factors): transformed_prompt_tokens.append( [ - _transform_coordinates_and_tokenize(prompt, transformed_image, tokenizer) - for prompt, transformed_image in zip(prompt_seq, transformed_image_seq) + _transform_coordinates_and_tokenize(prompt, scale_factor.item(), tokenizer) + for prompt, scale_factor in zip(prompt_seq, scale_factor_seq) ] ) else: @@ -260,7 +261,7 @@ def _tokenize_prompts_with_image_and_batch( # Number of tokens in the each sample of the batch. samples_length = min(max_prompt_len + max_tokens_to_generate, max_position_embeddings) if max_prompt_len + max_tokens_to_generate > max_position_embeddings: - print( + logger.warning( f"Max subsequence prompt length of {max_prompt_len} + max tokens to generate {max_tokens_to_generate}", f"exceeds context length of {max_position_embeddings}. Will generate as many tokens as possible.", ) @@ -279,86 +280,30 @@ def _tokenize_prompts_with_image_and_batch( return prompts_tokens_tensor, prompts_length_tensor -def original_to_transformed_h_coords(self, original_coords): - # apply crop - cropped_coords = ( - self._clamp_coords(original_coords, min_value=self.crop_top, max_value=self.crop_bottom) - self.crop_top - ) - # apply scale - scaled_coords = self._scale_coords(cropped_coords, scale=self.scaled_h / self.original_h) - # apply pad - return scaled_coords + self.padding_top +# Simplified assuming self.crop_top = self.padding_top = 0 +def original_to_transformed_h_coords(original_coords, scale_h): + return np.round(original_coords * scale_h).astype(np.int32) -def original_to_transformed_w_coords(self, original_coords): - # apply crop - cropped_coords = ( - self._clamp_coords(original_coords, min_value=self.crop_left, max_value=self.crop_right) - self.crop_left - ) - # apply scale - scaled_coords = self._scale_coords(cropped_coords, scale=self.scaled_w / self.original_w) - # apply pad - return scaled_coords + self.padding_left +# Simplified assuming self.crop_left = self.padding_left = 0 +def original_to_transformed_w_coords(original_coords, scale_w): + return np.round(original_coords * scale_w).astype(np.int32) -def scale_point_to_transformed_image(x: float, y: float) -> List[int]: - x_scaled = original_to_transformed_w_coords(np.array([x / 2]))[0] - y_scaled = original_to_transformed_h_coords(np.array([y / 2]))[0] +def scale_point_to_transformed_image(x: float, y: float, scale_factor: float) -> List[int]: + x_scaled = original_to_transformed_w_coords(np.array([x / 2]), scale_factor)[0] + y_scaled = original_to_transformed_h_coords(np.array([y / 2]), scale_factor)[0] return [x_scaled, y_scaled] -def scale_bbox_to_transformed_image(top: float, left: float, bottom: float, right: float) -> List[int]: - top_scaled = original_to_transformed_w_coords(np.array([top / 2]))[0] - left_scaled = original_to_transformed_h_coords(np.array([left / 2]))[0] - bottom_scaled = original_to_transformed_w_coords(np.array([bottom / 2]))[0] - right_scaled = original_to_transformed_h_coords(np.array([right / 2]))[0] - return [top_scaled, left_scaled, bottom_scaled, right_scaled] - - -# Copied from transformers.models.detr.image_processing_detr.max_across_indices -def max_across_indices(values: Iterable[Any]) -> List[Any]: - """ - Return the maximum value across all indices of an iterable of values. - """ - return [max(values_i) for values_i in zip(*values)] - - -# Copied from transformers.models.detr.image_processing_detr.get_max_height_width -def get_max_height_width( - images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None +def scale_bbox_to_transformed_image( + top: float, left: float, bottom: float, right: float, scale_factor: float ) -> List[int]: - """ - Get the maximum height and width across all images in a batch. - """ - if input_data_format is None: - input_data_format = infer_channel_dimension_format(images[0]) - - if input_data_format == ChannelDimension.FIRST: - _, max_height, max_width = max_across_indices([img.shape for img in images]) - elif input_data_format == ChannelDimension.LAST: - max_height, max_width, _ = max_across_indices([img.shape for img in images]) - else: - raise ValueError(f"Invalid channel dimension format: {input_data_format}") - return (max_height, max_width) - - -# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask -def make_pixel_mask( - image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None -) -> np.ndarray: - """ - Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. - - Args: - image (`np.ndarray`): - Image to make the pixel mask for. - output_size (`Tuple[int, int]`): - Output size of the mask. - """ - input_height, input_width = get_image_size(image, channel_dim=input_data_format) - mask = np.zeros(output_size, dtype=np.int64) - mask[:input_height, :input_width] = 1 - return mask + top_scaled = original_to_transformed_w_coords(np.array([top / 2]), scale_factor)[0] + left_scaled = original_to_transformed_h_coords(np.array([left / 2]), scale_factor)[0] + bottom_scaled = original_to_transformed_w_coords(np.array([bottom / 2]), scale_factor)[0] + right_scaled = original_to_transformed_h_coords(np.array([right / 2]), scale_factor)[0] + return [top_scaled, left_scaled, bottom_scaled, right_scaled] class FuyuProcessor(ProcessorMixin): @@ -384,42 +329,148 @@ def __init__(self, image_processor, tokenizer): self.tokenizer = tokenizer self.max_tokens_to_generate = 10 self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it? - self.image_processor = FuyuImageProcessor() - - def _process_images(self, images): - """Utility function to preprocess the images and extract necessary information about original formats.""" - batch_images = [] - image_unpadded_heights = [] - image_unpadded_widths = [] - - for image in images: - image = to_numpy_array(image) - if not is_scaled_image(image): - image = image / 255.0 - channel_dimension = infer_channel_dimension_format(image, 3) - if channel_dimension == ChannelDimension.FIRST: - width_index = 2 - height_index = 1 - elif channel_dimension == ChannelDimension.LAST: - width_index = 1 - height_index = 0 - - image_unpadded_widths.append([image.shape[width_index]]) - image_unpadded_heights.append([image.shape[height_index]]) - - # Reproduct adept padding sampler - padded_image = self.image_processor.apply_transformation(image) - - tensor_img = torch.Tensor(padded_image).permute(2, 0, 1) - batch_images.append([tensor_img]) - - return batch_images, torch.Tensor(image_unpadded_heights), torch.Tensor(image_unpadded_widths) - - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + self.pad_token_id = 0 + self.dummy_image_index = -1 + + def _left_pad_inputs_with_attention_mask(self, model_inputs: List[Dict], return_attention_mask: bool): + max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs) + max_length_image_patch_indices = max(entry["image_patches_indices"].shape[1] for entry in model_inputs) + + batched_inputs = {"input_ids": [], "image_patches": [], "image_patches_indices": [], "attention_mask": []} + + for entry in model_inputs: + for key, tensor in entry.items(): + if key == "input_ids": + num_padding_tokens = max_length_input_ids - tensor.shape[1] + padded_input_ids = torch.cat( + [ + torch.full((tensor.shape[0], num_padding_tokens), self.pad_token_id, dtype=torch.long), + tensor, + ], + dim=1, + ) + batched_inputs[key].append(padded_input_ids) + + attention_mask = torch.cat( + [torch.zeros(tensor.shape[0], num_padding_tokens, dtype=torch.long), torch.ones_like(tensor)], + dim=1, + ) + batched_inputs["attention_mask"].append(attention_mask) + + elif key == "image_patches": + # For image_patches, we don't pad but just append them to the list. + batched_inputs[key].append(tensor) + + else: # for image_patches_indices + num_padding_indices = max_length_image_patch_indices - tensor.shape[1] + padded_indices = torch.cat( + [ + torch.full( + (tensor.shape[0], num_padding_indices), self.dummy_image_index, dtype=torch.long + ), + tensor, + ], + dim=1, + ) + batched_inputs[key].append(padded_indices) + batched_keys = ["input_ids", "image_patches_indices"] + if return_attention_mask: + batched_keys.append("attention_mask") + for key in batched_keys: + batched_inputs[key] = torch.cat(batched_inputs[key], dim=0) + + return batched_inputs + + def get_sample_encoding( + self, + prompts, + scale_factors, + image_unpadded_heights, + image_unpadded_widths, + image_placeholder_id, + image_newline_id, + tensor_batch_images, + ): + image_present = torch.ones(1, 1, 1) + model_image_input = self.image_processor.preprocess_with_tokenizer_info( + image_input=tensor_batch_images, + image_present=image_present, + image_unpadded_h=image_unpadded_heights, + image_unpadded_w=image_unpadded_widths, + image_placeholder_id=image_placeholder_id, + image_newline_id=image_newline_id, + variable_sized=True, + ) + # FIXME max_tokens_to_generate is embedded into this processor's call. + prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch( + tokenizer=self.tokenizer, + prompts=prompts, + scale_factors=scale_factors, + max_tokens_to_generate=self.max_tokens_to_generate, + max_position_embeddings=self.max_position_embeddings, + add_BOS=True, + add_beginning_of_answer_token=True, + ) + image_padded_unpacked_tokens = construct_full_unpacked_stream( + num_real_text_tokens=prompts_length, + input_stream=prompt_tokens, + image_tokens=model_image_input["image_input_ids"], + batch_size=1, + num_sub_sequences=self.subsequence_length, + ) + # Construct inputs for image patch indices. + unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream( + num_real_text_tokens=prompts_length, + input_stream=torch.full_like(prompt_tokens, -1), + image_tokens=model_image_input["image_patch_indices_per_batch"], + batch_size=1, + num_sub_sequences=self.subsequence_length, + ) + max_prompt_length = max(x.shape[-1] for x in image_padded_unpacked_tokens) + max_seq_len_batch = min(max_prompt_length + self.max_tokens_to_generate, self.max_position_embeddings) + tokens_to_place = min(max_seq_len_batch, max(0, image_padded_unpacked_tokens[0].shape[0])) + + # Use same packing logic for the image patch indices. + image_patch_input_indices = full_unpacked_stream_to_tensor( + all_bi_tokens_to_place=[tokens_to_place], + full_unpacked_stream=unpacked_image_patch_indices_per_batch, + fill_value=-1, + batch_size=1, + new_seq_len=max_seq_len_batch, + offset=0, + ) + image_patches_tensor = torch.stack([img[0] for img in model_image_input["image_patches"]]) + batch_encoding = { + "input_ids": image_padded_unpacked_tokens[0].unsqueeze(0), + "image_patches": image_patches_tensor, + "image_patches_indices": image_patch_input_indices, + } + return batch_encoding + + def __call__( + self, + text=None, + images=None, + add_special_tokens: bool = True, + return_attention_mask: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_token_type_ids: bool = False, + return_length: bool = False, + verbose: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> "FuyuBatchFeature": """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to - encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + encode the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to FuyuImageProcessor's [`~FuyuImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring of the above two methods for more information. @@ -433,130 +484,211 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`FuyuBatchEncoding`]: A [`FuyuBatchEncoding`] with the following fields: - - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when - `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not - `None`). - - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **input_ids** -- Tensor of token ids to be fed to a model. Returned when `text` is not `None`. + - **image_patches** -- List of Tensor of image patches. Returned when `images` is not `None`. + - **image_patches_indices** -- Tensor of indices where patch embeddings have to be inserted by the model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model when + `return_attention_mask=True`. """ + requires_backends(self, ["torch"]) + + # --- Check input validity --- + if not return_attention_mask: + raise ValueError("`return_attention_mask=False` is not supported for this model.") if text is None and images is None: - raise ValueError("You have to specify either text or images. Both cannot be none.") + raise ValueError("You have to specify either text or images. Both cannot be None.") + if text is not None and images is None: + logger.warning("You are processing a text with no associated image. Make sure it is intended.") + self.current_processor = self.tokenizer + text_encoding = self.tokenizer( + text=text, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_token_type_ids=return_token_type_ids, + return_length=return_length, + verbose=verbose, + return_tensors=return_tensors, + **kwargs, + ) + return text_encoding + + if text is None and images is not None: + logger.warning("You are processing an image with no associated text. Make sure it is intended.") + prompts = [[""]] if text is not None and images is not None: if isinstance(text, str): prompts = [[text]] elif isinstance(text, list): prompts = [[text_seq] for text_seq in text] - batch_images = [] - if isinstance(images, PIL.Image.Image): - images = [images] - if isinstance(images, list): - batch_images, image_unpadded_heights, image_unpadded_widths = self._process_images(images) - # image_unpadded_heights = image_unpadded_heights.unsqueeze(0) - # image_unpadded_widths = image_unpadded_widths.unsqueeze(0) - else: - raise ValueError("images must be a list of ndarrays or PIL Images to be processed.") - - # Note: the original adept code has a handling of image_unpadded_h and w, but it doesn't seem to hold - # when there are several different size subsequences per batch. The current implementation reflects - # that limitation and should be documented. - # - self.subsequence_length = 1 # Each batch contains only one sequence. - self.batch_size = len(batch_images) - # FIXME max_tokens_to_generate is embedded into this processor's call. - prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch( - tokenizer=self.tokenizer, - prompts=prompts, - transformed_images=batch_images, - max_tokens_to_generate=self.max_tokens_to_generate, - max_position_embeddings=self.max_position_embeddings, - add_BOS=True, - add_beginning_of_answer_token=True, - ) - # same so far - - # This is 1 if there is an image per subsequence, else 0. [batch, 1, presence] - # the remainder of current image processing logic assumes subsequence_size = 1. - # Here it is OK as the model cannot handle > 1 subsequences - # the image could be absent however and image presence should be inferred from user batch input - # hence this code assumes the images are present. Use an assert? - - image_present = torch.ones(self.batch_size, 1, 1) - - image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1] - image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1] - tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1) - model_image_input = self.image_processor.process_images_for_model_input( - image_input=tensor_batch_images, - image_present=image_present, - image_unpadded_h=image_unpadded_heights, - image_unpadded_w=image_unpadded_widths, - image_patch_dim_h=30, - image_patch_dim_w=30, + + # --- Preprocess images using self.image_processor --- + + # FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors + image_encoding = self.image_processor.preprocess(images, return_tensors="pt") + batch_images = image_encoding["images"] + image_unpadded_heights = image_encoding["image_unpadded_heights"] + image_unpadded_widths = image_encoding["image_unpadded_widths"] + scale_factors = image_encoding["image_scale_factors"] + self.subsequence_length = 1 # Each batch contains only one sequence. + self.batch_size = len(batch_images) + + # --- Use self.tokenizer to get the ids of special tokens to insert into image ids --- + + image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1] + image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1] + tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1) + + # --- Use self.image_processor again to obtain the full token ids and batch inputs --- + all_encodings = [] + + for prompt, scale_factor, image_unpadded_height, image_unpadded_width, tensor_batch_image in zip( + prompts, scale_factors, image_unpadded_heights, image_unpadded_widths, tensor_batch_images + ): + sample_encoding = self.get_sample_encoding( + prompts=[prompt], + scale_factors=[scale_factor], + image_unpadded_heights=torch.tensor([image_unpadded_height]), + image_unpadded_widths=torch.tensor([image_unpadded_width]), image_placeholder_id=image_placeholder_id, image_newline_id=image_newline_id, - variable_sized=True, + tensor_batch_images=tensor_batch_image.unsqueeze(0), ) + all_encodings.append(sample_encoding) + batch_encoding = self._left_pad_inputs_with_attention_mask( + model_inputs=all_encodings, return_attention_mask=return_attention_mask + ) + return FuyuBatchFeature(data=batch_encoding) - image_padded_unpacked_tokens = construct_full_unpacked_stream( - num_real_text_tokens=prompts_length, - input_stream=prompt_tokens, - image_tokens=model_image_input["image_input_ids"], - batch_size=self.batch_size, - num_sub_sequences=self.subsequence_length, - ) - # Construct inputs for image patch indices. - unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream( - num_real_text_tokens=prompts_length, - input_stream=torch.full_like(prompt_tokens, -1), - image_tokens=model_image_input["image_patch_indices_per_batch"], - batch_size=self.batch_size, - num_sub_sequences=self.subsequence_length, - ) - max_prompt_length = max(x.shape[-1] for x in image_padded_unpacked_tokens) - max_seq_len_batch = min(max_prompt_length + self.max_tokens_to_generate, self.max_position_embeddings) - all_bi_tokens_to_place = [] - for bi in range(self.batch_size): - tokens_to_place = min(max_seq_len_batch, max(0, image_padded_unpacked_tokens[bi].shape[0])) - all_bi_tokens_to_place.append(tokens_to_place) - - # Use same packing logic for the image patch indices. - image_patch_input_indices = full_unpacked_stream_to_tensor( - all_bi_tokens_to_place=all_bi_tokens_to_place, - full_unpacked_stream=unpacked_image_patch_indices_per_batch, - fill_value=-1, - batch_size=self.batch_size, - new_seq_len=max_seq_len_batch, - offset=0, - ) + def post_process_box_coordinates(self, outputs, target_sizes=None): + """ + Transforms raw coordinates detected by [`FuyuForCausalLM`] to the original images' coordinate space. + Coordinates will be returned in "box" format, with the following pattern: + `top, left, bottom, right` + + Point coordinates are not supported yet. - image_patches_tensor = torch.stack([img[0] for img in model_image_input["image_patches"]]).unsqueeze(1) - return { - "input_ids": image_padded_unpacked_tokens[0].unsqueeze(0), - "image_patches": image_patches_tensor[0][0].unsqueeze(0), - "image_patches_indices": image_patch_input_indices, - } + Args: + outputs ([`GenerateOutput`]): + Raw outputs from `generate`. + target_sizes (`torch.Tensor`, *optional*): + Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in + the batch. If set, found coordinates in the output sequence are rescaled to the target sizes. If left + to None, coordinates will not be rescaled. + + Returns: + `GenerateOutput`: Same output type returned by `generate`, with output token ids replaced with + boxed and possible rescaled coordinates. + """ + + def scale_factor_to_fit(original_size, target_size=None): + height, width = original_size + if target_size is None: + max_height = self.image_processor.size["height"] + max_width = self.image_processor.size["width"] + else: + max_height, max_width = target_size + if width <= max_width and height <= max_height: + return 1.0 + return min(max_height / height, max_width / width) + + def find_delimiters_pair(tokens, start_token, end_token): + start_id = self.tokenizer.convert_tokens_to_ids(start_token) + end_id = self.tokenizer.convert_tokens_to_ids(end_token) + + starting_positions = (tokens == start_id).nonzero(as_tuple=True)[0] + ending_positions = (tokens == end_id).nonzero(as_tuple=True)[0] + + if torch.any(starting_positions) and torch.any(ending_positions): + return (starting_positions[0], ending_positions[0]) + return (None, None) + + def tokens_to_boxes(tokens, original_size): + while (pair := find_delimiters_pair(tokens, TOKEN_BBOX_OPEN_STRING, TOKEN_BBOX_CLOSE_STRING)) != ( + None, + None, + ): + start, end = pair + if end != start + 5: + continue + + # Retrieve transformed coordinates from tokens + coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end]) + + # Scale back to original image size and multiply by 2 + scale = scale_factor_to_fit(original_size) + top, left, bottom, right = [2 * int(float(c) / scale) for c in coords] + + # Replace the IDs so they get detokenized right + replacement = f" {TEXT_REPR_BBOX_OPEN}{top}, {left}, {bottom}, {right}{TEXT_REPR_BBOX_CLOSE}" + replacement = self.tokenizer.tokenize(replacement)[1:] + replacement = self.tokenizer.convert_tokens_to_ids(replacement) + replacement = torch.tensor(replacement).to(tokens) + + tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0) + return tokens + + def tokens_to_points(tokens, original_size): + while (pair := find_delimiters_pair(tokens, TOKEN_POINT_OPEN_STRING, TOKEN_POINT_CLOSE_STRING)) != ( + None, + None, + ): + start, end = pair + if end != start + 3: + continue + + # Retrieve transformed coordinates from tokens + coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end]) + + # Scale back to original image size and multiply by 2 + scale = scale_factor_to_fit(original_size) + x, y = [2 * int(float(c) / scale) for c in coords] + + # Replace the IDs so they get detokenized right + replacement = f" {TEXT_REPR_POINT_OPEN}{x}, {y}{TEXT_REPR_POINT_CLOSE}" + replacement = self.tokenizer.tokenize(replacement)[1:] + replacement = self.tokenizer.convert_tokens_to_ids(replacement) + replacement = torch.tensor(replacement).to(tokens) + + tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0) + return tokens + + if target_sizes is None: + target_sizes = ((self.image_processor.size["height"], self.image_processor.size["width"]),) * len(outputs) + elif target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + if len(outputs) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as output sequences") + + results = [] + for seq, size in zip(outputs, target_sizes): + seq = tokens_to_boxes(seq, size) + seq = tokens_to_points(seq, size) + results.append(seq) + + return results def batch_decode(self, *args, **kwargs): """ - This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): """ - This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) diff --git a/tests/models/fuyu/test_image_processing_fuyu.py b/tests/models/fuyu/test_image_processing_fuyu.py index 73f0936aacf13d..a9930e2fb81297 100644 --- a/tests/models/fuyu/test_image_processing_fuyu.py +++ b/tests/models/fuyu/test_image_processing_fuyu.py @@ -24,7 +24,8 @@ @require_torchvision class TestFuyuImageProcessor(unittest.TestCase): def setUp(self): - self.processor = FuyuImageProcessor(target_height=160, target_width=320, padding_value=1.0) + self.size = {"height": 160, "width": 320} + self.processor = FuyuImageProcessor(size=self.size, padding_value=1.0) self.batch_size = 3 self.channels = 3 self.height = 300 @@ -38,29 +39,25 @@ def setUp(self): self.sample_image_pil = Image.fromarray(self.sample_image) def test_patches(self): - expected_num_patches = self.processor.get_num_patches( - img_h=self.height, img_w=self.width, patch_dim_h=self.image_patch_dim_h, patch_dim_w=self.image_patch_dim_w - ) + expected_num_patches = self.processor.get_num_patches(image_height=self.height, image_width=self.width) - patches_final = self.processor.patchify_image( - image=self.image_input, patch_dim_h=self.image_patch_dim_h, patch_dim_w=self.image_patch_dim_w - ) + patches_final = self.processor.patchify_image(image=self.image_input) assert ( patches_final.shape[1] == expected_num_patches ), f"Expected {expected_num_patches} patches, got {patches_final.shape[1]}." def test_scale_to_target_aspect_ratio(self): # (h:450, w:210) fitting (160, 320) -> (160, 210*160/450) - scaled_image = self.processor._scale_to_target_aspect_ratio(self.sample_image) + scaled_image = self.processor.resize(self.sample_image, size=self.size) self.assertEqual(scaled_image.shape[0], 160) self.assertEqual(scaled_image.shape[1], 74) def test_apply_transformation_numpy(self): - transformed_image = self.processor.apply_transformation(self.sample_image) - self.assertEqual(transformed_image.shape[0], 160) - self.assertEqual(transformed_image.shape[1], 320) + transformed_image = self.processor.preprocess(self.sample_image).images[0][0] + self.assertEqual(transformed_image.shape[1], 160) + self.assertEqual(transformed_image.shape[2], 320) def test_apply_transformation_pil(self): - transformed_image = self.processor.apply_transformation(self.sample_image_pil) - self.assertEqual(transformed_image.shape[0], 160) - self.assertEqual(transformed_image.shape[1], 320) + transformed_image = self.processor.preprocess(self.sample_image_pil).images[0][0] + self.assertEqual(transformed_image.shape[1], 160) + self.assertEqual(transformed_image.shape[2], 320) diff --git a/tests/models/fuyu/test_modeling_fuyu.py b/tests/models/fuyu/test_modeling_fuyu.py index b9c061e7a00448..9fb6820e45ffb1 100644 --- a/tests/models/fuyu/test_modeling_fuyu.py +++ b/tests/models/fuyu/test_modeling_fuyu.py @@ -3,7 +3,7 @@ import requests -from transformers import AutoTokenizer, FuyuConfig, is_torch_available, is_vision_available +from transformers import FuyuConfig, is_torch_available, is_vision_available from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device from ...test_modeling_common import ids_tensor, random_attention_mask @@ -14,7 +14,7 @@ if is_torch_available() and is_vision_available(): - from transformers import FuyuImageProcessor, FuyuProcessor + from transformers import FuyuProcessor if is_torch_available(): @@ -267,11 +267,8 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin) all_model_classes = ("FuyuForCausalLM") if is_torch_available() else () def setUp(self): - self.pretrained_model_name = "huggingface/new_model_release_weights" - tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name) - image_processor = FuyuImageProcessor() - - self.processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer) + self.pretrained_model_name = "adept/fuyu-8b" + self.processor = FuyuProcessor.from_pretrained(self.pretrained_model_name) self.model = FuyuForCausalLM.from_pretrained(self.pretrained_model_name) self.bus_image_url = ( "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png" @@ -280,9 +277,8 @@ def setUp(self): @slow def test_model_8b_chat_greedy_generation_bus_captioning(self): - EXPECTED_TEXT_COMPLETION = """A bus parked on the side of a road.|ENDOFTEXT|""" + EXPECTED_TEXT_COMPLETION = """A blue bus parked on the side of a road.|ENDOFTEXT|""" text_prompt_coco_captioning = "Generate a coco-style caption.\n" - model_inputs_bus_captioning = self.processor(text=text_prompt_coco_captioning, images=self.bus_image_pil) generated_tokens = self.model.generate(**model_inputs_bus_captioning, max_new_tokens=10) text = self.processor.tokenizer.batch_decode(generated_tokens) @@ -297,7 +293,7 @@ def test_model_8b_chat_greedy_generation_bus_captioning(self): """ @slow - @require_torch_gpu + @require_torch_accelerator def test_model_8b_chat_greedy_generation_bus_color(self): EXPECTED_TEXT_COMPLETION = "The bus is blue.\n|ENDOFTEXT|" text_prompt_bus_color = "What color is the bus?\n" @@ -314,7 +310,7 @@ def test_model_8b_chat_greedy_generation_bus_color(self): self.assertEqual(EXPECTED_TEXT_COMPLETION, clean_sequence) @slow - @require_torch_gpu + @require_torch_accelerator def test_model_8b_chat_greedy_generation_chart_vqa(self): # fmt: off EXPECTED_TEXT_TOKENS = ["The","life expectancy","at","birth","of male","s in","","20","18","is","","80",".","7",".","\n","|ENDOFTEXT|",] @@ -340,7 +336,7 @@ def test_model_8b_chat_greedy_generation_chart_vqa(self): self.assertEqual(expected_text_completion, clean_sequence) @slow - @require_torch_gpu + @require_torch_accelerator def test_model_8b_chat_greedy_generation_bounding_box(self): EXPECTED_TEXT_COMPLETION = "\x00194213202244\x01|ENDOFTEXT|" text_prompt_bbox = "When presented with a box, perform OCR to extract text contained within it. If provided with text, generate the corresponding bounding box.\\nWilliams" # noqa: E231 diff --git a/tests/models/fuyu/test_processing_fuyu.py b/tests/models/fuyu/test_processing_fuyu.py index 1c75b2b0ae318a..459386952c3ed9 100644 --- a/tests/models/fuyu/test_processing_fuyu.py +++ b/tests/models/fuyu/test_processing_fuyu.py @@ -26,16 +26,14 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here """ """ def setUp(self): - pretrained_model_name = "huggingface/pre_release_model" - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) - image_processor = FuyuImageProcessor() + pretrained_model_name = "adept/fuyu-8b" + self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) + self.image_processor = FuyuImageProcessor() - processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer) - text_prompt = "Generate a coco-style caption.\\n" + self.processor = FuyuProcessor(image_processor=self.image_processor, tokenizer=self.tokenizer) + self.text_prompt = "Generate a coco-style caption.\\n" bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png" - bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content)) - - self.one_image_bus_model_inputs = processor(text=text_prompt, images=bus_image_pil) + self.bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content)) def test_fuyu_processing(self): """ @@ -44,11 +42,119 @@ def test_fuyu_processing(self): # fmt: off EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64) EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64) + + one_image_bus_model_inputs = self.processor(text=self.text_prompt, images=self.bus_image_pil) + + # fmt: on + torch.testing.assert_close(one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS) + torch.testing.assert_close(one_image_bus_model_inputs["input_ids"], EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS) + + def test_fuyu_processing_no_image(self): + """ + Test to check processor works with just text input + """ + processor_outputs = self.processor(text=self.text_prompt) + tokenizer_outputs = self.tokenizer(self.text_prompt) + self.assertEqual(processor_outputs["input_ids"], tokenizer_outputs["input_ids"]) + + def test_fuyu_processing_no_text(self): + """ + Test to check processor works with just image input + """ + # fmt: off + EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([ + [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, + 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, + 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, + 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, + 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, + 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, + 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, + 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, + 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, + 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, + 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, + 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, + 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, + -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, + 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, + 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, + 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, + 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] + ]).to(torch.int64) + # fmt: on + + processor_outputs = self.processor(images=self.bus_image_pil) + self.assertTrue((processor_outputs["image_patches_indices"] == EXPECTED_IMAGE_PATCH_INPUTS).all()) + + def test_fuyu_processing_multiple_image_sample(self): + """ + Test to check processor works with multiple image inputs for a single text input + """ + # fmt: off + SINGLE_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64) + SINGLE_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64) + + SINGLE_RESIZED_IMAGE_PATCH_INPUTS = torch.Tensor([[ 0, 1, 2, -1, 3, 4, 5, -1, 6, 7, 8, -1, 9, 10, 11, -1, 12, 13, 14, -1, 15, 16, 17, -1, 18, 19, 20, -1, 21, 22, 23, -1, 24, 25, 26, -1, 27, 28, 29, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]]) + SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[ 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122]]) # fmt: on - torch.testing.assert_close( - self.one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS + + # Batch of two images - equally sized + images = [self.bus_image_pil, self.bus_image_pil] + processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images) + + self.assertTrue( + ( + processor_outputs["image_patches_indices"] + == torch.cat([SINGLE_IMAGE_PATCH_INPUTS, SINGLE_IMAGE_PATCH_INPUTS], dim=0) + ).all() + ) + self.assertTrue( + ( + processor_outputs["input_ids"] + == torch.cat([SINGLE_PADDED_UNPACKED_TOKEN_INPUTS, SINGLE_PADDED_UNPACKED_TOKEN_INPUTS], dim=0) + ).all() ) - torch.testing.assert_close(self.one_image_bus_model_inputs["input_ids"], EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS) + + # Processes single images with different sizes as expected + images = [self.bus_image_pil] + processor_outputs = self.processor(text=self.text_prompt, images=images) + self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_IMAGE_PATCH_INPUTS).all()) + self.assertTrue((processor_outputs["input_ids"] == SINGLE_PADDED_UNPACKED_TOKEN_INPUTS).all()) + + images = [self.bus_image_pil.resize((64, 300))] + processor_outputs = self.processor(text=self.text_prompt, images=images) + self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_RESIZED_IMAGE_PATCH_INPUTS).all()) + self.assertTrue((processor_outputs["input_ids"] == SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS).all()) + + # Batch of two images - different sizes. Left-pads the smaller image inputs + images = [self.bus_image_pil, self.bus_image_pil.resize((64, 300))] + processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images) + + padding_len_patch = SINGLE_IMAGE_PATCH_INPUTS.shape[1] - SINGLE_RESIZED_IMAGE_PATCH_INPUTS.shape[1] + padded_single_resized_image_patch = torch.cat( + [torch.ones([1, padding_len_patch]) * -1, SINGLE_RESIZED_IMAGE_PATCH_INPUTS], dim=1 + ) + expected_image_patch_inputs = torch.cat([SINGLE_IMAGE_PATCH_INPUTS, padded_single_resized_image_patch], dim=0) + + padding_len_token = ( + SINGLE_PADDED_UNPACKED_TOKEN_INPUTS.shape[1] - SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS.shape[1] + ) + padded_single_resized_padded_unpacked_token_inputs = torch.cat( + [torch.zeros([1, padding_len_token]), SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS], dim=1 + ) + expected_padded_unpacked_token_inputs = torch.cat( + [SINGLE_PADDED_UNPACKED_TOKEN_INPUTS, padded_single_resized_padded_unpacked_token_inputs], dim=0 + ) + + self.assertTrue((processor_outputs["image_patches_indices"] == expected_image_patch_inputs).all()) + self.assertTrue((processor_outputs["input_ids"] == expected_padded_unpacked_token_inputs).all()) @require_torch @@ -97,7 +203,6 @@ def setUp(self): """ Adding a mix of present and absent images. """ - self.image_processor = FuyuImageProcessor() self.image_input = torch.randn([1, 1, 3, 64, 64]) self.image_present = torch.tensor([[1]]) @@ -108,19 +213,19 @@ def setUp(self): self.image_placeholder_id = 999 self.image_newline_id = 888 self.variable_sized = True + self.image_processor = FuyuImageProcessor( + patch_size={"height": self.image_patch_dim_h, "width": self.image_patch_dim_w} + ) def test_process_images_for_model_input_fixed_sized(self): self.variable_sized = False - result = self.image_processor.process_images_for_model_input( + result = self.image_processor.preprocess_with_tokenizer_info( image_input=self.image_input, image_present=self.image_present, image_unpadded_h=self.image_unpadded_h, image_unpadded_w=self.image_unpadded_w, - image_patch_dim_h=self.image_patch_dim_h, - image_patch_dim_w=self.image_patch_dim_w, image_placeholder_id=self.image_placeholder_id, image_newline_id=self.image_newline_id, variable_sized=self.variable_sized, ) - print(result["images"][0][0]) self.assertEqual(result["images"][0][0].shape, torch.Size([3, 64, 64]))