Skip to content

Commit d9d7f6a

Browse files
authored
Revert change in compile_friendly_resize (#40645)
fix
1 parent 738b223 commit d9d7f6a

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/transformers/image_processing_utils_fast.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,13 @@ def compile_friendly_resize(
375375
A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
376376
"""
377377
if image.dtype == torch.uint8:
378-
image = image.float() / 255
378+
# 256 is used on purpose instead of 255 to avoid numerical differences
379+
# see https://github.com/huggingface/transformers/pull/38540#discussion_r2127165652
380+
image = image.float() / 256
379381
image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
380-
image = image * 255
382+
image = image * 256
383+
# torch.where is used on purpose instead of torch.clamp to avoid bug in torch.compile
384+
# see https://github.com/huggingface/transformers/pull/38540#discussion_r2126888471
381385
image = torch.where(image > 255, 255, image)
382386
image = torch.where(image < 0, 0, image)
383387
image = image.round().to(torch.uint8)

0 commit comments

Comments
 (0)