Skip to content

[proto] Fixed RandAug and all AA consistency tests #6519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,3 +1650,205 @@ def test__transform(self):
assert isinstance(ohe_labels, features.OneHotLabel)
assert ohe_labels.shape == (4, 3)
assert ohe_labels.categories == labels.categories == categories


class TestAPIConsistency:
@pytest.mark.parametrize("antialias", [True, False])
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
def test_random_resized_crop(self, antialias, inpt):
from torchvision.transforms import transforms as ref_transforms

size = 224
t_ref = ref_transforms.RandomResizedCrop(size, antialias=antialias)
t = transforms.RandomResizedCrop(size, antialias=antialias)

torch.manual_seed(12)
expected_output = t_ref(inpt)

torch.manual_seed(12)
output = t(inpt)

if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)

torch.testing.assert_close(expected_output, output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we use assert_equal, i.e. assert_close(..., rtol=0, atol=0) for the consistency checks? Any deviation means there is a difference that we don't want. Since we only compare tensor to tensor and PIL to PIL, the usual needed tolerances are probably not needed here.

In the framework I've added for consistency tests, I also added

assert_equal = functools.partial(_assert_equal, pair_types=[ImagePair], rtol=0, atol=0)

that handles PIL and tensor images simultaneously.


@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_randaug(self, inpt, interpolation, mocker):
from torchvision.transforms import autoaugment as ref_transforms

t_ref = ref_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = transforms.RandAugment(interpolation=interpolation, num_ops=1)

le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
# Stable API, if signed there is another random call
if t._AUGMENTATION_SPACE[keys[i]][1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
randint_values = iter(randint_values)

mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)

for i in range(le):
expected_output = t_ref(inpt)
output = t(inpt)

if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)

torch.testing.assert_close(expected_output, output)

@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_trivial_aug(self, inpt, interpolation, mocker):
from torchvision.transforms import autoaugment as ref_transforms

t_ref = ref_transforms.TrivialAugmentWide(interpolation=interpolation)
t = transforms.TrivialAugmentWide(interpolation=interpolation)

le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)

randint_values = iter(randint_values)

mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)

for _ in range(le):
expected_output = t_ref(inpt)
output = t(inpt)

if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)

torch.testing.assert_close(expected_output, output)

@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_augmix(self, inpt, interpolation, mocker):
from torchvision.transforms import autoaugment as ref_transforms

t_ref = ref_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
t = transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t._sample_dirichlet = lambda t: t.softmax(dim=-1)

le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)

randint_values = iter(randint_values)

mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)

expected_output = t_ref(inpt)
output = t(inpt)

if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)

torch.testing.assert_close(expected_output, output)

@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR])
def test_aa(self, inpt, interpolation):
from torchvision.transforms import autoaugment as ref_transforms

aa_policy = ref_transforms.AutoAugmentPolicy("imagenet")
t_ref = ref_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = transforms.AutoAugment(aa_policy, interpolation=interpolation)

torch.manual_seed(12)
expected_output = t_ref(inpt)

torch.manual_seed(12)
output = t(inpt)

if isinstance(inpt, PIL.Image.Image):
expected_output = pil_to_tensor(expected_output)
output = pil_to_tensor(output)

torch.testing.assert_close(expected_output, output)
43 changes: 30 additions & 13 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,46 @@ def _apply_image_transform(
interpolation: InterpolationMode,
fill: Union[int, float, Sequence[int], Sequence[float]],
) -> Any:

# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we have to put fill as None if fill == 0
fill_: Optional[Union[int, float, Sequence[int], Sequence[float]]]
if isinstance(fill, int) and fill == 0:
fill_ = None
else:
fill_ = fill

if transform_id == "Identity":
return image
elif transform_id == "ShearX":
# magnitude should be arctan(magnitude)
# official autoaug: (1, level, 0, 0, 1, 0)
# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
# compared to
# torchvision: (1, tan(level), 0, 0, 1, 0)
# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
return F.affine(
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(magnitude), 0.0],
shear=[math.degrees(math.atan(magnitude)), 0.0],
interpolation=interpolation,
fill=fill,
fill=fill_,
center=[0, 0],
)
elif transform_id == "ShearY":
# magnitude should be arctan(magnitude)
# See above
return F.affine(
image,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(magnitude)],
shear=[0.0, math.degrees(math.atan(magnitude))],
interpolation=interpolation,
fill=fill,
fill=fill_,
center=[0, 0],
)
elif transform_id == "TranslateX":
return F.affine(
Expand All @@ -99,7 +118,7 @@ def _apply_image_transform(
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
fill=fill_,
)
elif transform_id == "TranslateY":
return F.affine(
Expand All @@ -109,10 +128,10 @@ def _apply_image_transform(
scale=1.0,
shear=[0.0, 0.0],
interpolation=interpolation,
fill=fill,
fill=fill_,
)
elif transform_id == "Rotate":
return F.rotate(image, angle=magnitude)
return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
elif transform_id == "Brightness":
return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
elif transform_id == "Color":
Expand Down Expand Up @@ -340,19 +359,17 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
_, height, width = get_chw(image)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, but weird that flake8 didn't complain about an unused variable.


for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
magnitude = float(magnitudes[self.magnitude])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0

image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
Expand Down Expand Up @@ -397,7 +414,7 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
_, height, width = get_chw(image)

transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

Expand Down Expand Up @@ -467,7 +484,7 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image = self._extract_image(sample)
num_channels, height, width = get_chw(orig_image)
_, height, width = get_chw(orig_image)

if isinstance(orig_image, torch.Tensor):
image = orig_image
Expand Down
6 changes: 5 additions & 1 deletion torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,12 @@ def affine_segmentation_mask(


def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to change this in the stable transforms as well? I would feel a lot more confident in this patch if the stable CI confirms this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stable is using fill=None everywhere where I could spotted the difference, so that's why there is no change there.
However, we may want to fix the current inconsistency in stable with fill=None != fill=0 (#6517).

# So, we can't reassign fill to 0
# if fill is None:
# fill = 0
if fill is None:
fill = 0
return fill

# This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)):
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L

# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, dummy), dim=1)
mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, dummy is indeed not a good name.

img = torch.cat((img, mask), dim=1)

img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)

Expand Down