diff --git a/src/transformers/models/flava/feature_extraction_flava.py b/src/transformers/models/flava/feature_extraction_flava.py index c3aba8c70b6ce9..62cbcce069731d 100644 --- a/src/transformers/models/flava/feature_extraction_flava.py +++ b/src/transformers/models/flava/feature_extraction_flava.py @@ -284,13 +284,15 @@ def __call__( If True, the processor will return `codebook_pixel_values` providing image pixels to be used with the default FLAVA codebook. Used in pretraining by Masked Image Modeling (MIM) loss. - return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): - If set, will return tensors of a particular framework. Acceptable values are: + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `None`): + If set, will return a tensor of a particular framework. - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. + Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` object. + - `'pt'`: Return PyTorch `torch.Tensor` object. + - `'np'`: Return NumPy `np.ndarray` object. + - `'jax'`: Return JAX `jnp.ndarray` object. + - None: Return list of `np.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -318,8 +320,16 @@ def __call__( images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] if self.do_center_crop and self.crop_size is not None: images = [self.center_crop(image, self.crop_size) for image in images] + + # if do_normalize=False, the casting to a numpy array won't happen, so we need to do it here + make_channel_first = True if isinstance(images[0], Image.Image) else images[0].shape[-1] in (1, 3) + images = [self.to_numpy_array(image, rescale=False, channel_first=make_channel_first) for image in images] + if self.do_normalize: - images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + images = [ + self.normalize(image=image, mean=self.image_mean, std=self.image_std, rescale=True) for image in images + ] + # return as BatchFeature data = {"pixel_values": images} diff --git a/tests/models/flava/test_feature_extraction_flava.py b/tests/models/flava/test_feature_extraction_flava.py index 793aa913aeb04b..e609aec4704f79 100644 --- a/tests/models/flava/test_feature_extraction_flava.py +++ b/tests/models/flava/test_feature_extraction_flava.py @@ -18,6 +18,7 @@ import numpy as np +from parameterized import parameterized from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_torch_available, is_vision_available @@ -345,3 +346,50 @@ def test_codebook_pixels(self): expected_width, ), ) + + @parameterized.expand( + [ + ("do_resize_True_do_center_crop_True_do_normalize_True", True, True, True), + ("do_resize_True_do_center_crop_True_do_normalize_False", True, True, False), + ("do_resize_True_do_center_crop_False_do_normalize_True", True, False, True), + ("do_resize_True_do_center_crop_False_do_normalize_False", True, False, False), + ("do_resize_False_do_center_crop_True_do_normalize_True", False, True, True), + ("do_resize_False_do_center_crop_True_do_normalize_False", False, True, False), + ("do_resize_False_do_center_crop_False_do_normalize_True", False, False, True), + ("do_resize_False_do_center_crop_False_do_normalize_False", False, False, False), + ] + ) + def test_call_flags(self, _, do_resize, do_center_crop, do_normalize): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + feature_extractor.do_center_crop = do_center_crop + feature_extractor.do_resize = do_resize + feature_extractor.do_normalize = do_normalize + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + + expected_shapes = [(3, *x.size[::-1]) for x in image_inputs] + if do_resize: + expected_shapes = [ + ( + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ) + for _ in range(self.feature_extract_tester.batch_size) + ] + if do_center_crop: + expected_shapes = [ + ( + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ) + for _ in range(self.feature_extract_tester.batch_size) + ] + + pixel_values = feature_extractor(image_inputs, return_tensors=None)["pixel_values"] + self.assertEqual(len(pixel_values), self.feature_extract_tester.batch_size) + for idx, image in enumerate(pixel_values): + self.assertEqual(image.shape, expected_shapes[idx]) + self.assertIsInstance(image, np.ndarray)