Skip to content

Commit

Permalink
fix resize
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Nov 3, 2022
1 parent d0394b7 commit 5f33f4a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ def resize_image_tensor(
)

if need_cast:
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
image = image.clamp_(min=0, max=255)
# bicubic interpolation can overshoot
if interpolation == InterpolationMode.BICUBIC:
image = image.clamp_(min=0, max=_FT._max_value(dtype))
image = image.round_().to(dtype=dtype)

return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,8 @@ def resize(

img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)

if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
if interpolation == "bicubic" and out_dtype not in (torch.float32, torch.float64):
img = img.clamp(min=0, max=_max_value(out_dtype))

img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)

Expand Down

0 comments on commit 5f33f4a

Please sign in to comment.