Skip to content

Commit beb09cb

Browse files
authored
🔴Make center_crop fast equivalent to slow (#40856)
make center_crop fast equivalent to slow
1 parent d4af0d9 commit beb09cb

File tree

4 files changed

+22
-17
lines changed

4 files changed

+22
-17
lines changed

src/transformers/image_processing_utils_fast.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,10 +405,11 @@ def rescale_and_normalize(
405405
def center_crop(
406406
self,
407407
image: "torch.Tensor",
408-
size: dict[str, int],
408+
size: SizeDict,
409409
**kwargs,
410410
) -> "torch.Tensor":
411411
"""
412+
Note: override torchvision's center_crop to have the same behavior as the slow processor.
412413
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
413414
any edge, the image is padded with 0's and then center cropped.
414415
@@ -423,7 +424,24 @@ def center_crop(
423424
"""
424425
if size.height is None or size.width is None:
425426
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
426-
return F.center_crop(image, (size["height"], size["width"]))
427+
image_height, image_width = image.shape[-2:]
428+
crop_height, crop_width = size.height, size.width
429+
430+
if crop_width > image_width or crop_height > image_height:
431+
padding_ltrb = [
432+
(crop_width - image_width) // 2 if crop_width > image_width else 0,
433+
(crop_height - image_height) // 2 if crop_height > image_height else 0,
434+
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
435+
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
436+
]
437+
image = F.pad(image, padding_ltrb, fill=0) # PIL uses fill value 0
438+
image_height, image_width = image.shape[-2:]
439+
if crop_width == image_width and crop_height == image_height:
440+
return image
441+
442+
crop_top = int((image_height - crop_height) / 2.0)
443+
crop_left = int((image_width - crop_width) / 2.0)
444+
return F.crop(image, crop_top, crop_left, crop_height, crop_width)
427445

428446
def convert_to_rgb(
429447
self,

src/transformers/models/perceiver/image_processing_perceiver_fast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def center_crop(
8181
min_dim = min(height, width)
8282
cropped_height = int((size.height / crop_size.height) * min_dim)
8383
cropped_width = int((size.width / crop_size.width) * min_dim)
84-
return F.center_crop(image, (cropped_height, cropped_width))
84+
return super().center_crop(image, SizeDict(height=cropped_height, width=cropped_width))
8585

8686
def _preprocess(
8787
self,

tests/models/chinese_clip/test_image_processing_chinese_clip.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class ChineseCLIPImageProcessingTestFourChannels(ImageProcessingTestMixin, unitt
141141

142142
def setUp(self):
143143
super().setUp()
144-
self.image_processor_tester = ChineseCLIPImageProcessingTester(self, num_channels=4, do_center_crop=True)
144+
self.image_processor_tester = ChineseCLIPImageProcessingTester(self, num_channels=3, do_center_crop=True)
145145
self.expected_encoded_image_num_channels = 3
146146

147147
@property
@@ -160,14 +160,6 @@ def test_image_processor_properties(self):
160160
self.assertTrue(hasattr(image_processing, "image_std"))
161161
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
162162

163-
@unittest.skip(reason="ChineseCLIPImageProcessor does not support 4 channels yet") # FIXME Amy
164-
def test_call_numpy(self):
165-
return super().test_call_numpy()
166-
167-
@unittest.skip(reason="ChineseCLIPImageProcessor does not support 4 channels yet") # FIXME Amy
168-
def test_call_pytorch(self):
169-
return super().test_call_torch()
170-
171163
@unittest.skip(
172164
reason="ChineseCLIPImageProcessor doesn't treat 4 channel PIL and numpy consistently yet"
173165
) # FIXME Amy

tests/test_image_processing_common.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,6 @@ def test_slow_fast_equivalence_batched(self):
200200
if self.image_processing_class is None or self.fast_image_processing_class is None:
201201
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
202202

203-
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
204-
self.skipTest(
205-
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
206-
)
207-
208203
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
209204
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
210205
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)

0 commit comments

Comments
 (0)