From eb643a8209a9778206edc086236675bf1235a1dc Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Fri, 5 Aug 2022 15:36:01 +0100 Subject: [PATCH 1/2] Cast images to numpy arrays in call to enable consistent behaviour with different configs --- .../models/flava/feature_extraction_flava.py | 24 +++++++--- .../flava/test_feature_extraction_flava.py | 48 +++++++++++++++++++ 2 files changed, 65 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/flava/feature_extraction_flava.py b/src/transformers/models/flava/feature_extraction_flava.py index c3aba8c70b6ce9..62cbcce069731d 100644 --- a/src/transformers/models/flava/feature_extraction_flava.py +++ b/src/transformers/models/flava/feature_extraction_flava.py @@ -284,13 +284,15 @@ def __call__( If True, the processor will return `codebook_pixel_values` providing image pixels to be used with the default FLAVA codebook. Used in pretraining by Masked Image Modeling (MIM) loss. - return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): - If set, will return tensors of a particular framework. Acceptable values are: + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `None`): + If set, will return a tensor of a particular framework. - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. + Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` object. + - `'pt'`: Return PyTorch `torch.Tensor` object. + - `'np'`: Return NumPy `np.ndarray` object. + - `'jax'`: Return JAX `jnp.ndarray` object. + - None: Return list of `np.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -318,8 +320,16 @@ def __call__( images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] if self.do_center_crop and self.crop_size is not None: images = [self.center_crop(image, self.crop_size) for image in images] + + # if do_normalize=False, the casting to a numpy array won't happen, so we need to do it here + make_channel_first = True if isinstance(images[0], Image.Image) else images[0].shape[-1] in (1, 3) + images = [self.to_numpy_array(image, rescale=False, channel_first=make_channel_first) for image in images] + if self.do_normalize: - images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + images = [ + self.normalize(image=image, mean=self.image_mean, std=self.image_std, rescale=True) for image in images + ] + # return as BatchFeature data = {"pixel_values": images} diff --git a/tests/models/flava/test_feature_extraction_flava.py b/tests/models/flava/test_feature_extraction_flava.py index 793aa913aeb04b..e609aec4704f79 100644 --- a/tests/models/flava/test_feature_extraction_flava.py +++ b/tests/models/flava/test_feature_extraction_flava.py @@ -18,6 +18,7 @@ import numpy as np +from parameterized import parameterized from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_torch_available, is_vision_available @@ -345,3 +346,50 @@ def test_codebook_pixels(self): expected_width, ), ) + + @parameterized.expand( + [ + ("do_resize_True_do_center_crop_True_do_normalize_True", True, True, True), + ("do_resize_True_do_center_crop_True_do_normalize_False", True, True, False), + ("do_resize_True_do_center_crop_False_do_normalize_True", True, False, True), + ("do_resize_True_do_center_crop_False_do_normalize_False", True, False, False), + ("do_resize_False_do_center_crop_True_do_normalize_True", False, True, True), + ("do_resize_False_do_center_crop_True_do_normalize_False", False, True, False), + ("do_resize_False_do_center_crop_False_do_normalize_True", False, False, True), + ("do_resize_False_do_center_crop_False_do_normalize_False", False, False, False), + ] + ) + def test_call_flags(self, _, do_resize, do_center_crop, do_normalize): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + feature_extractor.do_center_crop = do_center_crop + feature_extractor.do_resize = do_resize + feature_extractor.do_normalize = do_normalize + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + + expected_shapes = [(3, *x.size[::-1]) for x in image_inputs] + if do_resize: + expected_shapes = [ + ( + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ) + for _ in range(self.feature_extract_tester.batch_size) + ] + if do_center_crop: + expected_shapes = [ + ( + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ) + for _ in range(self.feature_extract_tester.batch_size) + ] + + pixel_values = feature_extractor(image_inputs, return_tensors=None)["pixel_values"] + self.assertEqual(len(pixel_values), self.feature_extract_tester.batch_size) + for idx, image in enumerate(pixel_values): + self.assertEqual(image.shape, expected_shapes[idx]) + self.assertIsInstance(image, np.ndarray) From aae78c1e29f2c5f916702f3c463af8c1cb0b4ae2 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Fri, 5 Aug 2022 16:35:31 +0100 Subject: [PATCH 2/2] Remove accidental clip changes --- .../models/clip/feature_extraction_clip.py | 23 +++----- .../clip/test_feature_extraction_clip.py | 54 ------------------- 2 files changed, 7 insertions(+), 70 deletions(-) diff --git a/src/transformers/models/clip/feature_extraction_clip.py b/src/transformers/models/clip/feature_extraction_clip.py index 4784955b9b636d..7f01b5e02b94df 100644 --- a/src/transformers/models/clip/feature_extraction_clip.py +++ b/src/transformers/models/clip/feature_extraction_clip.py @@ -108,15 +108,13 @@ def __call__( tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `None`): - If set, will return a tensor of a particular framework. + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): + If set, will return tensors of a particular framework. Acceptable values are: - Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` object. - - `'pt'`: Return PyTorch `torch.Tensor` object. - - `'np'`: Return NumPy `np.ndarray` object. - - `'jax'`: Return JAX `jnp.ndarray` object. - - None: Return list of `np.ndarray` objects. + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -157,15 +155,8 @@ def __call__( ] if self.do_center_crop and self.crop_size is not None: images = [self.center_crop(image, self.crop_size) for image in images] - - # if do_normalize=False, the casting to a numpy array won't happen, so we need to do it here - make_channel_first = True if isinstance(images[0], Image.Image) else images[0].shape[-1] in (1, 3) - images = [self.to_numpy_array(image, rescale=False, channel_first=make_channel_first) for image in images] - if self.do_normalize: - images = [ - self.normalize(image=image, mean=self.image_mean, std=self.image_std, rescale=True) for image in images - ] + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] # return as BatchFeature data = {"pixel_values": images} diff --git a/tests/models/clip/test_feature_extraction_clip.py b/tests/models/clip/test_feature_extraction_clip.py index 05e13af1fdf90b..8f36a65ae2d596 100644 --- a/tests/models/clip/test_feature_extraction_clip.py +++ b/tests/models/clip/test_feature_extraction_clip.py @@ -18,7 +18,6 @@ import numpy as np -from parameterized import parameterized from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_torch_available, is_vision_available @@ -293,56 +292,3 @@ def test_call_pil_four_channels(self): self.feature_extract_tester.crop_size, ), ) - - @parameterized.expand( - [ - ("do_resize_True_do_center_crop_True_do_normalize_True", True, True, True), - ("do_resize_True_do_center_crop_True_do_normalize_False", True, True, False), - ("do_resize_True_do_center_crop_False_do_normalize_True", True, False, True), - ("do_resize_True_do_center_crop_False_do_normalize_False", True, False, False), - ("do_resize_False_do_center_crop_True_do_normalize_True", False, True, True), - ("do_resize_False_do_center_crop_True_do_normalize_False", False, True, False), - ("do_resize_False_do_center_crop_False_do_normalize_True", False, False, True), - ("do_resize_False_do_center_crop_False_do_normalize_False", False, False, False), - ] - ) - def test_call_flags(self, _, do_resize, do_center_crop, do_normalize): - # Initialize feature_extractor - feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) - feature_extractor.do_center_crop = do_center_crop - feature_extractor.do_resize = do_resize - feature_extractor.do_normalize = do_normalize - # create random PIL images - image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False, torchify=True) - - # expected_shapes = [(self.expected_encoded_image_num_channels, *x.size[::-1]) for x in image_inputs] - expected_shapes = [x.shape for x in image_inputs] - if do_resize: - # Same size logic inside resized - resized_shapes = [] - for shape in expected_shapes: - c, h, w = shape - short, long = (w, h) if w <= h else (h, w) - min_size = self.feature_extract_tester.size - if short == min_size: - resized_shapes.append((c, h, w)) - else: - short, long = min_size, int(long * min_size / short) - resized_shape = (c, long, short) if w <= h else (c, short, long) - resized_shapes.append(resized_shape) - expected_shapes = resized_shapes - if do_center_crop: - expected_shapes = [ - ( - self.expected_encoded_image_num_channels, - self.feature_extract_tester.crop_size, - self.feature_extract_tester.crop_size, - ) - for _ in range(self.feature_extract_tester.batch_size) - ] - - pixel_values = feature_extractor(image_inputs, return_tensors=None)["pixel_values"] - self.assertEqual(len(pixel_values), self.feature_extract_tester.batch_size) - for idx, image in enumerate(pixel_values): - self.assertEqual(image.shape, expected_shapes[idx]) - self.assertIsInstance(image, np.ndarray)