diff --git a/test/common_utils.py b/test/common_utils.py index 5a853771301..10209cecdea 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -187,7 +187,7 @@ def _assert_approx_equal_tensor_to_pil( tensor = tensor.to(torch.float) pil_tensor = pil_tensor.to(torch.float) err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item() - assert err < tol + assert err < tol, f"{err} vs {tol}" def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs): diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index fb07112bb83..9bc499467b7 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -14,9 +14,11 @@ cpu_and_gpu, assert_equal, ) +from PIL import Image from torchvision import transforms as T from torchvision.transforms import InterpolationMode from torchvision.transforms import functional as F +from torchvision.transforms.autoaugment import _apply_op NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC @@ -725,6 +727,48 @@ def test_autoaugment_save(augmentation, tmpdir): s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) +@pytest.mark.parametrize("interpolation", [F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR]) +@pytest.mark.parametrize("mode", ["X", "Y"]) +def test_autoaugment__op_apply_shear(interpolation, mode): + # We check that torchvision's implementation of shear is equivalent + # to official CIFAR10 autoaugment implementation: + # https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L290 + image_size = 32 + + def shear(pil_img, level, mode, resample): + if mode == "X": + matrix = (1, level, 0, 0, 1, 0) + elif mode == "Y": + matrix = (1, 0, 0, level, 1, 0) + return pil_img.transform((image_size, image_size), Image.AFFINE, matrix, resample=resample) + + t_img, pil_img = _create_data(image_size, image_size) + + resample_pil = { + F.InterpolationMode.NEAREST: Image.NEAREST, + F.InterpolationMode.BILINEAR: Image.BILINEAR, + }[interpolation] + + level = 0.3 + expected_out = shear(pil_img, level, mode=mode, resample=resample_pil) + + # Check pil output vs expected pil + out = _apply_op(pil_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0) + assert out == expected_out + + if interpolation == F.InterpolationMode.BILINEAR: + # We skip bilinear mode for tensors as + # affine transformation results are not exactly the same + # between tensors and pil images + # MAE as around 1.40 + # Max Abs error can be 163 or 170 + return + + # Check tensor output vs expected pil + out = _apply_op(t_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0) + _assert_approx_equal_tensor_to_pil(out, expected_out) + + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize( "config", diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 2c42bc375f6..228b2f8dd9b 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -14,23 +14,31 @@ def _apply_op( img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]] ): if op_name == "ShearX": + # magnitude should be arctan(magnitude) + # official autoaug: (1, level, 0, 0, 1, 0) + # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290 + # compared to + # torchvision: (1, tan(level), 0, 0, 1, 0) + # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976 img = F.affine( img, angle=0.0, translate=[0, 0], scale=1.0, - shear=[math.degrees(magnitude), 0.0], + shear=[math.degrees(math.atan(magnitude)), 0.0], interpolation=interpolation, fill=fill, center=[0, 0], ) elif op_name == "ShearY": + # magnitude should be arctan(magnitude) + # See above img = F.affine( img, angle=0.0, translate=[0, 0], scale=1.0, - shear=[0.0, math.degrees(magnitude)], + shear=[0.0, math.degrees(math.atan(magnitude))], interpolation=interpolation, fill=fill, center=[0, 0],