Skip to content

Commit

Permalink
Merge pull request #15 from amyeroberts/type-cast-before-normalize-dpt
Browse files Browse the repository at this point in the history
Type cast before normalize dpt
  • Loading branch information
amyeroberts authored Sep 2, 2022
2 parents 568cd54 + 36a13c9 commit 8879aca
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/transformers/models/dpt/feature_extraction_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,15 @@ def __call__(
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~file_utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `None`):
If set, will return a tensor of a particular framework.
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` object.
- `'pt'`: Return PyTorch `torch.Tensor` object.
- `'np'`: Return NumPy `np.ndarray` object.
- `'jax'`: Return JAX `jnp.ndarray` object.
- None: Return list of `np.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
Expand Down Expand Up @@ -192,8 +194,15 @@ def __call__(
for idx, image in enumerate(images):
size = self.update_size(image)
images[idx] = self.resize(image, size=size, resample=self.resample)

# if do_normalize=False, the casting to a numpy array won't happen, so we need to do it here
make_channel_first = True if isinstance(images[0], Image.Image) else images[0].shape[-1] in (1, 3)
images = [self.to_numpy_array(image, rescale=False, channel_first=make_channel_first) for image in images]

if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
images = [
self.normalize(image=image, mean=self.image_mean, std=self.image_std, rescale=True) for image in images
]

# return as BatchFeature
data = {"pixel_values": images}
Expand Down
38 changes: 38 additions & 0 deletions tests/models/dpt/test_feature_extraction_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np

from parameterized import parameterized
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision

Expand Down Expand Up @@ -186,3 +187,40 @@ def test_call_pytorch(self):
self.feature_extract_tester.size,
),
)

@parameterized.expand(
[
("do_resize_True_do_normalize_True", True, True),
("do_resize_True_do_normalize_False", True, False),
("do_resize_True_do_normalize_True", True, True),
("do_resize_True_do_normalize_False", True, False),
("do_resize_False_do_normalize_True", False, True),
("do_resize_False_do_normalize_False", False, False),
("do_resize_False_do_normalize_True", False, True),
("do_resize_False_do_normalize_False", False, False),
]
)
def test_call_flags(self, _, do_resize, do_normalize):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
feature_extractor.do_resize = do_resize
feature_extractor.do_normalize = do_normalize
# create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)

expected_shapes = [(3, *x.size[::-1]) for x in image_inputs]
if do_resize:
expected_shapes = [
(
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
)
for _ in range(self.feature_extract_tester.batch_size)
]

pixel_values = feature_extractor(image_inputs, return_tensors=None)["pixel_values"]
self.assertEqual(len(pixel_values), self.feature_extract_tester.batch_size)
for idx, image in enumerate(pixel_values):
self.assertEqual(image.shape, expected_shapes[idx])
self.assertIsInstance(image, np.ndarray)

0 comments on commit 8879aca

Please sign in to comment.