diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py
index 18218534c0cc27..b626ff3dd717ec 100644
--- a/src/transformers/feature_extraction_utils.py
+++ b/src/transformers/feature_extraction_utils.py
@@ -112,17 +112,9 @@ def values(self):
def items(self):
return self.data.items()
- def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
- """
- Convert the inner content to tensors.
-
- Args:
- tensor_type (`str` or [`~utils.TensorType`], *optional*):
- The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
- `None`, no modification is done.
- """
+ def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None):
if tensor_type is None:
- return self
+ return None, None
# Convert to TensorType
if not isinstance(tensor_type, TensorType):
@@ -167,6 +159,21 @@ def as_tensor(value, dtype=None):
return np.asarray(value, dtype=dtype)
is_tensor = is_numpy_array
+ return is_tensor, as_tensor
+
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
+ """
+ Convert the inner content to tensors.
+
+ Args:
+ tensor_type (`str` or [`~utils.TensorType`], *optional*):
+ The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
+ `None`, no modification is done.
+ """
+ if tensor_type is None:
+ return self
+
+ is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
# Do the tensor conversion in batch
for key, value in self.items():
diff --git a/src/transformers/models/fuyu/image_processing_fuyu.py b/src/transformers/models/fuyu/image_processing_fuyu.py
index 2d83e18af40788..0e415980c97fdd 100644
--- a/src/transformers/models/fuyu/image_processing_fuyu.py
+++ b/src/transformers/models/fuyu/image_processing_fuyu.py
@@ -1,27 +1,182 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Fuyu."""
+
import math
-from typing import List, Union
+from typing import Dict, List, Optional, Union
import numpy as np
-from ...image_processing_utils import BaseImageProcessor
+from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import (
- normalize,
pad,
resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ is_valid_image,
+ make_list_of_images,
+ to_numpy_array,
+)
+from ...utils import (
+ TensorType,
+ is_torch_available,
+ is_torch_device,
+ is_torch_dtype,
+ logging,
+ requires_backends,
)
-from ...image_utils import to_numpy_array
-from ...utils import is_torch_available, is_vision_available, logging, requires_backends
-
-if is_vision_available():
- import PIL
if is_torch_available():
import torch
+
logger = logging.get_logger(__name__)
+def make_list_of_list_of_images(
+ images: Union[List[List[ImageInput]], List[ImageInput], ImageInput]
+) -> List[List[ImageInput]]:
+ if is_valid_image(images):
+ return [[images]]
+
+ if isinstance(images, list) and all(isinstance(image, list) for image in images):
+ return images
+
+ if isinstance(images, list):
+ return [make_list_of_images(image) for image in images]
+
+ raise ValueError("images must be a list of list of images or a list of images or an image.")
+
+
+class FuyuBatchFeature(BatchFeature):
+ """
+ BatchFeature class for Fuyu image processor and processor.
+
+ The outputs dictionary from the processors contains a mix of tensors and lists of tensors.
+ """
+
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
+ """
+ Convert the inner content to tensors.
+
+ Args:
+ tensor_type (`str` or [`~utils.TensorType`], *optional*):
+ The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
+ `None`, no modification is done.
+ """
+ if tensor_type is None:
+ return self
+
+ is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type=tensor_type)
+
+ def _convert_tensor(elem):
+ if is_tensor(elem):
+ return elem
+ return as_tensor(elem)
+
+ def _safe_convert_tensor(elem):
+ try:
+ return _convert_tensor(elem)
+ except: # noqa E722
+ if key == "overflowing_values":
+ raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
+ raise ValueError(
+ "Unable to create tensor, you should probably activate padding "
+ "with 'padding=True' to have batched tensors with the same length."
+ )
+
+ # Do the tensor conversion in batch
+ for key, value in self.items():
+ if isinstance(value, list) and isinstance(value[0], list):
+ # List[List[Any]] -> List[List[Tensor]]
+ self[key] = [[_safe_convert_tensor(elem) for elem in elems] for elems in value]
+ elif isinstance(value, list):
+ # List[Any] -> List[Tensor]
+ self[key] = [_safe_convert_tensor(elem) for elem in value]
+ else:
+ # Any -> Tensor
+ self[key] = _safe_convert_tensor(value)
+ return self
+
+ def to(self, *args, **kwargs) -> "BatchFeature":
+ """
+ Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
+ different `dtypes` and sending the `BatchFeature` to a different `device`.
+
+ Args:
+ args (`Tuple`):
+ Will be passed to the `to(...)` function of the tensors.
+ kwargs (`Dict`, *optional*):
+ Will be passed to the `to(...)` function of the tensors.
+
+ Returns:
+ [`BatchFeature`]: The same instance after modification.
+ """
+ requires_backends(self, ["torch"])
+ import torch # noqa
+
+ new_data = {}
+ device = kwargs.get("device")
+ # Check if the args are a device or a dtype
+ if device is None and len(args) > 0:
+ # device should be always the first argument
+ arg = args[0]
+ if is_torch_dtype(arg):
+ # The first argument is a dtype
+ pass
+ elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
+ device = arg
+ else:
+ # it's something else
+ raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
+
+ def _to(elem):
+ # check if v is a floating point
+ if torch.is_floating_point(elem):
+ # cast and send to device
+ return elem.to(*args, **kwargs)
+ if device is not None:
+ return elem.to(device=device)
+
+ return elem
+
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
+ for k, v in self.items():
+ if isinstance(v, list) and isinstance(v[0], list):
+ # Data structure is a list of lists
+ new_v = []
+ for elems in v:
+ new_v.append([_to(elem) for elem in elems])
+ new_data[k] = new_v
+ elif isinstance(v, list):
+ # Data structure is a list
+ new_data[k] = [_to(elem) for elem in v]
+ else:
+ new_data[k] = _to(v)
+ self.data = new_data
+ return self
+
+
class FuyuImageProcessor(BaseImageProcessor):
"""
This class should handle the image processing part before the main FuyuForCausalLM. In particular, it should
@@ -29,9 +184,9 @@ class FuyuImageProcessor(BaseImageProcessor):
- Processing Images:
Taking a batch of images as input. If the images are variable-sized, it resizes them based on the desired patch
- dimensions. The image output is always img_h ........................................... 1080 img_w
- ........................................... 1920 Then, it patches up these images using the patchify_image
- function.
+ dimensions. The image output is always img_h, img_w of (1080, 1920)
+
+ Then, it patches up these images using the patchify_image function.
- Creating Image Input IDs:
For each patch, a placeholder ID is given to identify where these patches belong in a token sequence. For
@@ -40,6 +195,32 @@ class FuyuImageProcessor(BaseImageProcessor):
- Image Patch Indices:
For each image patch, the code maintains an index where these patches should be inserted in a token stream.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image to `size`.
+ size (`Dict[str, int]`, *optional*, defaults to `{"height": 1080, "width": 1920}`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image to `size`.
+ padding_value (`float`, *optional*, defaults to 1.0):
+ The value to pad the image with.
+ padding_mode (`str`, *optional*, defaults to `"constant"`):
+ The padding mode to use when padding the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image.
+ image_mean (`float`, *optional*, defaults to 0.5):
+ The mean to use when normalizing the image.
+ image_std (`float`, *optional*, defaults to 0.5):
+ The standard deviation to use when normalizing the image.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `1 / 255`):
+ The factor to use when rescaling the image.
+ patch_size (`Dict[str, int]`, *optional*, defaults to `{"height": 30, "width": 30}`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
"""
model_input_names = [
@@ -51,204 +232,483 @@ class FuyuImageProcessor(BaseImageProcessor):
]
def __init__(
- self, target_height=1080, target_width=1920, padding_value=1.0, padding_mode: str = "constant", **kwargs
+ self,
+ do_resize: bool = True,
+ size: Optional[Dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_pad: bool = True,
+ padding_value: float = 1.0,
+ padding_mode: str = "constant",
+ do_normalize: bool = True,
+ image_mean: Union[float, List[float]] = 0.5,
+ image_std: Union[float, List[float]] = 0.5,
+ do_rescale: bool = True,
+ rescale_factor: float = 1 / 255,
+ patch_size: Optional[Dict[str, int]] = None,
+ **kwargs,
):
super().__init__(**kwargs)
- self.target_width = target_width
- self.target_height = target_height
+ self.do_resize = do_resize
+ self.size = size if size is not None else {"height": 1080, "width": 1920}
+ self.resample = resample
+ self.do_pad = do_pad
self.padding_value = padding_value
self.padding_mode = padding_mode
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.patch_size = patch_size if patch_size is not None else {"height": 30, "width": 30}
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
- def get_num_patches(self, img_h: int, img_w: int, patch_dim_h: int, patch_dim_w: int) -> int:
- """Calculate number of patches required to encode an image."""
- if img_h % patch_dim_h != 0:
- raise ValueError(f"{img_h=} must be divisible by {patch_dim_h=}")
- if img_w % patch_dim_w != 0:
- raise ValueError(f"{img_w=} must be divisible by {patch_dim_w=}")
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ image_height, image_width = get_image_size(image, input_data_format)
+ target_height, target_width = size["height"], size["width"]
- num_patches_per_dim_h = img_h // patch_dim_h
- num_patches_per_dim_w = img_w // patch_dim_w
- num_patches = num_patches_per_dim_h * num_patches_per_dim_w
+ if image_width <= target_width and image_height <= target_height:
+ return image
+
+ height_scale_factor = target_height / image_height
+ width_scale_factor = target_width / image_width
+ optimal_scale_factor = min(height_scale_factor, width_scale_factor)
+ new_height = int(image_height * optimal_scale_factor)
+ new_width = int(image_width * optimal_scale_factor)
+
+ scaled_image = resize(
+ image=image,
+ size=(new_height, new_width),
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ return scaled_image
+
+ def pad_image(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ mode: str = "constant",
+ constant_values: float = 1.0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pad an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to pad.
+ size (`Dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The data format of the output image. If unset, the same format as the input image is used.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ image_height, image_width = get_image_size(image, input_data_format)
+ target_height, target_width = size["height"], size["width"]
+ padding_top = 0
+ padding_left = 0
+ padding_bottom = target_height - image_height
+ padding_right = target_width - image_width
+ padded_image = pad(
+ image,
+ padding=((padding_top, padding_bottom), (padding_left, padding_right)),
+ mode=mode,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ return padded_image
+
+ def preprocess(
+ self,
+ images,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_pad: Optional[bool] = None,
+ padding_value: Optional[float] = None,
+ padding_mode: Optional[str] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[float] = None,
+ image_std: Optional[float] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ patch_size: Optional[Dict[str, int]] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ return_tensors: Optional[TensorType] = None,
+ ):
+ """
+
+ Utility function to preprocess the images and extract necessary information about original formats.
+
+ Args:
+ images (`ImageInput`):
+ Images to preprocess. Expects a single image, a list or images or a list of lists of images. Pixel
+ values range from 0 to 255, or between 0 and 1 if `do_rescale` is `False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image to `size`.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image to `size`.
+ padding_value (`float`, *optional*, defaults to `self.padding_value`):
+ The value to pad the image with.
+ padding_mode (`str`, *optional*, defaults to `self.padding_mode`):
+ The padding mode to use when padding the image.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float`, *optional*, defaults to `self.image_mean`):
+ The mean to use when normalizing the image.
+ image_std (`float`, *optional*, defaults to `self.image_std`):
+ The standard deviation to use when normalizing the image.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ The factor to use when rescaling the image.
+ patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format of the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ resample = resample if resample is not None else self.resample
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ padding_value = padding_value if padding_value is not None else self.padding_value
+ padding_mode = padding_mode if padding_mode is not None else self.padding_mode
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ patch_size = patch_size if patch_size is not None else self.patch_size
+
+ if isinstance(images, list) and any(isinstance(elem, list) and len(elem) >= 2 for elem in images):
+ raise ValueError("Multiple images for a single sample are not yet supported.")
+
+ batch_images = make_list_of_list_of_images(images)
+
+ if do_resize and size is None:
+ raise ValueError("Size must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and image_mean is None or image_std is None:
+ raise ValueError("image_mean and image_std must be specified if do_normalize is True.")
+
+ # All transformations expect numpy arrays.
+ batch_images = [[to_numpy_array(image) for image in images] for images in batch_images]
+
+ if is_scaled_image(batch_images[0][0]) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(batch_images[0][0])
+
+ original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
+
+ if do_resize:
+ batch_images = [
+ [self.resize(image, size=size, input_data_format=input_data_format) for image in images]
+ for images in batch_images
+ ]
+
+ image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
+ image_unpadded_heights = [[image_size[0]] for image_size in image_sizes]
+ image_unpadded_widths = [[image_size[1]] for image_size in image_sizes]
+
+ # scale_h is the same as scale_w
+ image_scale_factors = [
+ [resized_size[0] / original_size[0]]
+ for original_size, resized_size in zip(original_image_sizes, image_sizes)
+ ]
+
+ if do_pad:
+ batch_images = [
+ [
+ self.pad_image(
+ image,
+ size=size,
+ mode=padding_mode,
+ constant_values=padding_value,
+ input_data_format=input_data_format,
+ )
+ for image in images
+ ]
+ for images in batch_images
+ ]
+
+ if do_rescale:
+ batch_images = [
+ [self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) for image in images]
+ for images in batch_images
+ ]
+
+ if do_normalize:
+ batch_images = [
+ [
+ self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+ for images in batch_images
+ ]
+
+ if data_format is not None:
+ batch_images = [
+ [to_channel_dimension_format(image, data_format, input_data_format) for image in images]
+ for images in batch_images
+ ]
+
+ data = {
+ "images": batch_images,
+ "image_unpadded_heights": image_unpadded_heights,
+ "image_unpadded_widths": image_unpadded_widths,
+ "image_scale_factors": image_scale_factors,
+ }
+ return FuyuBatchFeature(data=data, tensor_type=return_tensors)
+
+ def get_num_patches(self, image_height: int, image_width: int, patch_size: Dict[str, int] = None) -> int:
+ """
+ Calculate number of patches required to encode an image.
+
+ Args:
+ image_height (`int`):
+ Height of the image.
+ image_width (`int`):
+ Width of the image.
+ patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
+ """
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
+
+ if image_height % patch_height != 0:
+ raise ValueError(f"{image_height=} must be divisible by {patch_height}")
+ if image_width % patch_width != 0:
+ raise ValueError(f"{image_width=} must be divisible by {patch_width}")
+
+ num_patches_per_dim_h = image_height // patch_height
+ num_patches_per_dim_w = image_width // patch_width
+ num_patches = num_patches_per_dim_h * num_patches_per_dim_w
return num_patches
- def patchify_image(self, image: "torch.Tensor", patch_dim_h: int, patch_dim_w: int) -> "torch.Tensor":
+ def patchify_image(self, image: "torch.Tensor", patch_size: Optional[Dict[str, int]] = None) -> "torch.Tensor":
"""
Convert an image into a tensor of patches.
Args:
- image: Image to convert. Shape: [batch, channels, height, width]
- patch_dim_h: Height of each patch.
- patch_dim_w: Width of each patch.
+ image (`torch.Tensor`):
+ Image to convert. Shape: [batch, channels, height, width]
+ patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
"""
requires_backends(self, ["torch"])
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ patch_height, patch_width = patch_size["height"], patch_size["width"]
# TODO refer to https://github.com/ArthurZucker/transformers/blob/0f0a3fe5ca5697ee58faeb5b53f049af720b5e98/src/transformers/models/vit_mae/modeling_vit_mae.py#L871
# torch implementation is faster but does not handle non-squares
- batch_size, channels, height, width = image.shape
- unfolded_along_height = image.unfold(2, patch_dim_h, patch_dim_h)
- patches = unfolded_along_height.unfold(3, patch_dim_w, patch_dim_w)
-
- patches_reshaped = patches.contiguous().view(batch_size, channels, -1, patch_dim_h, patch_dim_w)
-
- patches_final = patches_reshaped.permute(0, 2, 3, 4, 1).reshape(
- batch_size, -1, channels * patch_dim_h * patch_dim_w
- )
-
- return patches_final
+ batch_size, channels, _, _ = image.shape
+ unfolded_along_height = image.unfold(2, patch_height, patch_height)
+ patches = unfolded_along_height.unfold(3, patch_width, patch_width)
+ patches = patches.contiguous()
+ patches = patches.view(batch_size, channels, -1, patch_height, patch_width)
+ patches = patches.permute(0, 2, 3, 4, 1)
+ patches = patches.reshape(batch_size, -1, channels * patch_height * patch_width)
+ return patches
- def process_images_for_model_input(
+ def preprocess_with_tokenizer_info(
self,
image_input: "torch.Tensor",
image_present: "torch.Tensor",
image_unpadded_h: "torch.Tensor",
image_unpadded_w: "torch.Tensor",
- image_patch_dim_h: int,
- image_patch_dim_w: int,
image_placeholder_id: int,
image_newline_id: int,
variable_sized: bool,
- ) -> dict:
+ patch_size: Optional[Dict[str, int]] = None,
+ ) -> FuyuBatchFeature:
"""Process images for model input. In particular, variable-sized images are handled here.
Args:
- image_input: [batch_size, 1, c, h, w] tensor of images padded to model input size.
- image_present: [batch_size, 1] tensor of 1s and 0s indicating whether an image is present.
- image_unpadded_h: [batch_size, 1] tensor of unpadded image heights.
- image_unpadded_w: [batch_size, 1] tensor of unpadded image widths.
- image_patch_dim_h: The height of the image patches.
- image_patch_dim_w: The width of the image patches.
- image_placeholder_id: The id of the image placeholder token.
- image_newline_id: The id of the image newline token.
- variable_sized: Whether to process images as variable-sized.
+ image_input (`torch.Tensor` of shape [batch_size, subsequence_size, num_channels, height, width]):
+ Tensor of images padded to model input size.
+ image_present (`torch.Tensor` of shape [batch_size, subsequence_size, num_images]):
+ Tensor of 1s and 0s indicating whether an image is present.
+ image_unpadded_h (`torch.Tensor` of shape [batch_size, subsequence_size]):
+ Tensor of unpadded image heights.
+ image_unpadded_w (`torch.Tensor` of shape [batch_size, subsequence_size]):
+ Tensor of unpadded image widths.
+ image_placeholder_id (int):
+ The id of the image placeholder token. Comes from an associated tokenizer.
+ image_newline_id (int):
+ The id of the image newline token. Comes from an associated tokenizer.
+ variable_sized (bool):
+ Whether to process images as variable-sized.
+ patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
+ Size of the patches.
"""
requires_backends(self, ["torch"])
+
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ patch_height, patch_width = patch_size["height"], patch_size["width"]
+
# Only images that are present.
images: List[List[torch.Tensor]] = []
- image_patches: List[List[torch.Tensor]] = []
+ batch_image_patches: List[List[torch.Tensor]] = []
# Image input ids for every subsequence, including ones with no image present.
- image_input_ids: List[List[torch.Tensor]] = []
- for bi in range(image_input.shape[0]):
- images.append([])
- image_input_ids.append([])
- image_patches.append([])
- for si in range(image_input.shape[1]):
- if image_present[bi, si]:
- image = image_input[bi, si]
+ batch_image_input_ids: List[List[torch.Tensor]] = []
+ for batch_index in range(image_input.shape[0]):
+ image_input_ids = []
+ image_patches = []
+ for subseq_index in range(image_input.shape[1]):
+ if image_present[batch_index, subseq_index]:
+ image = image_input[batch_index, subseq_index]
+ image_height, image_width = image.shape[1], image.shape[2]
if variable_sized:
# The min() is required here due to floating point issues:
# math.ceil(torch.tensor(300).cuda() / 30) == 11
new_h = min(
- image.shape[1], math.ceil(image_unpadded_h[bi, si] / image_patch_dim_h) * image_patch_dim_h
+ image_height,
+ math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height,
)
new_w = min(
- image.shape[2], math.ceil(image_unpadded_w[bi, si] / image_patch_dim_w) * image_patch_dim_w
+ image_width,
+ math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width,
)
image = image[:, :new_h, :new_w]
- images[bi].append(image)
- num_patches = self.get_num_patches(
- img_h=image.shape[1],
- img_w=image.shape[2],
- patch_dim_h=image_patch_dim_h,
- patch_dim_w=image_patch_dim_w,
+ image_height, image_width = new_h, new_w
+
+ num_patches = self.get_num_patches(image_height=image_height, image_width=image_width)
+ tensor_of_image_ids = torch.full(
+ [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
)
- ids = torch.full([num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device)
- patches = self.patchify_image(
- image=image.unsqueeze(0), patch_dim_h=image_patch_dim_h, patch_dim_w=image_patch_dim_w
- ).squeeze(0)
+ patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
+ assert num_patches == patches.shape[0]
+
if variable_sized:
# Now terminate each line with |NEWLINE|.
- ids = ids.reshape(-1, new_w // image_patch_dim_w)
- ids = torch.cat(
- [
- ids,
- torch.full(
- [ids.shape[0], 1], image_newline_id, dtype=torch.int32, device=image_input.device
- ),
- ],
- dim=1,
+ tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width)
+ newline_ids = torch.full(
+ [tensor_of_image_ids.shape[0], 1],
+ image_newline_id,
+ dtype=torch.int32,
+ device=image_input.device,
)
- ids = ids.reshape(-1)
- image_input_ids[bi].append(ids)
- image_patches[bi].append(patches)
+ tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1)
+ tensor_of_image_ids = tensor_of_image_ids.reshape(-1)
+
+ images.append([image])
+ image_input_ids.append(tensor_of_image_ids)
+ image_patches.append(patches)
else:
- image_input_ids[bi].append(torch.tensor([], dtype=torch.int32, device=image_input.device))
+ image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device))
+
+ batch_image_input_ids.append(image_input_ids)
+ batch_image_patches.append(image_patches)
# Create image_patch_input_indices, where non-negative values correspond to image patches to be inserted in
# the stream.
image_patch_indices_per_batch: List[List[torch.Tensor]] = []
image_patch_indices_per_subsequence: List[List[torch.Tensor]] = []
- for bi in range(len(image_input_ids)):
- image_patch_indices_per_batch.append([])
- image_patch_indices_per_subsequence.append([])
+
+ for sample_image_input_ids in batch_image_input_ids:
index_offset = 0
- for si in range(len(image_input_ids[bi])):
+ per_batch_indices = []
+ per_subsequence_indices = []
+ for subseq_image_input_ids in sample_image_input_ids:
# Indices of image patches.
- num_patches = torch.count_nonzero(image_input_ids[bi][si] == image_placeholder_id)
+ patches_mask = subseq_image_input_ids == image_placeholder_id
+ num_patches = torch.count_nonzero(patches_mask)
indices = torch.arange(
- num_patches,
- dtype=image_input_ids[bi][si].dtype,
- device=image_input_ids[bi][si].device,
+ num_patches, dtype=subseq_image_input_ids.dtype, device=subseq_image_input_ids.device
)
# Place those indices in the image input ids token stream, with -1 representing non-index tokens.
- indices_in_stream_per_batch = torch.full_like(image_input_ids[bi][si], -1)
- indices_in_stream_per_subsequence = torch.full_like(image_input_ids[bi][si], -1)
- indices_in_stream_per_batch[
- torch.nonzero(image_input_ids[bi][si] == image_placeholder_id, as_tuple=True)[0]
- ] = (indices + index_offset)
- indices_in_stream_per_subsequence[
- torch.nonzero(image_input_ids[bi][si] == image_placeholder_id, as_tuple=True)[0]
- ] = indices
-
- image_patch_indices_per_batch[bi].append(indices_in_stream_per_batch)
- image_patch_indices_per_subsequence[bi].append(indices_in_stream_per_subsequence)
- index_offset += num_patches
-
- return {
- "images": images,
- "image_input_ids": image_input_ids,
- "image_patches": image_patches,
- "image_patch_indices_per_batch": image_patch_indices_per_batch,
- "image_patch_indices_per_subsequence": image_patch_indices_per_subsequence,
- }
+ indices_in_stream_per_batch = torch.full_like(subseq_image_input_ids, -1)
+ indices_in_stream_per_subsequence = torch.full_like(subseq_image_input_ids, -1)
+ patches_inds = torch.nonzero(patches_mask, as_tuple=True)[0]
- def _scale_to_target_aspect_ratio(self, image: np.ndarray) -> np.ndarray:
- image_height, image_width, _ = image.shape
- if image_width <= self.target_width and image_height <= self.target_height:
- return image
-
- height_scale_factor = self.target_height / image_height
- width_scale_factor = self.target_width / image_width
- optimal_scale_factor = min(height_scale_factor, width_scale_factor)
+ indices_in_stream_per_batch[patches_inds] = indices + index_offset
+ indices_in_stream_per_subsequence[patches_inds] = indices
- new_height = int(image_height * optimal_scale_factor)
- new_width = int(image_width * optimal_scale_factor)
-
- scaled_image = resize(image=image, size=(new_height, new_width))
- return np.array(scaled_image)
-
- def _pad_to_target_size(self, image: np.ndarray) -> np.ndarray:
- image_height, image_width, _ = image.shape
-
- padding_top = 0
- padding_left = 0
- padding_bottom = self.target_height - image_height
- padding_right = self.target_width - image_width
+ per_batch_indices.append(indices_in_stream_per_batch)
+ per_subsequence_indices.append(indices_in_stream_per_subsequence)
+ index_offset += num_patches
- padded_image = pad(
- image,
- ((padding_top, padding_bottom), (padding_left, padding_right)),
- mode=self.padding_mode,
- constant_values=self.padding_value,
+ image_patch_indices_per_batch.append(per_batch_indices)
+ image_patch_indices_per_subsequence.append(per_subsequence_indices)
+
+ return FuyuBatchFeature(
+ data={
+ "images": images,
+ "image_input_ids": batch_image_input_ids,
+ "image_patches": batch_image_patches,
+ "image_patch_indices_per_batch": image_patch_indices_per_batch,
+ "image_patch_indices_per_subsequence": image_patch_indices_per_subsequence,
+ }
)
- return padded_image
-
- def apply_transformation(self, image: Union[np.ndarray, PIL.Image.Image]) -> np.ndarray:
- if isinstance(image, PIL.Image.Image):
- image = to_numpy_array(image)
- scaled_image = self._scale_to_target_aspect_ratio(image)
- padded_image = self._pad_to_target_size(scaled_image)
- normalized_padded_image = normalize(padded_image, 0.5, 0.5)
- return normalized_padded_image
diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py
index 89127843befe08..345d0a0e92a5ee 100644
--- a/src/transformers/models/fuyu/modeling_fuyu.py
+++ b/src/transformers/models/fuyu/modeling_fuyu.py
@@ -257,8 +257,10 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None:
- patch_embeddings = self.vision_embed_tokens(image_patches.to(self.vision_embed_tokens.weight.dtype))
- patch_embeddings = patch_embeddings.to(inputs_embeds.device)
+ patch_embeddings = [
+ self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
+ for patch in image_patches
+ ]
inputs_embeds = self.gather_continuous_embeddings(
word_embeddings=inputs_embeds,
continuous_embeddings=patch_embeddings,
diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py
index ea660b072d721a..e0f362a6c8763b 100644
--- a/src/transformers/models/fuyu/processing_fuyu.py
+++ b/src/transformers/models/fuyu/processing_fuyu.py
@@ -1,45 +1,50 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Image/Text processor class for GIT
+"""
import re
-from typing import Any, Iterable, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import numpy as np
-from ...image_utils import (
- ChannelDimension,
- get_image_size,
- infer_channel_dimension_format,
- is_scaled_image,
- to_numpy_array,
-)
from ...processing_utils import ProcessorMixin
-from ...utils import is_torch_available, is_vision_available, logging
+from ...tokenization_utils_base import PaddingStrategy, TruncationStrategy
+from ...utils import TensorType, is_torch_available, logging, requires_backends
-if is_torch_available() and is_vision_available():
- from .image_processing_fuyu import FuyuImageProcessor
+if is_torch_available():
+ from .image_processing_fuyu import FuyuBatchFeature
logger = logging.get_logger(__name__)
-if is_vision_available():
- import PIL
if is_torch_available():
import torch
-BBOX_OPEN_STRING = "<0x00>" #
-BBOX_CLOSE_STRING = "<0x01>" #
-POINT_OPEN_STRING = "<0x02>" #
-POINT_CLOSE_STRING = "<0x03>" #
TEXT_REPR_BBOX_OPEN = ""
TEXT_REPR_BBOX_CLOSE = ""
TEXT_REPR_POINT_OPEN = ""
TEXT_REPR_POINT_CLOSE = ""
-TOKEN_BBOX_OPEN_STRING = BBOX_OPEN_STRING = "<0x00>" #
-BBOX_CLOSE_STRING = "<0x01>" #
-TOKEN_BBOX_CLOSE_STRING = TOKEN_POINT_OPEN_STRING = POINT_OPEN_STRING = "<0x02>" #
-TOKEN_POINT_CLOSE_STRING = POINT_CLOSE_STRING = "<0x03>" #
+TOKEN_BBOX_OPEN_STRING = "<0x00>" #
+TOKEN_BBOX_CLOSE_STRING = "<0x01>" #
+TOKEN_POINT_OPEN_STRING = "<0x02>" #
+TOKEN_POINT_CLOSE_STRING = "<0x03>" #
BEGINNING_OF_ANSWER_STRING = "<0x04>" #
@@ -87,18 +92,16 @@ def construct_full_unpacked_stream(
all_bi_stream = []
- for bi in range(batch_size):
+ for batch_index in range(batch_size):
all_si_stream = []
# First, construct full token stream (including image placeholder tokens) and loss mask for each subsequence
# and append to lists. We use lists rather than tensors because each subsequence is variable-sized.
- for si in range(num_sub_sequences):
- image_adjustment = image_tokens[bi][si]
- si_stream = torch.cat([image_adjustment, input_stream[bi, si]], dim=0)
- num_real_tokens = image_adjustment.shape[0] + num_real_text_tokens[bi][si]
-
- all_si_stream.append(si_stream[:num_real_tokens])
- # Combine all subsequences for this batch entry. Still using a list because each batch entry is variable-sized.
+ # TODO Remove this logic in a subsequent release since subsequences are not supported.
+ image_adjustment = image_tokens[batch_index][0]
+ subsequence_stream = torch.cat([image_adjustment, input_stream[batch_index, 0]], dim=0)
+ num_real_tokens = image_adjustment.shape[0] + num_real_text_tokens[batch_index][0]
+ all_si_stream.append(subsequence_stream[:num_real_tokens])
all_bi_stream.append(torch.cat(all_si_stream, dim=0))
return all_bi_stream
@@ -137,7 +140,7 @@ def _segment_prompt_into_text_token_conversions(prompt: str) -> List:
return prompt_text_list
-def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenizer) -> List[int]:
+def _transform_coordinates_and_tokenize(prompt: str, scale_factor: float, tokenizer) -> List[int]:
"""
This function transforms the prompt in the following fashion:
- and to their respective token mappings
@@ -161,7 +164,7 @@ def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenize
for elem in prompt_text_list:
if elem[1]:
# This is a location, we need to tokenize it
- within_tag_tokenized = _transform_within_tags(elem[0], transformed_image, tokenizer)
+ within_tag_tokenized = _transform_within_tags(elem[0], scale_factor, tokenizer)
# Surround the text with the open and close tags
transformed_prompt_tokens.extend(within_tag_tokenized)
else:
@@ -169,7 +172,7 @@ def _transform_coordinates_and_tokenize(prompt: str, transformed_image, tokenize
return transformed_prompt_tokens
-def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int]:
+def _transform_within_tags(text: str, scale_factor: float, tokenizer) -> List[int]:
"""
Given a bounding box of the fashion 1, 2, 3, 4 | 1, 2 This function is responsible for
converting 1, 2, 3, 4 into tokens of 1 2 3 4 without any commas.
@@ -188,16 +191,14 @@ def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int]
num_ints = [float(num.strip()) for num in num_int_strs]
# scale to transformed image siz
if len(num_ints) == 2:
- num_ints_translated = scale_point_to_transformed_image(
- x=num_ints[0], y=num_ints[1], transformed_image=transformed_image
- )
+ num_ints_translated = scale_point_to_transformed_image(x=num_ints[0], y=num_ints[1], scale_factor=scale_factor)
elif len(num_ints) == 4:
num_ints_translated = scale_bbox_to_transformed_image(
top=num_ints[0],
left=num_ints[1],
bottom=num_ints[2],
right=num_ints[3],
- transformed_image=transformed_image,
+ scale_factor=scale_factor,
)
else:
raise ValueError(f"Invalid number of ints: {len(num_ints)}")
@@ -209,7 +210,7 @@ def _transform_within_tags(text: str, transformed_image, tokenizer) -> List[int]
def _tokenize_prompts_with_image_and_batch(
tokenizer,
prompts: List[List[str]],
- transformed_images: Optional[List[List["torch.Tensor"]]],
+ scale_factors: Optional[List[List["torch.Tensor"]]],
max_tokens_to_generate: int,
max_position_embeddings: int,
add_BOS: bool, # Same issue with types as above
@@ -223,13 +224,13 @@ def _tokenize_prompts_with_image_and_batch(
"""
# If not tool use, tranform the coordinates while tokenizing
- if transformed_images is not None:
+ if scale_factors is not None:
transformed_prompt_tokens = []
- for prompt_seq, transformed_image_seq in zip(prompts, transformed_images):
+ for prompt_seq, scale_factor_seq in zip(prompts, scale_factors):
transformed_prompt_tokens.append(
[
- _transform_coordinates_and_tokenize(prompt, transformed_image, tokenizer)
- for prompt, transformed_image in zip(prompt_seq, transformed_image_seq)
+ _transform_coordinates_and_tokenize(prompt, scale_factor.item(), tokenizer)
+ for prompt, scale_factor in zip(prompt_seq, scale_factor_seq)
]
)
else:
@@ -260,7 +261,7 @@ def _tokenize_prompts_with_image_and_batch(
# Number of tokens in the each sample of the batch.
samples_length = min(max_prompt_len + max_tokens_to_generate, max_position_embeddings)
if max_prompt_len + max_tokens_to_generate > max_position_embeddings:
- print(
+ logger.warning(
f"Max subsequence prompt length of {max_prompt_len} + max tokens to generate {max_tokens_to_generate}",
f"exceeds context length of {max_position_embeddings}. Will generate as many tokens as possible.",
)
@@ -279,86 +280,30 @@ def _tokenize_prompts_with_image_and_batch(
return prompts_tokens_tensor, prompts_length_tensor
-def original_to_transformed_h_coords(self, original_coords):
- # apply crop
- cropped_coords = (
- self._clamp_coords(original_coords, min_value=self.crop_top, max_value=self.crop_bottom) - self.crop_top
- )
- # apply scale
- scaled_coords = self._scale_coords(cropped_coords, scale=self.scaled_h / self.original_h)
- # apply pad
- return scaled_coords + self.padding_top
+# Simplified assuming self.crop_top = self.padding_top = 0
+def original_to_transformed_h_coords(original_coords, scale_h):
+ return np.round(original_coords * scale_h).astype(np.int32)
-def original_to_transformed_w_coords(self, original_coords):
- # apply crop
- cropped_coords = (
- self._clamp_coords(original_coords, min_value=self.crop_left, max_value=self.crop_right) - self.crop_left
- )
- # apply scale
- scaled_coords = self._scale_coords(cropped_coords, scale=self.scaled_w / self.original_w)
- # apply pad
- return scaled_coords + self.padding_left
+# Simplified assuming self.crop_left = self.padding_left = 0
+def original_to_transformed_w_coords(original_coords, scale_w):
+ return np.round(original_coords * scale_w).astype(np.int32)
-def scale_point_to_transformed_image(x: float, y: float) -> List[int]:
- x_scaled = original_to_transformed_w_coords(np.array([x / 2]))[0]
- y_scaled = original_to_transformed_h_coords(np.array([y / 2]))[0]
+def scale_point_to_transformed_image(x: float, y: float, scale_factor: float) -> List[int]:
+ x_scaled = original_to_transformed_w_coords(np.array([x / 2]), scale_factor)[0]
+ y_scaled = original_to_transformed_h_coords(np.array([y / 2]), scale_factor)[0]
return [x_scaled, y_scaled]
-def scale_bbox_to_transformed_image(top: float, left: float, bottom: float, right: float) -> List[int]:
- top_scaled = original_to_transformed_w_coords(np.array([top / 2]))[0]
- left_scaled = original_to_transformed_h_coords(np.array([left / 2]))[0]
- bottom_scaled = original_to_transformed_w_coords(np.array([bottom / 2]))[0]
- right_scaled = original_to_transformed_h_coords(np.array([right / 2]))[0]
- return [top_scaled, left_scaled, bottom_scaled, right_scaled]
-
-
-# Copied from transformers.models.detr.image_processing_detr.max_across_indices
-def max_across_indices(values: Iterable[Any]) -> List[Any]:
- """
- Return the maximum value across all indices of an iterable of values.
- """
- return [max(values_i) for values_i in zip(*values)]
-
-
-# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
-def get_max_height_width(
- images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
+def scale_bbox_to_transformed_image(
+ top: float, left: float, bottom: float, right: float, scale_factor: float
) -> List[int]:
- """
- Get the maximum height and width across all images in a batch.
- """
- if input_data_format is None:
- input_data_format = infer_channel_dimension_format(images[0])
-
- if input_data_format == ChannelDimension.FIRST:
- _, max_height, max_width = max_across_indices([img.shape for img in images])
- elif input_data_format == ChannelDimension.LAST:
- max_height, max_width, _ = max_across_indices([img.shape for img in images])
- else:
- raise ValueError(f"Invalid channel dimension format: {input_data_format}")
- return (max_height, max_width)
-
-
-# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
-def make_pixel_mask(
- image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
-) -> np.ndarray:
- """
- Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
-
- Args:
- image (`np.ndarray`):
- Image to make the pixel mask for.
- output_size (`Tuple[int, int]`):
- Output size of the mask.
- """
- input_height, input_width = get_image_size(image, channel_dim=input_data_format)
- mask = np.zeros(output_size, dtype=np.int64)
- mask[:input_height, :input_width] = 1
- return mask
+ top_scaled = original_to_transformed_w_coords(np.array([top / 2]), scale_factor)[0]
+ left_scaled = original_to_transformed_h_coords(np.array([left / 2]), scale_factor)[0]
+ bottom_scaled = original_to_transformed_w_coords(np.array([bottom / 2]), scale_factor)[0]
+ right_scaled = original_to_transformed_h_coords(np.array([right / 2]), scale_factor)[0]
+ return [top_scaled, left_scaled, bottom_scaled, right_scaled]
class FuyuProcessor(ProcessorMixin):
@@ -384,42 +329,148 @@ def __init__(self, image_processor, tokenizer):
self.tokenizer = tokenizer
self.max_tokens_to_generate = 10
self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it?
- self.image_processor = FuyuImageProcessor()
-
- def _process_images(self, images):
- """Utility function to preprocess the images and extract necessary information about original formats."""
- batch_images = []
- image_unpadded_heights = []
- image_unpadded_widths = []
-
- for image in images:
- image = to_numpy_array(image)
- if not is_scaled_image(image):
- image = image / 255.0
- channel_dimension = infer_channel_dimension_format(image, 3)
- if channel_dimension == ChannelDimension.FIRST:
- width_index = 2
- height_index = 1
- elif channel_dimension == ChannelDimension.LAST:
- width_index = 1
- height_index = 0
-
- image_unpadded_widths.append([image.shape[width_index]])
- image_unpadded_heights.append([image.shape[height_index]])
-
- # Reproduct adept padding sampler
- padded_image = self.image_processor.apply_transformation(image)
-
- tensor_img = torch.Tensor(padded_image).permute(2, 0, 1)
- batch_images.append([tensor_img])
-
- return batch_images, torch.Tensor(image_unpadded_heights), torch.Tensor(image_unpadded_widths)
-
- def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
+ self.pad_token_id = 0
+ self.dummy_image_index = -1
+
+ def _left_pad_inputs_with_attention_mask(self, model_inputs: List[Dict], return_attention_mask: bool):
+ max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs)
+ max_length_image_patch_indices = max(entry["image_patches_indices"].shape[1] for entry in model_inputs)
+
+ batched_inputs = {"input_ids": [], "image_patches": [], "image_patches_indices": [], "attention_mask": []}
+
+ for entry in model_inputs:
+ for key, tensor in entry.items():
+ if key == "input_ids":
+ num_padding_tokens = max_length_input_ids - tensor.shape[1]
+ padded_input_ids = torch.cat(
+ [
+ torch.full((tensor.shape[0], num_padding_tokens), self.pad_token_id, dtype=torch.long),
+ tensor,
+ ],
+ dim=1,
+ )
+ batched_inputs[key].append(padded_input_ids)
+
+ attention_mask = torch.cat(
+ [torch.zeros(tensor.shape[0], num_padding_tokens, dtype=torch.long), torch.ones_like(tensor)],
+ dim=1,
+ )
+ batched_inputs["attention_mask"].append(attention_mask)
+
+ elif key == "image_patches":
+ # For image_patches, we don't pad but just append them to the list.
+ batched_inputs[key].append(tensor)
+
+ else: # for image_patches_indices
+ num_padding_indices = max_length_image_patch_indices - tensor.shape[1]
+ padded_indices = torch.cat(
+ [
+ torch.full(
+ (tensor.shape[0], num_padding_indices), self.dummy_image_index, dtype=torch.long
+ ),
+ tensor,
+ ],
+ dim=1,
+ )
+ batched_inputs[key].append(padded_indices)
+ batched_keys = ["input_ids", "image_patches_indices"]
+ if return_attention_mask:
+ batched_keys.append("attention_mask")
+ for key in batched_keys:
+ batched_inputs[key] = torch.cat(batched_inputs[key], dim=0)
+
+ return batched_inputs
+
+ def get_sample_encoding(
+ self,
+ prompts,
+ scale_factors,
+ image_unpadded_heights,
+ image_unpadded_widths,
+ image_placeholder_id,
+ image_newline_id,
+ tensor_batch_images,
+ ):
+ image_present = torch.ones(1, 1, 1)
+ model_image_input = self.image_processor.preprocess_with_tokenizer_info(
+ image_input=tensor_batch_images,
+ image_present=image_present,
+ image_unpadded_h=image_unpadded_heights,
+ image_unpadded_w=image_unpadded_widths,
+ image_placeholder_id=image_placeholder_id,
+ image_newline_id=image_newline_id,
+ variable_sized=True,
+ )
+ # FIXME max_tokens_to_generate is embedded into this processor's call.
+ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
+ tokenizer=self.tokenizer,
+ prompts=prompts,
+ scale_factors=scale_factors,
+ max_tokens_to_generate=self.max_tokens_to_generate,
+ max_position_embeddings=self.max_position_embeddings,
+ add_BOS=True,
+ add_beginning_of_answer_token=True,
+ )
+ image_padded_unpacked_tokens = construct_full_unpacked_stream(
+ num_real_text_tokens=prompts_length,
+ input_stream=prompt_tokens,
+ image_tokens=model_image_input["image_input_ids"],
+ batch_size=1,
+ num_sub_sequences=self.subsequence_length,
+ )
+ # Construct inputs for image patch indices.
+ unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream(
+ num_real_text_tokens=prompts_length,
+ input_stream=torch.full_like(prompt_tokens, -1),
+ image_tokens=model_image_input["image_patch_indices_per_batch"],
+ batch_size=1,
+ num_sub_sequences=self.subsequence_length,
+ )
+ max_prompt_length = max(x.shape[-1] for x in image_padded_unpacked_tokens)
+ max_seq_len_batch = min(max_prompt_length + self.max_tokens_to_generate, self.max_position_embeddings)
+ tokens_to_place = min(max_seq_len_batch, max(0, image_padded_unpacked_tokens[0].shape[0]))
+
+ # Use same packing logic for the image patch indices.
+ image_patch_input_indices = full_unpacked_stream_to_tensor(
+ all_bi_tokens_to_place=[tokens_to_place],
+ full_unpacked_stream=unpacked_image_patch_indices_per_batch,
+ fill_value=-1,
+ batch_size=1,
+ new_seq_len=max_seq_len_batch,
+ offset=0,
+ )
+ image_patches_tensor = torch.stack([img[0] for img in model_image_input["image_patches"]])
+ batch_encoding = {
+ "input_ids": image_padded_unpacked_tokens[0].unsqueeze(0),
+ "image_patches": image_patches_tensor,
+ "image_patches_indices": image_patch_input_indices,
+ }
+ return batch_encoding
+
+ def __call__(
+ self,
+ text=None,
+ images=None,
+ add_special_tokens: bool = True,
+ return_attention_mask: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_token_type_ids: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
+ ) -> "FuyuBatchFeature":
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to
- encode the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
+ encode the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
FuyuImageProcessor's [`~FuyuImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
@@ -433,130 +484,211 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
Returns:
- [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+ [`FuyuBatchEncoding`]: A [`FuyuBatchEncoding`] with the following fields:
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **input_ids** -- Tensor of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **image_patches** -- List of Tensor of image patches. Returned when `images` is not `None`.
+ - **image_patches_indices** -- Tensor of indices where patch embeddings have to be inserted by the model.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model when
+ `return_attention_mask=True`.
"""
+ requires_backends(self, ["torch"])
+
+ # --- Check input validity ---
+ if not return_attention_mask:
+ raise ValueError("`return_attention_mask=False` is not supported for this model.")
if text is None and images is None:
- raise ValueError("You have to specify either text or images. Both cannot be none.")
+ raise ValueError("You have to specify either text or images. Both cannot be None.")
+ if text is not None and images is None:
+ logger.warning("You are processing a text with no associated image. Make sure it is intended.")
+ self.current_processor = self.tokenizer
+ text_encoding = self.tokenizer(
+ text=text,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_token_type_ids=return_token_type_ids,
+ return_length=return_length,
+ verbose=verbose,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+ return text_encoding
+
+ if text is None and images is not None:
+ logger.warning("You are processing an image with no associated text. Make sure it is intended.")
+ prompts = [[""]]
if text is not None and images is not None:
if isinstance(text, str):
prompts = [[text]]
elif isinstance(text, list):
prompts = [[text_seq] for text_seq in text]
- batch_images = []
- if isinstance(images, PIL.Image.Image):
- images = [images]
- if isinstance(images, list):
- batch_images, image_unpadded_heights, image_unpadded_widths = self._process_images(images)
- # image_unpadded_heights = image_unpadded_heights.unsqueeze(0)
- # image_unpadded_widths = image_unpadded_widths.unsqueeze(0)
- else:
- raise ValueError("images must be a list of ndarrays or PIL Images to be processed.")
-
- # Note: the original adept code has a handling of image_unpadded_h and w, but it doesn't seem to hold
- # when there are several different size subsequences per batch. The current implementation reflects
- # that limitation and should be documented.
- #
- self.subsequence_length = 1 # Each batch contains only one sequence.
- self.batch_size = len(batch_images)
- # FIXME max_tokens_to_generate is embedded into this processor's call.
- prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
- tokenizer=self.tokenizer,
- prompts=prompts,
- transformed_images=batch_images,
- max_tokens_to_generate=self.max_tokens_to_generate,
- max_position_embeddings=self.max_position_embeddings,
- add_BOS=True,
- add_beginning_of_answer_token=True,
- )
- # same so far
-
- # This is 1 if there is an image per subsequence, else 0. [batch, 1, presence]
- # the remainder of current image processing logic assumes subsequence_size = 1.
- # Here it is OK as the model cannot handle > 1 subsequences
- # the image could be absent however and image presence should be inferred from user batch input
- # hence this code assumes the images are present. Use an assert?
-
- image_present = torch.ones(self.batch_size, 1, 1)
-
- image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1]
- image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1]
- tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1)
- model_image_input = self.image_processor.process_images_for_model_input(
- image_input=tensor_batch_images,
- image_present=image_present,
- image_unpadded_h=image_unpadded_heights,
- image_unpadded_w=image_unpadded_widths,
- image_patch_dim_h=30,
- image_patch_dim_w=30,
+
+ # --- Preprocess images using self.image_processor ---
+
+ # FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors
+ image_encoding = self.image_processor.preprocess(images, return_tensors="pt")
+ batch_images = image_encoding["images"]
+ image_unpadded_heights = image_encoding["image_unpadded_heights"]
+ image_unpadded_widths = image_encoding["image_unpadded_widths"]
+ scale_factors = image_encoding["image_scale_factors"]
+ self.subsequence_length = 1 # Each batch contains only one sequence.
+ self.batch_size = len(batch_images)
+
+ # --- Use self.tokenizer to get the ids of special tokens to insert into image ids ---
+
+ image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1]
+ image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1]
+ tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1)
+
+ # --- Use self.image_processor again to obtain the full token ids and batch inputs ---
+ all_encodings = []
+
+ for prompt, scale_factor, image_unpadded_height, image_unpadded_width, tensor_batch_image in zip(
+ prompts, scale_factors, image_unpadded_heights, image_unpadded_widths, tensor_batch_images
+ ):
+ sample_encoding = self.get_sample_encoding(
+ prompts=[prompt],
+ scale_factors=[scale_factor],
+ image_unpadded_heights=torch.tensor([image_unpadded_height]),
+ image_unpadded_widths=torch.tensor([image_unpadded_width]),
image_placeholder_id=image_placeholder_id,
image_newline_id=image_newline_id,
- variable_sized=True,
+ tensor_batch_images=tensor_batch_image.unsqueeze(0),
)
+ all_encodings.append(sample_encoding)
+ batch_encoding = self._left_pad_inputs_with_attention_mask(
+ model_inputs=all_encodings, return_attention_mask=return_attention_mask
+ )
+ return FuyuBatchFeature(data=batch_encoding)
- image_padded_unpacked_tokens = construct_full_unpacked_stream(
- num_real_text_tokens=prompts_length,
- input_stream=prompt_tokens,
- image_tokens=model_image_input["image_input_ids"],
- batch_size=self.batch_size,
- num_sub_sequences=self.subsequence_length,
- )
- # Construct inputs for image patch indices.
- unpacked_image_patch_indices_per_batch = construct_full_unpacked_stream(
- num_real_text_tokens=prompts_length,
- input_stream=torch.full_like(prompt_tokens, -1),
- image_tokens=model_image_input["image_patch_indices_per_batch"],
- batch_size=self.batch_size,
- num_sub_sequences=self.subsequence_length,
- )
- max_prompt_length = max(x.shape[-1] for x in image_padded_unpacked_tokens)
- max_seq_len_batch = min(max_prompt_length + self.max_tokens_to_generate, self.max_position_embeddings)
- all_bi_tokens_to_place = []
- for bi in range(self.batch_size):
- tokens_to_place = min(max_seq_len_batch, max(0, image_padded_unpacked_tokens[bi].shape[0]))
- all_bi_tokens_to_place.append(tokens_to_place)
-
- # Use same packing logic for the image patch indices.
- image_patch_input_indices = full_unpacked_stream_to_tensor(
- all_bi_tokens_to_place=all_bi_tokens_to_place,
- full_unpacked_stream=unpacked_image_patch_indices_per_batch,
- fill_value=-1,
- batch_size=self.batch_size,
- new_seq_len=max_seq_len_batch,
- offset=0,
- )
+ def post_process_box_coordinates(self, outputs, target_sizes=None):
+ """
+ Transforms raw coordinates detected by [`FuyuForCausalLM`] to the original images' coordinate space.
+ Coordinates will be returned in "box" format, with the following pattern:
+ `top, left, bottom, right`
+
+ Point coordinates are not supported yet.
- image_patches_tensor = torch.stack([img[0] for img in model_image_input["image_patches"]]).unsqueeze(1)
- return {
- "input_ids": image_padded_unpacked_tokens[0].unsqueeze(0),
- "image_patches": image_patches_tensor[0][0].unsqueeze(0),
- "image_patches_indices": image_patch_input_indices,
- }
+ Args:
+ outputs ([`GenerateOutput`]):
+ Raw outputs from `generate`.
+ target_sizes (`torch.Tensor`, *optional*):
+ Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
+ the batch. If set, found coordinates in the output sequence are rescaled to the target sizes. If left
+ to None, coordinates will not be rescaled.
+
+ Returns:
+ `GenerateOutput`: Same output type returned by `generate`, with output token ids replaced with
+ boxed and possible rescaled coordinates.
+ """
+
+ def scale_factor_to_fit(original_size, target_size=None):
+ height, width = original_size
+ if target_size is None:
+ max_height = self.image_processor.size["height"]
+ max_width = self.image_processor.size["width"]
+ else:
+ max_height, max_width = target_size
+ if width <= max_width and height <= max_height:
+ return 1.0
+ return min(max_height / height, max_width / width)
+
+ def find_delimiters_pair(tokens, start_token, end_token):
+ start_id = self.tokenizer.convert_tokens_to_ids(start_token)
+ end_id = self.tokenizer.convert_tokens_to_ids(end_token)
+
+ starting_positions = (tokens == start_id).nonzero(as_tuple=True)[0]
+ ending_positions = (tokens == end_id).nonzero(as_tuple=True)[0]
+
+ if torch.any(starting_positions) and torch.any(ending_positions):
+ return (starting_positions[0], ending_positions[0])
+ return (None, None)
+
+ def tokens_to_boxes(tokens, original_size):
+ while (pair := find_delimiters_pair(tokens, TOKEN_BBOX_OPEN_STRING, TOKEN_BBOX_CLOSE_STRING)) != (
+ None,
+ None,
+ ):
+ start, end = pair
+ if end != start + 5:
+ continue
+
+ # Retrieve transformed coordinates from tokens
+ coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end])
+
+ # Scale back to original image size and multiply by 2
+ scale = scale_factor_to_fit(original_size)
+ top, left, bottom, right = [2 * int(float(c) / scale) for c in coords]
+
+ # Replace the IDs so they get detokenized right
+ replacement = f" {TEXT_REPR_BBOX_OPEN}{top}, {left}, {bottom}, {right}{TEXT_REPR_BBOX_CLOSE}"
+ replacement = self.tokenizer.tokenize(replacement)[1:]
+ replacement = self.tokenizer.convert_tokens_to_ids(replacement)
+ replacement = torch.tensor(replacement).to(tokens)
+
+ tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0)
+ return tokens
+
+ def tokens_to_points(tokens, original_size):
+ while (pair := find_delimiters_pair(tokens, TOKEN_POINT_OPEN_STRING, TOKEN_POINT_CLOSE_STRING)) != (
+ None,
+ None,
+ ):
+ start, end = pair
+ if end != start + 3:
+ continue
+
+ # Retrieve transformed coordinates from tokens
+ coords = self.tokenizer.convert_ids_to_tokens(tokens[start + 1 : end])
+
+ # Scale back to original image size and multiply by 2
+ scale = scale_factor_to_fit(original_size)
+ x, y = [2 * int(float(c) / scale) for c in coords]
+
+ # Replace the IDs so they get detokenized right
+ replacement = f" {TEXT_REPR_POINT_OPEN}{x}, {y}{TEXT_REPR_POINT_CLOSE}"
+ replacement = self.tokenizer.tokenize(replacement)[1:]
+ replacement = self.tokenizer.convert_tokens_to_ids(replacement)
+ replacement = torch.tensor(replacement).to(tokens)
+
+ tokens = torch.cat([tokens[:start], replacement, tokens[end + 1 :]], 0)
+ return tokens
+
+ if target_sizes is None:
+ target_sizes = ((self.image_processor.size["height"], self.image_processor.size["width"]),) * len(outputs)
+ elif target_sizes.shape[1] != 2:
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ if len(outputs) != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as output sequences")
+
+ results = []
+ for seq, size in zip(outputs, target_sizes):
+ seq = tokens_to_boxes(seq, size)
+ seq = tokens_to_points(seq, size)
+ results.append(seq)
+
+ return results
def batch_decode(self, *args, **kwargs):
"""
- This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
- This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
diff --git a/tests/models/fuyu/test_image_processing_fuyu.py b/tests/models/fuyu/test_image_processing_fuyu.py
index 73f0936aacf13d..a9930e2fb81297 100644
--- a/tests/models/fuyu/test_image_processing_fuyu.py
+++ b/tests/models/fuyu/test_image_processing_fuyu.py
@@ -24,7 +24,8 @@
@require_torchvision
class TestFuyuImageProcessor(unittest.TestCase):
def setUp(self):
- self.processor = FuyuImageProcessor(target_height=160, target_width=320, padding_value=1.0)
+ self.size = {"height": 160, "width": 320}
+ self.processor = FuyuImageProcessor(size=self.size, padding_value=1.0)
self.batch_size = 3
self.channels = 3
self.height = 300
@@ -38,29 +39,25 @@ def setUp(self):
self.sample_image_pil = Image.fromarray(self.sample_image)
def test_patches(self):
- expected_num_patches = self.processor.get_num_patches(
- img_h=self.height, img_w=self.width, patch_dim_h=self.image_patch_dim_h, patch_dim_w=self.image_patch_dim_w
- )
+ expected_num_patches = self.processor.get_num_patches(image_height=self.height, image_width=self.width)
- patches_final = self.processor.patchify_image(
- image=self.image_input, patch_dim_h=self.image_patch_dim_h, patch_dim_w=self.image_patch_dim_w
- )
+ patches_final = self.processor.patchify_image(image=self.image_input)
assert (
patches_final.shape[1] == expected_num_patches
), f"Expected {expected_num_patches} patches, got {patches_final.shape[1]}."
def test_scale_to_target_aspect_ratio(self):
# (h:450, w:210) fitting (160, 320) -> (160, 210*160/450)
- scaled_image = self.processor._scale_to_target_aspect_ratio(self.sample_image)
+ scaled_image = self.processor.resize(self.sample_image, size=self.size)
self.assertEqual(scaled_image.shape[0], 160)
self.assertEqual(scaled_image.shape[1], 74)
def test_apply_transformation_numpy(self):
- transformed_image = self.processor.apply_transformation(self.sample_image)
- self.assertEqual(transformed_image.shape[0], 160)
- self.assertEqual(transformed_image.shape[1], 320)
+ transformed_image = self.processor.preprocess(self.sample_image).images[0][0]
+ self.assertEqual(transformed_image.shape[1], 160)
+ self.assertEqual(transformed_image.shape[2], 320)
def test_apply_transformation_pil(self):
- transformed_image = self.processor.apply_transformation(self.sample_image_pil)
- self.assertEqual(transformed_image.shape[0], 160)
- self.assertEqual(transformed_image.shape[1], 320)
+ transformed_image = self.processor.preprocess(self.sample_image_pil).images[0][0]
+ self.assertEqual(transformed_image.shape[1], 160)
+ self.assertEqual(transformed_image.shape[2], 320)
diff --git a/tests/models/fuyu/test_modeling_fuyu.py b/tests/models/fuyu/test_modeling_fuyu.py
index b9c061e7a00448..9fb6820e45ffb1 100644
--- a/tests/models/fuyu/test_modeling_fuyu.py
+++ b/tests/models/fuyu/test_modeling_fuyu.py
@@ -3,7 +3,7 @@
import requests
-from transformers import AutoTokenizer, FuyuConfig, is_torch_available, is_vision_available
+from transformers import FuyuConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
from ...test_modeling_common import ids_tensor, random_attention_mask
@@ -14,7 +14,7 @@
if is_torch_available() and is_vision_available():
- from transformers import FuyuImageProcessor, FuyuProcessor
+ from transformers import FuyuProcessor
if is_torch_available():
@@ -267,11 +267,8 @@ class FuyuIntegrationTest(unittest.TestCase): # , ModelTesterMixin)
all_model_classes = ("FuyuForCausalLM") if is_torch_available() else ()
def setUp(self):
- self.pretrained_model_name = "huggingface/new_model_release_weights"
- tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name)
- image_processor = FuyuImageProcessor()
-
- self.processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)
+ self.pretrained_model_name = "adept/fuyu-8b"
+ self.processor = FuyuProcessor.from_pretrained(self.pretrained_model_name)
self.model = FuyuForCausalLM.from_pretrained(self.pretrained_model_name)
self.bus_image_url = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
@@ -280,9 +277,8 @@ def setUp(self):
@slow
def test_model_8b_chat_greedy_generation_bus_captioning(self):
- EXPECTED_TEXT_COMPLETION = """A bus parked on the side of a road.|ENDOFTEXT|"""
+ EXPECTED_TEXT_COMPLETION = """A blue bus parked on the side of a road.|ENDOFTEXT|"""
text_prompt_coco_captioning = "Generate a coco-style caption.\n"
-
model_inputs_bus_captioning = self.processor(text=text_prompt_coco_captioning, images=self.bus_image_pil)
generated_tokens = self.model.generate(**model_inputs_bus_captioning, max_new_tokens=10)
text = self.processor.tokenizer.batch_decode(generated_tokens)
@@ -297,7 +293,7 @@ def test_model_8b_chat_greedy_generation_bus_captioning(self):
"""
@slow
- @require_torch_gpu
+ @require_torch_accelerator
def test_model_8b_chat_greedy_generation_bus_color(self):
EXPECTED_TEXT_COMPLETION = "The bus is blue.\n|ENDOFTEXT|"
text_prompt_bus_color = "What color is the bus?\n"
@@ -314,7 +310,7 @@ def test_model_8b_chat_greedy_generation_bus_color(self):
self.assertEqual(EXPECTED_TEXT_COMPLETION, clean_sequence)
@slow
- @require_torch_gpu
+ @require_torch_accelerator
def test_model_8b_chat_greedy_generation_chart_vqa(self):
# fmt: off
EXPECTED_TEXT_TOKENS = ["The","life expectancy","at","birth","of male","s in","","20","18","is","","80",".","7",".","\n","|ENDOFTEXT|",]
@@ -340,7 +336,7 @@ def test_model_8b_chat_greedy_generation_chart_vqa(self):
self.assertEqual(expected_text_completion, clean_sequence)
@slow
- @require_torch_gpu
+ @require_torch_accelerator
def test_model_8b_chat_greedy_generation_bounding_box(self):
EXPECTED_TEXT_COMPLETION = "\x00194213202244\x01|ENDOFTEXT|"
text_prompt_bbox = "When presented with a box, perform OCR to extract text contained within it. If provided with text, generate the corresponding bounding box.\\nWilliams" # noqa: E231
diff --git a/tests/models/fuyu/test_processing_fuyu.py b/tests/models/fuyu/test_processing_fuyu.py
index 1c75b2b0ae318a..459386952c3ed9 100644
--- a/tests/models/fuyu/test_processing_fuyu.py
+++ b/tests/models/fuyu/test_processing_fuyu.py
@@ -26,16 +26,14 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
""" """
def setUp(self):
- pretrained_model_name = "huggingface/pre_release_model"
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
- image_processor = FuyuImageProcessor()
+ pretrained_model_name = "adept/fuyu-8b"
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
+ self.image_processor = FuyuImageProcessor()
- processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)
- text_prompt = "Generate a coco-style caption.\\n"
+ self.processor = FuyuProcessor(image_processor=self.image_processor, tokenizer=self.tokenizer)
+ self.text_prompt = "Generate a coco-style caption.\\n"
bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
- bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content))
-
- self.one_image_bus_model_inputs = processor(text=text_prompt, images=bus_image_pil)
+ self.bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content))
def test_fuyu_processing(self):
"""
@@ -44,11 +42,119 @@ def test_fuyu_processing(self):
# fmt: off
EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64)
EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64)
+
+ one_image_bus_model_inputs = self.processor(text=self.text_prompt, images=self.bus_image_pil)
+
+ # fmt: on
+ torch.testing.assert_close(one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS)
+ torch.testing.assert_close(one_image_bus_model_inputs["input_ids"], EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS)
+
+ def test_fuyu_processing_no_image(self):
+ """
+ Test to check processor works with just text input
+ """
+ processor_outputs = self.processor(text=self.text_prompt)
+ tokenizer_outputs = self.tokenizer(self.text_prompt)
+ self.assertEqual(processor_outputs["input_ids"], tokenizer_outputs["input_ids"])
+
+ def test_fuyu_processing_no_text(self):
+ """
+ Test to check processor works with just image input
+ """
+ # fmt: off
+ EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([
+ [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+ 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26,
+ 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
+ 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
+ 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66,
+ 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
+ 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93,
+ 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
+ 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
+ 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133,
+ 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
+ 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160,
+ 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174,
+ 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
+ 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200,
+ 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214,
+ 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227,
+ 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241,
+ -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254,
+ 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267,
+ 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281,
+ 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294,
+ 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1,
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
+ ]).to(torch.int64)
+ # fmt: on
+
+ processor_outputs = self.processor(images=self.bus_image_pil)
+ self.assertTrue((processor_outputs["image_patches_indices"] == EXPECTED_IMAGE_PATCH_INPUTS).all())
+
+ def test_fuyu_processing_multiple_image_sample(self):
+ """
+ Test to check processor works with multiple image inputs for a single text input
+ """
+ # fmt: off
+ SINGLE_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64)
+ SINGLE_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64)
+
+ SINGLE_RESIZED_IMAGE_PATCH_INPUTS = torch.Tensor([[ 0, 1, 2, -1, 3, 4, 5, -1, 6, 7, 8, -1, 9, 10, 11, -1, 12, 13, 14, -1, 15, 16, 17, -1, 18, 19, 20, -1, 21, 22, 23, -1, 24, 25, 26, -1, 27, 28, 29, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]])
+ SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[ 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122]])
# fmt: on
- torch.testing.assert_close(
- self.one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS
+
+ # Batch of two images - equally sized
+ images = [self.bus_image_pil, self.bus_image_pil]
+ processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images)
+
+ self.assertTrue(
+ (
+ processor_outputs["image_patches_indices"]
+ == torch.cat([SINGLE_IMAGE_PATCH_INPUTS, SINGLE_IMAGE_PATCH_INPUTS], dim=0)
+ ).all()
+ )
+ self.assertTrue(
+ (
+ processor_outputs["input_ids"]
+ == torch.cat([SINGLE_PADDED_UNPACKED_TOKEN_INPUTS, SINGLE_PADDED_UNPACKED_TOKEN_INPUTS], dim=0)
+ ).all()
)
- torch.testing.assert_close(self.one_image_bus_model_inputs["input_ids"], EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS)
+
+ # Processes single images with different sizes as expected
+ images = [self.bus_image_pil]
+ processor_outputs = self.processor(text=self.text_prompt, images=images)
+ self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_IMAGE_PATCH_INPUTS).all())
+ self.assertTrue((processor_outputs["input_ids"] == SINGLE_PADDED_UNPACKED_TOKEN_INPUTS).all())
+
+ images = [self.bus_image_pil.resize((64, 300))]
+ processor_outputs = self.processor(text=self.text_prompt, images=images)
+ self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_RESIZED_IMAGE_PATCH_INPUTS).all())
+ self.assertTrue((processor_outputs["input_ids"] == SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS).all())
+
+ # Batch of two images - different sizes. Left-pads the smaller image inputs
+ images = [self.bus_image_pil, self.bus_image_pil.resize((64, 300))]
+ processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images)
+
+ padding_len_patch = SINGLE_IMAGE_PATCH_INPUTS.shape[1] - SINGLE_RESIZED_IMAGE_PATCH_INPUTS.shape[1]
+ padded_single_resized_image_patch = torch.cat(
+ [torch.ones([1, padding_len_patch]) * -1, SINGLE_RESIZED_IMAGE_PATCH_INPUTS], dim=1
+ )
+ expected_image_patch_inputs = torch.cat([SINGLE_IMAGE_PATCH_INPUTS, padded_single_resized_image_patch], dim=0)
+
+ padding_len_token = (
+ SINGLE_PADDED_UNPACKED_TOKEN_INPUTS.shape[1] - SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS.shape[1]
+ )
+ padded_single_resized_padded_unpacked_token_inputs = torch.cat(
+ [torch.zeros([1, padding_len_token]), SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS], dim=1
+ )
+ expected_padded_unpacked_token_inputs = torch.cat(
+ [SINGLE_PADDED_UNPACKED_TOKEN_INPUTS, padded_single_resized_padded_unpacked_token_inputs], dim=0
+ )
+
+ self.assertTrue((processor_outputs["image_patches_indices"] == expected_image_patch_inputs).all())
+ self.assertTrue((processor_outputs["input_ids"] == expected_padded_unpacked_token_inputs).all())
@require_torch
@@ -97,7 +203,6 @@ def setUp(self):
"""
Adding a mix of present and absent images.
"""
- self.image_processor = FuyuImageProcessor()
self.image_input = torch.randn([1, 1, 3, 64, 64])
self.image_present = torch.tensor([[1]])
@@ -108,19 +213,19 @@ def setUp(self):
self.image_placeholder_id = 999
self.image_newline_id = 888
self.variable_sized = True
+ self.image_processor = FuyuImageProcessor(
+ patch_size={"height": self.image_patch_dim_h, "width": self.image_patch_dim_w}
+ )
def test_process_images_for_model_input_fixed_sized(self):
self.variable_sized = False
- result = self.image_processor.process_images_for_model_input(
+ result = self.image_processor.preprocess_with_tokenizer_info(
image_input=self.image_input,
image_present=self.image_present,
image_unpadded_h=self.image_unpadded_h,
image_unpadded_w=self.image_unpadded_w,
- image_patch_dim_h=self.image_patch_dim_h,
- image_patch_dim_w=self.image_patch_dim_w,
image_placeholder_id=self.image_placeholder_id,
image_newline_id=self.image_newline_id,
variable_sized=self.variable_sized,
)
- print(result["images"][0][0])
self.assertEqual(result["images"][0][0].shape, torch.Size([3, 64, 64]))