diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 52e457a070f..bf897f571c6 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -555,10 +555,10 @@ def test__get_params(self, degrees, translate, scale, shear, mocker): if translate is not None: w_max = int(round(translate[0] * w)) h_max = int(round(translate[1] * h)) - assert -w_max <= params["translations"][0] <= w_max - assert -h_max <= params["translations"][1] <= h_max + assert -w_max <= params["translate"][0] <= w_max + assert -h_max <= params["translate"][1] <= h_max else: - assert params["translations"] == (0, 0) + assert params["translate"] == (0, 0) if scale is not None: assert scale[0] <= params["scale"] <= scale[1] @@ -759,7 +759,8 @@ def test__transform(self, kernel_size, sigma, mocker): if isinstance(kernel_size, (tuple, list)): assert transform.kernel_size == kernel_size else: - assert transform.kernel_size == (kernel_size, kernel_size) + kernel_size = (kernel_size, kernel_size) + assert transform.kernel_size == kernel_size if isinstance(sigma, (tuple, list)): assert transform.sigma == sigma @@ -779,7 +780,7 @@ def test__transform(self, kernel_size, sigma, mocker): torch.manual_seed(12) params = transform._get_params(inpt) - fn.assert_called_once_with(inpt, **params) + fn.assert_called_once_with(inpt, kernel_size, **params) class TestRandomColorOp: diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index e4a48e542ff..97755d87d3a 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -137,6 +137,7 @@ def parametrization(self): [ ArgsKwargs(18), ArgsKwargs((18, 13)), + ArgsKwargs(18, vertical_flip=True), ], make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]), ), @@ -215,6 +216,214 @@ def parametrization(self): # images given that the transform does nothing but call it anyway. supports_pil=False, ), + ConsistencyConfig( + prototype_transforms.RandomHorizontalFlip, + legacy_transforms.RandomHorizontalFlip, + [ + ArgsKwargs(p=0), + ArgsKwargs(p=1), + ], + ), + ConsistencyConfig( + prototype_transforms.RandomVerticalFlip, + legacy_transforms.RandomVerticalFlip, + [ + ArgsKwargs(p=0), + ArgsKwargs(p=1), + ], + ), + ConsistencyConfig( + prototype_transforms.RandomEqualize, + legacy_transforms.RandomEqualize, + [ + ArgsKwargs(p=0), + ArgsKwargs(p=1), + ], + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]), + ), + ConsistencyConfig( + prototype_transforms.RandomInvert, + legacy_transforms.RandomInvert, + [ + ArgsKwargs(p=0), + ArgsKwargs(p=1), + ], + ), + ConsistencyConfig( + prototype_transforms.RandomPosterize, + legacy_transforms.RandomPosterize, + [ + ArgsKwargs(p=0, bits=5), + ArgsKwargs(p=1, bits=1), + ArgsKwargs(p=1, bits=3), + ], + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]), + ), + ConsistencyConfig( + prototype_transforms.RandomSolarize, + legacy_transforms.RandomSolarize, + [ + ArgsKwargs(p=0, threshold=0.5), + ArgsKwargs(p=1, threshold=0.3), + ArgsKwargs(p=1, threshold=0.99), + ], + ), + ConsistencyConfig( + prototype_transforms.RandomAutocontrast, + legacy_transforms.RandomAutocontrast, + [ + ArgsKwargs(p=0), + ArgsKwargs(p=1), + ], + ), + ConsistencyConfig( + prototype_transforms.RandomAdjustSharpness, + legacy_transforms.RandomAdjustSharpness, + [ + ArgsKwargs(p=0, sharpness_factor=0.5), + ArgsKwargs(p=1, sharpness_factor=0.3), + ArgsKwargs(p=1, sharpness_factor=0.99), + ], + ), + ConsistencyConfig( + prototype_transforms.RandomGrayscale, + legacy_transforms.RandomGrayscale, + [ + ArgsKwargs(p=0), + ArgsKwargs(p=1), + ], + ), + ConsistencyConfig( + prototype_transforms.RandomResizedCrop, + legacy_transforms.RandomResizedCrop, + [ + ArgsKwargs(16), + ArgsKwargs(17, scale=(0.3, 0.7)), + ArgsKwargs(25, ratio=(0.5, 1.5)), + ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), + ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), + ArgsKwargs((29, 32), antialias=False), + ArgsKwargs((28, 31), antialias=True), + ], + ), + ConsistencyConfig( + prototype_transforms.RandomErasing, + legacy_transforms.RandomErasing, + [ + ArgsKwargs(p=0), + ArgsKwargs(p=1), + ArgsKwargs(p=1, scale=(0.3, 0.7)), + ArgsKwargs(p=1, ratio=(0.5, 1.5)), + ArgsKwargs(p=1, value=1), + ArgsKwargs(p=1, value=(1, 2, 3)), + ArgsKwargs(p=1, value="random"), + ], + supports_pil=False, + ), + ConsistencyConfig( + prototype_transforms.ColorJitter, + legacy_transforms.ColorJitter, + [ + ArgsKwargs(), + ArgsKwargs(brightness=0.1), + ArgsKwargs(brightness=(0.2, 0.3)), + ArgsKwargs(contrast=0.4), + ArgsKwargs(contrast=(0.5, 0.6)), + ArgsKwargs(saturation=0.7), + ArgsKwargs(saturation=(0.8, 0.9)), + ArgsKwargs(hue=0.3), + ArgsKwargs(hue=(-0.1, 0.2)), + ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.7, hue=0.3), + ], + ), + ConsistencyConfig( + prototype_transforms.ElasticTransform, + legacy_transforms.ElasticTransform, + [ + ArgsKwargs(), + ArgsKwargs(alpha=20.0), + ArgsKwargs(alpha=(15.3, 27.2)), + ArgsKwargs(sigma=3.0), + ArgsKwargs(sigma=(2.5, 3.9)), + ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST), + ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC), + ArgsKwargs(fill=1), + ], + # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)]), + ), + ConsistencyConfig( + prototype_transforms.GaussianBlur, + legacy_transforms.GaussianBlur, + [ + ArgsKwargs(kernel_size=3), + ArgsKwargs(kernel_size=(1, 5)), + ArgsKwargs(kernel_size=3, sigma=0.7), + ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)), + ], + ), + ConsistencyConfig( + prototype_transforms.RandomAffine, + legacy_transforms.RandomAffine, + [ + ArgsKwargs(degrees=30.0), + ArgsKwargs(degrees=(-20.0, 10.0)), + ArgsKwargs(degrees=0.0, translate=(0.4, 0.6)), + ArgsKwargs(degrees=0.0, scale=(0.3, 0.8)), + ArgsKwargs(degrees=0.0, shear=13), + ArgsKwargs(degrees=0.0, shear=(8, 17)), + ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)), + ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)), + ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST), + ArgsKwargs(degrees=30.0, fill=1), + ArgsKwargs(degrees=30.0, fill=(2, 3, 4)), + ArgsKwargs(degrees=30.0, center=(0, 0)), + ], + ), + ConsistencyConfig( + prototype_transforms.RandomCrop, + legacy_transforms.RandomCrop, + [ + ArgsKwargs(12), + ArgsKwargs((15, 17)), + ArgsKwargs(11, padding=1), + ArgsKwargs((8, 13), padding=(2, 3)), + ArgsKwargs((14, 9), padding=(0, 2, 1, 0)), + ArgsKwargs(36, pad_if_needed=True), + ArgsKwargs((7, 8), fill=1), + ArgsKwargs(5, fill=(1, 2, 3)), + ArgsKwargs(12), + ArgsKwargs(15, padding=2, padding_mode="edge"), + ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"), + ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"), + ], + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(26, 26), (18, 33), (29, 22)]), + ), + ConsistencyConfig( + prototype_transforms.RandomPerspective, + legacy_transforms.RandomPerspective, + [ + ArgsKwargs(p=0), + ArgsKwargs(p=1), + ArgsKwargs(p=1, distortion_scale=0.3), + ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST), + ArgsKwargs(p=1, distortion_scale=0.1, fill=1), + ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)), + ], + ), + ConsistencyConfig( + prototype_transforms.RandomRotation, + legacy_transforms.RandomRotation, + [ + ArgsKwargs(degrees=30.0), + ArgsKwargs(degrees=(-20.0, 10.0)), + ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR), + ArgsKwargs(degrees=30.0, expand=True), + ArgsKwargs(degrees=30.0, center=(0, 0)), + ArgsKwargs(degrees=30.0, fill=1), + ArgsKwargs(degrees=30.0, fill=(1, 2, 3)), + ], + ), ] @@ -227,23 +436,23 @@ def test_automatic_coverage_deterministic(): and not issubclass(obj, enum.Enum) and name not in { - "Compose", # This framework is based on the assumption that the input image can always be a tensor and optionally a - # PIL image. The transforms below require a non-tensor input and thus have to be tested manually. + # PIL image, but the transforms below require a non-tensor input. "PILToTensor", "ToTensor", + # Transform containers cannot be tested without other tranforms + "Compose", + "RandomApply", + "RandomChoice", + "RandomOrder", + # If the random parameter generation in the legacy and prototype transform is the same, setting the seed + # should be sufficient. In that case, the transforms below should be tested automatically. + "AugMix", + "AutoAugment", + "RandAugment", + "TrivialAugmentWide", } } - # filter out random transformations - legacy = {name for name in legacy if "Random" not in name} - { - "AugMix", - "TrivialAugmentWide", - "GaussianBlur", - "RandAugment", - "AutoAugment", - "ColorJitter", - "ElasticTransform", - } prototype = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS} @@ -285,6 +494,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]" try: + torch.manual_seed(0) output_legacy_tensor = legacy(image_tensor) except Exception as exc: raise pytest.UsageError( @@ -294,6 +504,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ) from exc try: + torch.manual_seed(0) output_prototype_tensor = prototype(image_tensor) except Exception as exc: raise AssertionError( @@ -309,6 +520,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ) try: + torch.manual_seed(0) output_prototype_image = prototype(image) except Exception as exc: raise AssertionError( @@ -325,6 +537,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, if image_pil is not None: try: + torch.manual_seed(0) output_legacy_pil = legacy(image_pil) except Exception as exc: raise pytest.UsageError( @@ -334,6 +547,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, ) from exc try: + torch.manual_seed(0) output_prototype_pil = prototype(image_pil) except Exception as exc: raise AssertionError( diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index d9c57abfd2a..ee1cf6c44df 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -373,9 +373,9 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: max_dy = float(self.translate[1] * height) tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) - translations = (tx, ty) + translate = (tx, ty) else: - translations = (0, 0) + translate = (0, 0) if self.scale is not None: scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()) @@ -389,7 +389,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()) shear = (shear_x, shear_y) - return dict(angle=angle, translations=translations, scale=scale, shear=shear) + return dict(angle=angle, translate=translate, scale=scale, shear=shear) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.affine( @@ -489,7 +489,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomPerspective(_RandomApplyTransform): def __init__( self, - distortion_scale: float, + distortion_scale: float = 0.5, fill: Union[int, float, Sequence[int], Sequence[float]] = 0, interpolation: InterpolationMode = InterpolationMode.BILINEAR, p: float = 0.5, diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index e0235d160e6..124e4b77274 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -136,7 +136,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(sigma=[sigma, sigma]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.gaussian_blur(inpt, **params) + return F.gaussian_blur(inpt, self.kernel_size, **params) class ToDtype(Lambda):