Skip to content

Commit

Permalink
Merge pull request #19 from amyeroberts/type-cast-before-normalize-la…
Browse files Browse the repository at this point in the history
…youtlmv2

Type cast before normalize layoutlmv2
  • Loading branch information
amyeroberts authored Sep 2, 2022
2 parents dcf02c6 + d0c7db1 commit 3911f5d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,15 @@ def __call__(
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `None`):
If set, will return a tensor of a particular framework.
Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` object.
- `'pt'`: Return PyTorch `torch.Tensor` object.
- `'np'`: Return NumPy `np.ndarray` object.
- `'jax'`: Return JAX `jnp.ndarray` object.
- None: Return list of `np.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
Expand Down
31 changes: 31 additions & 0 deletions tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np

from parameterized import parameterized
from transformers.testing_utils import require_pytesseract, require_torch
from transformers.utils import is_pytesseract_available, is_torch_available

Expand Down Expand Up @@ -219,3 +220,33 @@ def test_layoutlmv2_integration_test(self):
224,
),
)

@parameterized.expand(
[
("do_resize_True", True),
("do_resize_False", False),
]
)
def test_call_flags(self, _, do_resize):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
feature_extractor.do_resize = do_resize
# create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)

expected_shapes = [(3, *x.size[::-1]) for x in image_inputs]
if do_resize:
expected_shapes = [
(
self.feature_extract_tester.num_channels,
self.feature_extract_tester.size,
self.feature_extract_tester.size,
)
for _ in range(self.feature_extract_tester.batch_size)
]

pixel_values = feature_extractor(image_inputs, return_tensors=None)["pixel_values"]
self.assertEqual(len(pixel_values), self.feature_extract_tester.batch_size)
for idx, image in enumerate(pixel_values):
self.assertEqual(image.shape, expected_shapes[idx])
self.assertIsInstance(image, np.ndarray)

0 comments on commit 3911f5d

Please sign in to comment.