From b4d527e0b7cbb66501004632b5cb9002b7739277 Mon Sep 17 00:00:00 2001 From: Niels Date: Sat, 31 Aug 2024 10:22:12 +0200 Subject: [PATCH] Use numpy --- .../models/llava/image_processing_llava.py | 82 +++++++++---------- .../llava/test_image_processing_llava.py | 33 ++++++++ 2 files changed, 74 insertions(+), 41 deletions(-) diff --git a/src/transformers/models/llava/image_processing_llava.py b/src/transformers/models/llava/image_processing_llava.py index 5af59911138739..2e40b94675fef4 100644 --- a/src/transformers/models/llava/image_processing_llava.py +++ b/src/transformers/models/llava/image_processing_llava.py @@ -17,14 +17,11 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np -from PIL import Image from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( - PaddingMode, convert_to_rgb, get_resize_output_image_size, - pad, resize, to_channel_dimension_format, ) @@ -157,25 +154,6 @@ def __init__( # `shortest_edge` key. delattr(self, "use_square_size") - def pad_to_square_original( - self, image: Image.Image, background_color: Union[int, Tuple[int, int, int]] = 0 - ) -> Image.Image: - """ - Pads an image to make it square. - """ - print("Image size:", image.size) - width, height = image.size - if width == height: - return image - elif width > height: - result = Image.new(image.mode, (width, width), background_color) - result.paste(image, (0, (width - height) // 2)) - return result - else: - result = Image.new(image.mode, (height, height), background_color) - result.paste(image, ((height - width) // 2, 0)) - return result - def pad_to_square( self, image: np.array, @@ -199,27 +177,42 @@ def pad_to_square( Returns: `np.ndarray`: The padded image. """ - height, width = get_image_size(image, input_data_format) + h, w = get_image_size(image, input_data_format) + c = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1] - if height == width: + if h == w: return image - max_dim = max(height, width) - pad_height = max_dim - height - pad_width = max_dim - width - - padding = ( - (pad_height // 2, pad_height - pad_height // 2), - (pad_width // 2, pad_width - pad_width // 2), - ) - - return pad( - image=image, - padding=padding, - mode=PaddingMode.CONSTANT, - constant_values=background_color, - input_data_format=input_data_format, - ) + max_dim = max(h, w) + + # Ensure background_color is the correct shape + if isinstance(background_color, (int, float)): + background_color = [background_color] * c + elif len(background_color) != c: + raise ValueError(f"background_color must have {c} elements to match the number of channels") + + if input_data_format == ChannelDimension.FIRST: + result = np.zeros((c, max_dim, max_dim), dtype=image.dtype) + for i, color in enumerate(background_color): + result[i, :, :] = color + if w > h: + start = (max_dim - h) // 2 + result[:, start : start + h, :] = image + else: + start = (max_dim - w) // 2 + result[:, :, start : start + w] = image + else: + result = np.zeros((max_dim, max_dim, c), dtype=image.dtype) + for i, color in enumerate(background_color): + result[:, :, i] = color + if w > h: + start = (max_dim - h) // 2 + result[start : start + h, :, :] = image + else: + start = (max_dim - w) // 2 + result[:, start : start + w, :] = image + + return result def resize( self, @@ -368,7 +361,6 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) validate_preprocess_arguments( - do_pad=do_pad, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, @@ -397,6 +389,10 @@ def preprocess( # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) + print("Input data format:", input_data_format) + for image in images: + print(image.shape) + if do_pad: images = [ self.pad_to_square( @@ -407,6 +403,10 @@ def preprocess( for image in images ] + print("After padding:") + for image in images: + print(image.shape) + if do_resize: images = [ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) diff --git a/tests/models/llava/test_image_processing_llava.py b/tests/models/llava/test_image_processing_llava.py index 4b7f6e9b1bc3cf..7811b2110351ee 100644 --- a/tests/models/llava/test_image_processing_llava.py +++ b/tests/models/llava/test_image_processing_llava.py @@ -15,6 +15,9 @@ import unittest +from typing import Tuple, Union + +import numpy as np from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available @@ -23,6 +26,8 @@ if is_vision_available(): + from PIL import Image + from transformers import LlavaImageProcessor @@ -128,3 +133,31 @@ def test_image_processor_from_dict_with_kwargs(self): image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) self.assertEqual(image_processor.size, {"shortest_edge": 42}) self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + + # Ignore copy + def test_padding(self): + # taken from original implementation + def pad_to_square_original( + image: Image.Image, background_color: Union[int, Tuple[int, int, int]] = 0 + ) -> Image.Image: + width, height = image.size + if width == height: + return image + elif width > height: + result = Image.new(image.mode, (width, width), background_color) + result.paste(image, (0, (width - height) // 2)) + return result + else: + result = Image.new(image.mode, (height, height), background_color) + result.paste(image, ((height - width) // 2, 0)) + return result + + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + padded_image = image_processor.pad_to_square(image) + padded_image_original = pad_to_square_original(image) + + self.assertTrue(np.assert_array_equal(padded_image, padded_image_original))