Skip to content

Commit

Permalink
Add center_crop to ImageFeatureExtractoMixin (#11066)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored Apr 5, 2021
1 parent abb7430 commit 090e3e6
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,55 @@ def resize(self, image, size, resample=PIL.Image.BILINEAR):
image = self.to_pil_image(image)

return image.resize(size, resample=resample)

def center_crop(self, image, size):
"""
Crops :obj:`image` to the given size using a center crop. Note that if the image is too small to be cropped to
the size given, it will be padded (so the returned result has the size asked).
Args:
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
The image to resize.
size (:obj:`int` or :obj:`Tuple[int, int]`):
The size to which crop the image.
"""
self._ensure_format_supported(image)
if not isinstance(size, tuple):
size = (size, size)

# PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
image_shape = (image.size[1], image.size[0]) if isinstance(image, PIL.Image.Image) else image.shape[-2:]
top = (image_shape[0] - size[0]) // 2
bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
left = (image_shape[1] - size[1]) // 2
right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.

# For PIL Images we have a method to crop directly.
if isinstance(image, PIL.Image.Image):
return image.crop((left, top, right, bottom))

# Check if all the dimensions are inside the image.
if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
return image[..., top:bottom, left:right]

# Otherwise, we may need to pad if the image is too small. Oh joy...
new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
if isinstance(image, np.ndarray):
new_image = np.zeros_like(image, shape=new_shape)
elif is_torch_tensor(image):
new_image = image.new_zeros(new_shape)

top_pad = (new_shape[-2] - image_shape[0]) // 2
bottom_pad = top_pad + image_shape[0]
left_pad = (new_shape[-1] - image_shape[1]) // 2
right_pad = left_pad + image_shape[1]
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image

top += top_pad
bottom += top_pad
left += left_pad
right += left_pad

return new_image[
..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
]
52 changes: 52 additions & 0 deletions tests/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,55 @@ def test_normalize_tensor(self):

normalized_tensor = feature_extractor.normalize(tensor, torch.tensor(mean), torch.tensor(std))
self.assertTrue(torch.equal(normalized_tensor, expected))

def test_center_crop_image(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)

# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
crop_sizes = [8, (8, 64), 20, (32, 64)]
for size in crop_sizes:
cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(isinstance(cropped_image, PIL.Image.Image))

# PIL Image.size is transposed compared to NumPy or PyTorch (width first instead of height first).
expected_size = (size, size) if isinstance(size, int) else (size[1], size[0])
self.assertEqual(cropped_image.size, expected_size)

def test_center_crop_array(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
array = feature_extractor.to_numpy_array(image)

# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
crop_sizes = [8, (8, 64), 20, (32, 64)]
for size in crop_sizes:
cropped_array = feature_extractor.center_crop(array, size)
self.assertTrue(isinstance(cropped_array, np.ndarray))

expected_size = (size, size) if isinstance(size, int) else size
self.assertEqual(cropped_array.shape[-2:], expected_size)

# Check result is consistent with PIL.Image.crop
cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(np.array_equal(cropped_array, feature_extractor.to_numpy_array(cropped_image)))

@require_torch
def test_center_crop_tensor(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
array = feature_extractor.to_numpy_array(image)
tensor = torch.tensor(array)

# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
crop_sizes = [8, (8, 64), 20, (32, 64)]
for size in crop_sizes:
cropped_tensor = feature_extractor.center_crop(tensor, size)
self.assertTrue(isinstance(cropped_tensor, torch.Tensor))

expected_size = (size, size) if isinstance(size, int) else size
self.assertEqual(cropped_tensor.shape[-2:], expected_size)

# Check result is consistent with PIL.Image.crop
cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))

0 comments on commit 090e3e6

Please sign in to comment.