Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change default of antialias parameter from None to 'warn' #7160

Merged
merged 18 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions gallery/plot_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ def plot(imgs, **imshow_kwargs):

#########################
# The RAFT model accepts RGB images. We first get the frames from
# :func:`~torchvision.io.read_video` and resize them to ensure their
# dimensions are divisible by 8. Then we use the transforms bundled into the
# weights in order to preprocess the input and rescale its values to the
# :func:`~torchvision.io.read_video` and resize them to ensure their dimensions
# are divisible by 8. Note that we explicitly use ``antialias=False``, because
# this is how those models were trained. Then we use the transforms bundled into
# the weights in order to preprocess the input and rescale its values to the
# required ``[-1, 1]`` interval.

from torchvision.models.optical_flow import Raft_Large_Weights
Expand All @@ -93,8 +94,8 @@ def plot(imgs, **imshow_kwargs):


def preprocess(img1_batch, img2_batch):
img1_batch = F.resize(img1_batch, size=[520, 960])
img2_batch = F.resize(img2_batch, size=[520, 960])
img1_batch = F.resize(img1_batch, size=[520, 960], antialias=False)
img2_batch = F.resize(img2_batch, size=[520, 960], antialias=False)
return transforms(img1_batch, img2_batch)


Expand Down
6 changes: 5 additions & 1 deletion references/depth/stereo/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,11 @@ def forward(
INTERP_MODE = self._interpolation_mode_strategy()

for img in images:
resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE),)
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the stereo models with antialias=True?
resized_images += (F.resize(img, self.resize_size, interpolation=INTERP_MODE, antialias=False),)

for dsp in disparities:
if dsp is not None:
Expand Down
8 changes: 6 additions & 2 deletions references/optical_flow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,12 @@ def forward(self, img1, img2, flow, valid_flow_mask):

if torch.rand(1).item() < self.resize_prob:
# rescale the images
img1 = F.resize(img1, size=(new_h, new_w))
img2 = F.resize(img2, size=(new_h, new_w))
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the OF models with antialias=True?
img1 = F.resize(img1, size=(new_h, new_w), antialias=False)
img2 = F.resize(img2, size=(new_h, new_w), antialias=False)
if valid_flow_mask is None:
flow = F.resize(flow, size=(new_h, new_w))
flow = flow * torch.tensor([scale_x, scale_y])[:, None, None]
Expand Down
12 changes: 10 additions & 2 deletions references/video_classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ def __init__(
):
trans = [
transforms.ConvertImageDtype(torch.float32),
transforms.Resize(resize_size),
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the video models with antialias=True?
transforms.Resize(resize_size, antialias=False),
]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
Expand All @@ -31,7 +35,11 @@ def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645),
self.transforms = transforms.Compose(
[
transforms.ConvertImageDtype(torch.float32),
transforms.Resize(resize_size),
# We hard-code antialias=False to preserve results after we changed
# its default from None to True (see
# https://github.com/pytorch/vision/pull/7160)
# TODO: we could re-train the video models with antialias=True?
transforms.Resize(resize_size, antialias=False),
transforms.Normalize(mean=mean, std=std),
transforms.CenterCrop(crop_size),
ConvertBCHWtoCBHW(),
Expand Down
52 changes: 42 additions & 10 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import os
import re
import warnings
from functools import partial
from typing import Sequence

Expand Down Expand Up @@ -531,8 +532,8 @@ def test_resize(device, dt, size, max_size, interpolation):
tensor = tensor.to(dt)
batch_tensors = batch_tensors.to(dt)

resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
pmeier marked this conversation as resolved.
Show resolved Hide resolved
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: As we now doing antialias=True, we can also reduce the tol below now:

_assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0)

Copy link
Member Author

@NicolasHug NicolasHug Feb 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll merge now to unblock, and I'll open a PR to check the CI on a lower tol
EDIT: #7233


assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]

Expand All @@ -557,10 +558,12 @@ def test_resize(device, dt, size, max_size, interpolation):
else:
script_size = size

resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size)
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True)
assert_equal(resized_tensor, resize_result)

_test_fn_on_batch(batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size)
_test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True
)


