Skip to content

Commit

Permalink
Add uint8 bicubic support to ResizeV2 (#7668)
Browse files Browse the repository at this point in the history
Co-authored-by: vfdev-5 <vfdev.5@gmail.com>
  • Loading branch information
NicolasHug and vfdev-5 authored Jun 14, 2023
1 parent e44bba1 commit 8324c48
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 17 deletions.
22 changes: 18 additions & 4 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,8 @@ def __init__(
ArgsKwargs([32]),
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC),
NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100),
Expand All @@ -101,6 +99,15 @@ def __init__(
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs=dict(rtol=0, atol=1),
),
ConsistencyConfig(
v2_transforms.Resize,
legacy_transforms.Resize,
[
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC, antialias=True),
],
closeness_kwargs=dict(rtol=0, atol=21),
),
ConsistencyConfig(
v2_transforms.CenterCrop,
legacy_transforms.CenterCrop,
Expand Down Expand Up @@ -309,15 +316,22 @@ def __init__(
ArgsKwargs(17, scale=(0.3, 0.7)),
ArgsKwargs(25, ratio=(0.5, 1.5)),
ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs=dict(rtol=0, atol=1),
),
ConsistencyConfig(
v2_transforms.RandomResizedCrop,
legacy_transforms.RandomResizedCrop,
[
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC, antialias=True),
],
closeness_kwargs=dict(rtol=0, atol=21),
),
ConsistencyConfig(
v2_transforms.RandomErasing,
legacy_transforms.RandomErasing,
Expand Down
28 changes: 23 additions & 5 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,20 @@ def sample_inputs_resize_image_tensor():

for image_loader, interpolation in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=["RGB"]),
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC,
],
[F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR],
):
yield ArgsKwargs(image_loader, size=[min(image_loader.spatial_size) + 1], interpolation=interpolation)

yield ArgsKwargs(make_image_loader(size=(11, 17)), size=20, max_size=25)


def sample_inputs_resize_image_tensor_bicubic():
for image_loader, interpolation in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=["RGB"]), [F.InterpolationMode.BICUBIC]
):
yield ArgsKwargs(image_loader, size=[min(image_loader.spatial_size) + 1], interpolation=interpolation)


@pil_reference_wrapper
def reference_resize_image_tensor(*args, **kwargs):
if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in {
Expand Down Expand Up @@ -364,6 +367,21 @@ def reference_inputs_resize_bounding_box():
xfail_jit_python_scalar_arg("size"),
],
),
KernelInfo(
F.resize_image_tensor,
sample_inputs_fn=sample_inputs_resize_image_tensor_bicubic,
reference_fn=reference_resize_image_tensor,
reference_inputs_fn=reference_inputs_resize_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(10, mae=True),
**cuda_vs_cpu_pixel_difference(atol=30),
**float32_vs_uint8_pixel_difference(1, mae=True),
},
test_marks=[
xfail_jit_python_scalar_arg("size"),
],
),
KernelInfo(
F.resize_bounding_box,
sample_inputs_fn=sample_inputs_resize_bounding_box,
Expand Down
16 changes: 8 additions & 8 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,13 @@ def resize_image_tensor(
if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
# uint8 dtype can be included for cpu and cuda input if nearest mode
acceptable_dtypes.append(torch.uint8)
elif (
interpolation == InterpolationMode.BILINEAR
and image.device.type == "cpu"
and "AVX2" in torch.backends.cpu.get_cpu_capability()
):
# uint8 dtype support for bilinear mode is limited to cpu and
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
acceptable_dtypes.append(torch.uint8)
elif image.device.type == "cpu":
# uint8 dtype support for bilinear and bicubic is limited to cpu and
# according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
interpolation == InterpolationMode.BICUBIC
):
acceptable_dtypes.append(torch.uint8)

strides = image.stride()
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
Expand Down Expand Up @@ -227,6 +226,7 @@ def resize_image_tensor(

if need_cast:
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
# This path is hit on non-AVX archs, or on GPU.
image = image.clamp_(min=0, max=255)
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
image = image.round_()
Expand Down

0 comments on commit 8324c48

Please sign in to comment.