Skip to content

Commit

Permalink
Merge pull request #16 from amyeroberts/type-cast-before-normalize-vit
Browse files Browse the repository at this point in the history
Type cast before normalize vit
  • Loading branch information
amyeroberts authored Sep 2, 2022
2 parents b94e71a + 8431b12 commit a2e74b0
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/transformers/models/vit/feature_extraction_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,15 @@ def __call__(
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `None`):
If set, will return a tensor of a particular framework.
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` object.
- `'pt'`: Return PyTorch `torch.Tensor` object.
- `'np'`: Return NumPy `np.ndarray` object.
- `'jax'`: Return JAX `jnp.ndarray` object.
- None: Return list of `np.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
Expand Down Expand Up @@ -139,8 +141,15 @@ def __call__(
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]

# if do_normalize=False, the casting to a numpy array won't happen, so we need to do it here
make_channel_first = True if isinstance(images[0], Image.Image) else images[0].shape[-1] in (1, 3)
images = [self.to_numpy_array(image, rescale=False, channel_first=make_channel_first) for image in images]

if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
images = [
self.normalize(image=image, mean=self.image_mean, std=self.image_std, rescale=True) for image in images
]

# return as BatchFeature
data = {"pixel_values": images}
Expand Down
35 changes: 35 additions & 0 deletions tests/models/vit/test_feature_extraction_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np

from parameterized import parameterized
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available

Expand Down Expand Up @@ -189,3 +190,37 @@ def test_call_pytorch(self):
self.feature_extract_tester.size,
),
)

@parameterized.expand(
[
("do_resize_True_do_normalize_True", True, True),
("do_resize_True_do_normalize_False", True, False),
("do_resize_False_do_normalize_True", False, True),
("do_resize_False_do_normalize_False", False, False),
]
)
def test_call_flags(self, _, do_resize, do_normalize):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
feature_extractor.do_resize = do_resize
feature_extractor.do_normalize = do_normalize
# create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)

if do_resize:
expected_shapes = [
(
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
)
for _ in range(self.feature_extract_tester.batch_size)
]
else:
expected_shapes = [(3, *x.size[::-1]) for x in image_inputs]

pixel_values = feature_extractor(image_inputs, return_tensors=None)["pixel_values"]
self.assertEqual(len(pixel_values), self.feature_extract_tester.batch_size)
for idx, image in enumerate(pixel_values):
self.assertEqual(image.shape, expected_shapes[idx])
self.assertIsInstance(image, np.ndarray)

0 comments on commit a2e74b0

Please sign in to comment.