@pytest.mark.parametrize("device", cpu_and_gpu())
Expand All @@ -576,9 +579,9 @@ def test_resize_asserts(device):
"Please use InterpolationMode enum."
),
):
res1 = F.resize(tensor, size=32, interpolation=2)
res1 = F.resize(tensor, size=32, interpolation=2, antialias=True)

res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
res2 = F.resize(tensor, size=32, interpolation=BILINEAR, antialias=True)
assert_equal(res1, res2)

for img in (tensor, pil_img):
Expand Down Expand Up @@ -608,7 +611,7 @@ def test_resize_antialias(device, dt, size, interpolation):
tensor = tensor.to(dt)

resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, antialias=True)

assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]

Expand Down Expand Up @@ -657,6 +660,23 @@ def test_assert_resize_antialias(interpolation):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)


def test_resize_antialias_default_warning():

img = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8)

match = "The default value of the antialias"
with pytest.warns(UserWarning, match=match):
F.resize(img, size=(20, 20))
with pytest.warns(UserWarning, match=match):
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20))

# For modes that aren't bicubic or bilinear, don't throw a warning
with warnings.catch_warnings():
warnings.simplefilter("error")
F.resize(img, size=(20, 20), interpolation=NEAREST)
F.resized_crop(img, 0, 0, 10, 10, size=(20, 20), interpolation=NEAREST)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dt", [torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("size", [[10, 7], [10, 42], [42, 7]])
Expand Down Expand Up @@ -985,12 +1005,16 @@ def test_resized_crop(device, mode):
# 1) resize to the same size, crop to the same size => should be identity
tensor, _ = _create_data(26, 36, device=device)

out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
out_tensor = F.resized_crop(
tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode, antialias=True
)
assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")

# 2) resize by half and crop a TL corner
tensor, _ = _create_data(26, 36, device=device)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
out_tensor = F.resized_crop(
tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST, antialias=True
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
)
expected_out_tensor = tensor[:, :20:2, :30:2]
assert_equal(
expected_out_tensor,
Expand All @@ -1000,7 +1024,15 @@ def test_resized_crop(device, mode):

batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
_test_fn_on_batch(
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST
batch_tensors,
F.resized_crop,
top=1,
left=2,
height=20,
width=30,
size=[10, 15],
interpolation=NEAREST,
antialias=True,
)


Expand Down
20 changes: 20 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,5 +1030,25 @@ def test_raft(model_fn, scripted):
_assert_expected(flow_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1)


def test_presets_antialias():

img = torch.randint(0, 256, size=(1, 3, 224, 224), dtype=torch.uint8)

match = "The default value of the antialias parameter"
with pytest.warns(UserWarning, match=match):
models.ResNet18_Weights.DEFAULT.transforms()(img)
with pytest.warns(UserWarning, match=match):
models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT.transforms()(img)

with warnings.catch_warnings():
warnings.simplefilter("error")
models.ResNet18_Weights.DEFAULT.transforms(antialias=True)(img)
models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT.transforms(antialias=True)(img)

models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()(img)
models.video.R3D_18_Weights.DEFAULT.transforms()(img)
models.optical_flow.Raft_Small_Weights.DEFAULT.transforms()(img, img)
Comment on lines +1068 to +1070
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these 3, AFAICT, Resize() is either not used in the evaluation presets, or the training was done on Tensors already (i.e. no interpolation was done anyway, in which case I hard-coded antialias=False and added comments in the code to explain why).



if __name__ == "__main__":
pytest.main([__file__])
23 changes: 17 additions & 6 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random
import re
import warnings
from functools import partial

import numpy as np
Expand Down Expand Up @@ -319,7 +320,7 @@ def test_randomresized_params():
scale_range = (scale_min, scale_min + round(random.random(), 2))
aspect_min = max(round(random.random(), 2), epsilon)
aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2))
randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range)
randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range, antialias=True)
i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
aspect_ratio_obtained = w / h
assert (
Expand Down Expand Up @@ -366,7 +367,7 @@ def test_randomresized_params():
def test_resize(height, width, osize, max_size):
img = Image.new("RGB", size=(width, height), color=127)

t = transforms.Resize(osize, max_size=max_size)
t = transforms.Resize(osize, max_size=max_size, antialias=True)
result = t(img)

msg = f"{height}, {width} - {osize} - {max_size}"
Expand Down Expand Up @@ -424,7 +425,7 @@ def test_resize_sequence_output(height, width, osize):
img = Image.new("RGB", size=(width, height), color=127)
oheight, owidth = osize

t = transforms.Resize(osize)
t = transforms.Resize(osize, antialias=True)
result = t(img)

assert (owidth, oheight) == result.size
Expand All @@ -439,6 +440,16 @@ def test_resize_antialias_error():
t(img)


def test_resize_antialias_default_warning():

img = Image.new("RGB", size=(10, 10), color=127)
# We make sure we don't warn for PIL images since the default behaviour doesn't change
with warnings.catch_warnings():
warnings.simplefilter("error")
transforms.Resize((20, 20))(img)
transforms.RandomResizedCrop((20, 20))(img)


@pytest.mark.parametrize("height, width", ((32, 64), (64, 32)))
def test_resize_size_equals_small_edge_size(height, width):
# Non-regression test for https://github.com/pytorch/vision/issues/5405
Expand All @@ -447,7 +458,7 @@ def test_resize_size_equals_small_edge_size(height, width):
img = Image.new("RGB", size=(width, height), color=127)

small_edge = min(height, width)
t = transforms.Resize(small_edge, max_size=max_size)
t = transforms.Resize(small_edge, max_size=max_size, antialias=True)
result = t(img)
assert max(result.size) == max_size

Expand Down Expand Up @@ -1424,11 +1435,11 @@ def test_random_choice(proba_passthrough, seed):
def test_random_order():
random_state = random.getstate()
random.seed(42)
random_order_transform = transforms.RandomOrder([transforms.Resize(20), transforms.CenterCrop(10)])
random_order_transform = transforms.RandomOrder([transforms.Resize(20, antialias=True), transforms.CenterCrop(10)])
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
num_normal_order = 0
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20, antialias=True)(img))
for _ in range(num_samples):
out = random_order_transform(img)
if out == resize_crop_out:
Expand Down
25 changes: 21 additions & 4 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -371,7 +372,7 @@ class TestResize:
def test_resize_int(self, size):
# TODO: Minimal check for bug-fix, improve this later
x = torch.rand(3, 32, 46)
t = T.Resize(size=size)
t = T.Resize(size=size, antialias=True)
y = t(x)
# If size is an int, smaller edge of the image will be matched to this number.
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
Expand All @@ -394,13 +395,13 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device):
if max_size is not None and len(size) != 1:
pytest.skip("Size should be an int or a sequence of length 1 if max_size is specified")

transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size)
transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size, antialias=True)
s_transform = torch.jit.script(transform)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

def test_resize_save_load(self, tmpdir):
fn = T.Resize(size=[32])
fn = T.Resize(size=[32], antialias=True)
_test_fn_save_load(fn, tmpdir)

@pytest.mark.parametrize("device", cpu_and_gpu())
Expand All @@ -424,9 +425,25 @@ def test_resized_crop(self, scale, ratio, size, interpolation, antialias, device
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

def test_resized_crop_save_load(self, tmpdir):
fn = T.RandomResizedCrop(size=[32])
fn = T.RandomResizedCrop(size=[32], antialias=True)
_test_fn_save_load(fn, tmpdir)

def test_antialias_default_warning(self):

img = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8)

match = "The default value of the antialias"
with pytest.warns(UserWarning, match=match):
T.Resize((20, 20))(img)
with pytest.warns(UserWarning, match=match):
T.RandomResizedCrop((20, 20))(img)

# For modes that aren't bicubic or bilinear, don't throw a warning
with warnings.catch_warnings():
warnings.simplefilter("error")
T.Resize((20, 20), interpolation=NEAREST)(img)
T.RandomResizedCrop((20, 20), interpolation=NEAREST)(img)


def _test_random_affine_helper(device, **kwargs):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def resize( # type: ignore[override]
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
antialias: Optional[Union[str, bool]] = "warn",
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
) -> BoundingBox:
output, spatial_size = self._F.resize_bounding_box(
self.as_subclass(torch.Tensor), spatial_size=self.spatial_size, size=size, max_size=max_size
Expand All @@ -105,7 +105,7 @@ def resized_crop(
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox:
output, spatial_size = self._F.resized_crop_bounding_box(
self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def resize( # type: ignore[override]
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
return self

Expand All @@ -180,7 +180,7 @@ def resized_crop(
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
return self

Expand Down
Loading