diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index bf897f571c6..548ed675ea8 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1650,3 +1650,205 @@ def test__transform(self): assert isinstance(ohe_labels, features.OneHotLabel) assert ohe_labels.shape == (4, 3) assert ohe_labels.categories == labels.categories == categories + + +class TestAPIConsistency: + @pytest.mark.parametrize("antialias", [True, False]) + @pytest.mark.parametrize( + "inpt", + [ + torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), + PIL.Image.new("RGB", (256, 256), 123), + features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + ], + ) + def test_random_resized_crop(self, antialias, inpt): + from torchvision.transforms import transforms as ref_transforms + + size = 224 + t_ref = ref_transforms.RandomResizedCrop(size, antialias=antialias) + t = transforms.RandomResizedCrop(size, antialias=antialias) + + torch.manual_seed(12) + expected_output = t_ref(inpt) + + torch.manual_seed(12) + output = t(inpt) + + if isinstance(inpt, PIL.Image.Image): + expected_output = pil_to_tensor(expected_output) + output = pil_to_tensor(output) + + torch.testing.assert_close(expected_output, output) + + @pytest.mark.parametrize( + "inpt", + [ + torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), + PIL.Image.new("RGB", (256, 256), 123), + features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + ], + ) + @pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR]) + def test_randaug(self, inpt, interpolation, mocker): + from torchvision.transforms import autoaugment as ref_transforms + + t_ref = ref_transforms.RandAugment(interpolation=interpolation, num_ops=1) + t = transforms.RandAugment(interpolation=interpolation, num_ops=1) + + le = len(t._AUGMENTATION_SPACE) + keys = list(t._AUGMENTATION_SPACE.keys()) + randint_values = [] + for i in range(le): + # Stable API, op_index random call + randint_values.append(i) + # Stable API, if signed there is another random call + if t._AUGMENTATION_SPACE[keys[i]][1]: + randint_values.append(0) + # New API, _get_random_item + randint_values.append(i) + randint_values = iter(randint_values) + + mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) + mocker.patch("torch.rand", return_value=1.0) + + for i in range(le): + expected_output = t_ref(inpt) + output = t(inpt) + + if isinstance(inpt, PIL.Image.Image): + expected_output = pil_to_tensor(expected_output) + output = pil_to_tensor(output) + + torch.testing.assert_close(expected_output, output) + + @pytest.mark.parametrize( + "inpt", + [ + torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), + PIL.Image.new("RGB", (256, 256), 123), + features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + ], + ) + @pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR]) + def test_trivial_aug(self, inpt, interpolation, mocker): + from torchvision.transforms import autoaugment as ref_transforms + + t_ref = ref_transforms.TrivialAugmentWide(interpolation=interpolation) + t = transforms.TrivialAugmentWide(interpolation=interpolation) + + le = len(t._AUGMENTATION_SPACE) + keys = list(t._AUGMENTATION_SPACE.keys()) + randint_values = [] + for i in range(le): + # Stable API, op_index random call + randint_values.append(i) + key = keys[i] + # Stable API, random magnitude + aug_op = t._AUGMENTATION_SPACE[key] + magnitudes = aug_op[0](2, 0, 0) + if magnitudes is not None: + randint_values.append(5) + # Stable API, if signed there is another random call + if aug_op[1]: + randint_values.append(0) + # New API, _get_random_item + randint_values.append(i) + # New API, random magnitude + if magnitudes is not None: + randint_values.append(5) + + randint_values = iter(randint_values) + + mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) + mocker.patch("torch.rand", return_value=1.0) + + for _ in range(le): + expected_output = t_ref(inpt) + output = t(inpt) + + if isinstance(inpt, PIL.Image.Image): + expected_output = pil_to_tensor(expected_output) + output = pil_to_tensor(output) + + torch.testing.assert_close(expected_output, output) + + @pytest.mark.parametrize( + "inpt", + [ + torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), + PIL.Image.new("RGB", (256, 256), 123), + features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + ], + ) + @pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR]) + def test_augmix(self, inpt, interpolation, mocker): + from torchvision.transforms import autoaugment as ref_transforms + + t_ref = ref_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) + t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1) + t = transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) + t._sample_dirichlet = lambda t: t.softmax(dim=-1) + + le = len(t._AUGMENTATION_SPACE) + keys = list(t._AUGMENTATION_SPACE.keys()) + randint_values = [] + for i in range(le): + # Stable API, op_index random call + randint_values.append(i) + key = keys[i] + # Stable API, random magnitude + aug_op = t._AUGMENTATION_SPACE[key] + magnitudes = aug_op[0](2, 0, 0) + if magnitudes is not None: + randint_values.append(5) + # Stable API, if signed there is another random call + if aug_op[1]: + randint_values.append(0) + # New API, _get_random_item + randint_values.append(i) + # New API, random magnitude + if magnitudes is not None: + randint_values.append(5) + + randint_values = iter(randint_values) + + mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) + mocker.patch("torch.rand", return_value=1.0) + + expected_output = t_ref(inpt) + output = t(inpt) + + if isinstance(inpt, PIL.Image.Image): + expected_output = pil_to_tensor(expected_output) + output = pil_to_tensor(output) + + torch.testing.assert_close(expected_output, output) + + @pytest.mark.parametrize( + "inpt", + [ + torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), + PIL.Image.new("RGB", (256, 256), 123), + features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + ], + ) + @pytest.mark.parametrize("interpolation", [InterpolationMode.NEAREST, InterpolationMode.BILINEAR]) + def test_aa(self, inpt, interpolation): + from torchvision.transforms import autoaugment as ref_transforms + + aa_policy = ref_transforms.AutoAugmentPolicy("imagenet") + t_ref = ref_transforms.AutoAugment(aa_policy, interpolation=interpolation) + t = transforms.AutoAugment(aa_policy, interpolation=interpolation) + + torch.manual_seed(12) + expected_output = t_ref(inpt) + + torch.manual_seed(12) + output = t(inpt) + + if isinstance(inpt, PIL.Image.Image): + expected_output = pil_to_tensor(expected_output) + output = pil_to_tensor(output) + + torch.testing.assert_close(expected_output, output) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index f113c64c346..a2bbb504fa6 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -69,27 +69,46 @@ def _apply_image_transform( interpolation: InterpolationMode, fill: Union[int, float, Sequence[int], Sequence[float]], ) -> Any: + + # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 + # So, we have to put fill as None if fill == 0 + fill_: Optional[Union[int, float, Sequence[int], Sequence[float]]] + if isinstance(fill, int) and fill == 0: + fill_ = None + else: + fill_ = fill + if transform_id == "Identity": return image elif transform_id == "ShearX": + # magnitude should be arctan(magnitude) + # official autoaug: (1, level, 0, 0, 1, 0) + # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290 + # compared to + # torchvision: (1, tan(level), 0, 0, 1, 0) + # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976 return F.affine( image, angle=0.0, translate=[0, 0], scale=1.0, - shear=[math.degrees(magnitude), 0.0], + shear=[math.degrees(math.atan(magnitude)), 0.0], interpolation=interpolation, - fill=fill, + fill=fill_, + center=[0, 0], ) elif transform_id == "ShearY": + # magnitude should be arctan(magnitude) + # See above return F.affine( image, angle=0.0, translate=[0, 0], scale=1.0, - shear=[0.0, math.degrees(magnitude)], + shear=[0.0, math.degrees(math.atan(magnitude))], interpolation=interpolation, - fill=fill, + fill=fill_, + center=[0, 0], ) elif transform_id == "TranslateX": return F.affine( @@ -99,7 +118,7 @@ def _apply_image_transform( scale=1.0, shear=[0.0, 0.0], interpolation=interpolation, - fill=fill, + fill=fill_, ) elif transform_id == "TranslateY": return F.affine( @@ -109,10 +128,10 @@ def _apply_image_transform( scale=1.0, shear=[0.0, 0.0], interpolation=interpolation, - fill=fill, + fill=fill_, ) elif transform_id == "Rotate": - return F.rotate(image, angle=magnitude) + return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_) elif transform_id == "Brightness": return F.adjust_brightness(image, brightness_factor=1.0 + magnitude) elif transform_id == "Color": @@ -340,19 +359,17 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, image = self._extract_image(sample) - num_channels, height, width = get_chw(image) + _, height, width = get_chw(image) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: - magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) + magnitude = float(magnitudes[self.magnitude]) if signed and torch.rand(()) <= 0.5: magnitude *= -1 else: magnitude = 0.0 - image = self._apply_image_transform( image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) @@ -397,7 +414,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, image = self._extract_image(sample) - num_channels, height, width = get_chw(image) + _, height, width = get_chw(image) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -467,7 +484,7 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] id, orig_image = self._extract_image(sample) - num_channels, height, width = get_chw(orig_image) + _, height, width = get_chw(orig_image) if isinstance(orig_image, torch.Tensor): image = orig_image diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6f852097b53..c4b68db96fd 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -379,8 +379,12 @@ def affine_segmentation_mask( def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]: + # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 + # So, we can't reassign fill to 0 + # if fill is None: + # fill = 0 if fill is None: - fill = 0 + return fill # This cast does Sequence -> List[float] to please mypy and torch.jit.script if not isinstance(fill, (int, float)): diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index eae86d5f105..8a88bf61ea8 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -549,8 +549,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice if fill is not None: - dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device) - img = torch.cat((img, dummy), dim=1) + mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device) + img = torch.cat((img, mask), dim=1) img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)