diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 2ac7e78e6a2..81bfa74acce 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -34,6 +34,15 @@ DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)]) +class NotScriptableArgsKwargs(ArgsKwargs): + """ + This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and + thus will be tested there, but will be skipped by the JIT tests. + """ + + pass + + class ConsistencyConfig: def __init__( self, @@ -73,7 +82,7 @@ def __init__( prototype_transforms.Resize, legacy_transforms.Resize, [ - ArgsKwargs(32), + NotScriptableArgsKwargs(32), ArgsKwargs([32]), ArgsKwargs((32, 29)), ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), @@ -84,8 +93,10 @@ def __init__( # ArgsKwargs((30, 27), interpolation=0), # ArgsKwargs((35, 29), interpolation=2), # ArgsKwargs((34, 25), interpolation=3), - ArgsKwargs(31, max_size=32), - ArgsKwargs(30, max_size=100), + NotScriptableArgsKwargs(31, max_size=32), + ArgsKwargs([31], max_size=32), + NotScriptableArgsKwargs(30, max_size=100), + ArgsKwargs([31], max_size=32), ArgsKwargs((29, 32), antialias=False), ArgsKwargs((28, 31), antialias=True), ], @@ -121,14 +132,15 @@ def __init__( prototype_transforms.Pad, legacy_transforms.Pad, [ - ArgsKwargs(3), + NotScriptableArgsKwargs(3), ArgsKwargs([3]), ArgsKwargs([2, 3]), ArgsKwargs([3, 2, 1, 4]), - ArgsKwargs(5, fill=1, padding_mode="constant"), - ArgsKwargs(5, padding_mode="edge"), - ArgsKwargs(5, padding_mode="reflect"), - ArgsKwargs(5, padding_mode="symmetric"), + NotScriptableArgsKwargs(5, fill=1, padding_mode="constant"), + ArgsKwargs([5], fill=1, padding_mode="constant"), + NotScriptableArgsKwargs(5, padding_mode="edge"), + NotScriptableArgsKwargs(5, padding_mode="reflect"), + NotScriptableArgsKwargs(5, padding_mode="symmetric"), ], ), ConsistencyConfig( @@ -170,7 +182,7 @@ def __init__( ConsistencyConfig( prototype_transforms.ToPILImage, legacy_transforms.ToPILImage, - [ArgsKwargs()], + [NotScriptableArgsKwargs()], make_images_kwargs=dict( color_spaces=[ "GRAY", @@ -186,7 +198,7 @@ def __init__( prototype_transforms.Lambda, legacy_transforms.Lambda, [ - ArgsKwargs(lambda image: image / 2), + NotScriptableArgsKwargs(lambda image: image / 2), ], # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL # images given that the transform does nothing but call it anyway. @@ -380,14 +392,15 @@ def __init__( [ ArgsKwargs(12), ArgsKwargs((15, 17)), - ArgsKwargs(11, padding=1), + NotScriptableArgsKwargs(11, padding=1), + ArgsKwargs(11, padding=[1]), ArgsKwargs((8, 13), padding=(2, 3)), ArgsKwargs((14, 9), padding=(0, 2, 1, 0)), ArgsKwargs(36, pad_if_needed=True), ArgsKwargs((7, 8), fill=1), - ArgsKwargs(5, fill=(1, 2, 3)), + NotScriptableArgsKwargs(5, fill=(1, 2, 3)), ArgsKwargs(12), - ArgsKwargs(15, padding=2, padding_mode="edge"), + NotScriptableArgsKwargs(15, padding=2, padding_mode="edge"), ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"), ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"), ], @@ -642,6 +655,38 @@ def test_call_consistency(config, args_kwargs): ) +@pytest.mark.parametrize( + ("config", "args_kwargs"), + [ + pytest.param( + config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}" + ) + for config in CONSISTENCY_CONFIGS + for idx, args_kwargs in enumerate(config.args_kwargs) + if not isinstance(args_kwargs, NotScriptableArgsKwargs) + ], +) +def test_jit_consistency(config, args_kwargs): + args, kwargs = args_kwargs + + prototype_transform_eager = config.prototype_cls(*args, **kwargs) + legacy_transform_eager = config.legacy_cls(*args, **kwargs) + + legacy_transform_scripted = torch.jit.script(legacy_transform_eager) + prototype_transform_scripted = torch.jit.script(prototype_transform_eager) + + for image in make_images(**config.make_images_kwargs): + image = image.as_subclass(torch.Tensor) + + torch.manual_seed(0) + output_legacy_scripted = legacy_transform_scripted(image) + + torch.manual_seed(0) + output_prototype_scripted = prototype_transform_scripted(image) + + assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs) + + class TestContainerTransforms: """ Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 667193784da..65b672b7edc 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -6,7 +6,7 @@ import PIL.Image import torch from torch.utils._pytree import tree_flatten, tree_unflatten - +from torchvision import transforms as _transforms from torchvision.ops import masks_to_boxes from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform @@ -16,6 +16,14 @@ class RandomErasing(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomErasing + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + return dict( + super()._extract_params_for_v1_transform(), + value="random" if self.value is None else self.value, + ) + _transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video) def __init__( diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index d4f2ca2143b..50b17068aaf 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -5,7 +5,7 @@ import torch from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec - +from torchvision import transforms as _transforms from torchvision.prototype import datapoints from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms.functional._meta import get_spatial_size @@ -161,6 +161,8 @@ def _apply_image_or_video_transform( class AutoAugment(_AutoAugmentBase): + _v1_transform_cls = _transforms.AutoAugment + _AUGMENTATION_SPACE = { "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), @@ -315,6 +317,7 @@ def forward(self, *inputs: Any) -> Any: class RandAugment(_AutoAugmentBase): + _v1_transform_cls = _transforms.RandAugment _AUGMENTATION_SPACE = { "Identity": (lambda num_bins, height, width: None, False), "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), @@ -375,6 +378,7 @@ def forward(self, *inputs: Any) -> Any: class TrivialAugmentWide(_AutoAugmentBase): + _v1_transform_cls = _transforms.TrivialAugmentWide _AUGMENTATION_SPACE = { "Identity": (lambda num_bins, height, width: None, False), "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), @@ -425,6 +429,8 @@ def forward(self, *inputs: Any) -> Any: class AugMix(_AutoAugmentBase): + _v1_transform_cls = _transforms.AugMix + _PARTIAL_AUGMENTATION_SPACE = { "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 17b02e36953..a360e076b1d 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -3,7 +3,7 @@ import PIL.Image import torch - +from torchvision import transforms as _transforms from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform @@ -12,6 +12,8 @@ class Grayscale(Transform): + _v1_transform_cls = _transforms.Grayscale + _transformed_types = ( datapoints.Image, PIL.Image.Image, @@ -28,6 +30,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomGrayscale(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomGrayscale + _transformed_types = ( datapoints.Image, PIL.Image.Image, @@ -47,6 +51,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ColorJitter(Transform): + _v1_transform_cls = _transforms.ColorJitter + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()} + def __init__( self, brightness: Optional[Union[float, Sequence[float]]] = None, @@ -194,16 +203,22 @@ def _transform( class RandomEqualize(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomEqualize + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.equalize(inpt) class RandomInvert(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomInvert + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.invert(inpt) class RandomPosterize(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomPosterize + def __init__(self, bits: int, p: float = 0.5) -> None: super().__init__(p=p) self.bits = bits @@ -213,6 +228,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomSolarize(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomSolarize + def __init__(self, threshold: float, p: float = 0.5) -> None: super().__init__(p=p) self.threshold = threshold @@ -222,11 +239,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomAutocontrast(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomAutocontrast + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.autocontrast(inpt) class RandomAdjustSharpness(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomAdjustSharpness + def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: super().__init__(p=p) self.sharpness_factor = sharpness_factor diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 8282a5d4d5f..70ae972d9e2 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -6,6 +6,7 @@ import PIL.Image import torch +from torchvision import transforms as _transforms from torchvision.ops.boxes import box_iou from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform @@ -25,16 +26,22 @@ class RandomHorizontalFlip(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomHorizontalFlip + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.horizontal_flip(inpt) class RandomVerticalFlip(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomVerticalFlip + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.vertical_flip(inpt) class Resize(Transform): + _v1_transform_cls = _transforms.Resize + def __init__( self, size: Union[int, Sequence[int]], @@ -69,6 +76,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class CenterCrop(Transform): + _v1_transform_cls = _transforms.CenterCrop + def __init__(self, size: Union[int, Sequence[int]]): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -78,6 +87,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomResizedCrop(Transform): + _v1_transform_cls = _transforms.RandomResizedCrop + def __init__( self, size: Union[int, Sequence[int]], @@ -174,6 +185,8 @@ class FiveCrop(Transform): torch.Size([5]) """ + _v1_transform_cls = _transforms.FiveCrop + _transformed_types = ( datapoints.Image, PIL.Image.Image, @@ -200,6 +213,8 @@ class TenCrop(Transform): See :class:`~torchvision.prototype.transforms.FiveCrop` for an example. """ + _v1_transform_cls = _transforms.TenCrop + _transformed_types = ( datapoints.Image, PIL.Image.Image, @@ -223,6 +238,18 @@ def _transform( class Pad(Transform): + _v1_transform_cls = _transforms.Pad + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + params = super()._extract_params_for_v1_transform() + + if not (params["fill"] is None or isinstance(params["fill"], (int, float))): + raise ValueError( + f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." + ) + + return params + def __init__( self, padding: Union[int, Sequence[int]], @@ -285,6 +312,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomRotation(Transform): + _v1_transform_cls = _transforms.RandomRotation + def __init__( self, degrees: Union[numbers.Number, Sequence], @@ -322,6 +351,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomAffine(Transform): + _v1_transform_cls = _transforms.RandomAffine + def __init__( self, degrees: Union[numbers.Number, Sequence], @@ -399,6 +430,24 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomCrop(Transform): + _v1_transform_cls = _transforms.RandomCrop + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + params = super()._extract_params_for_v1_transform() + + if not (params["fill"] is None or isinstance(params["fill"], (int, float))): + raise ValueError( + f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images." + ) + + padding = self.padding + if padding is not None: + pad_left, pad_right, pad_top, pad_bottom = padding + padding = [pad_left, pad_top, pad_right, pad_bottom] + params["padding"] = padding + + return params + def __init__( self, size: Union[int, Sequence[int]], @@ -491,6 +540,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomPerspective(_RandomApplyTransform): + _v1_transform_cls = _transforms.RandomPerspective + def __init__( self, distortion_scale: float = 0.5, @@ -550,6 +601,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ElasticTransform(Transform): + _v1_transform_cls = _transforms.ElasticTransform + def __init__( self, alpha: Union[float, Sequence[float]] = 50.0, diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 0373ee1baf3..1cef6eeb8f2 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -2,6 +2,7 @@ import torch +from torchvision import transforms as _transforms from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform @@ -27,6 +28,8 @@ def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> da class ConvertDtype(Transform): + _v1_transform_cls = _transforms.ConvertImageDtype + _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) def __init__(self, dtype: torch.dtype = torch.float32) -> None: diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 70a695199fc..07ab53aff82 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -4,6 +4,7 @@ import torch +from torchvision import transforms as _transforms from torchvision.ops import remove_small_boxes from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform @@ -39,6 +40,8 @@ def extra_repr(self) -> str: class LinearTransformation(Transform): + _v1_transform_cls = _transforms.LinearTransformation + _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): @@ -94,6 +97,7 @@ def _transform( class Normalize(Transform): + _v1_transform_cls = _transforms.Normalize _transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video) def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): @@ -113,6 +117,8 @@ def _transform( class GaussianBlur(Transform): + _v1_transform_cls = _transforms.GaussianBlur + def __init__( self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0) ) -> None: diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 43224cabd38..a1fb3846a24 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import enum -from typing import Any, Callable, Dict, List, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import PIL.Image import torch @@ -54,6 +56,51 @@ def extra_repr(self) -> str: return ", ".join(extra) + # This attribute should be set on all transforms that have a v1 equivalent. Doing so enables the v2 transformation + # to be scriptable. See `_extract_params_for_v1_transform()` and `__prepare_scriptable__` for details. + _v1_transform_cls: Optional[Type[nn.Module]] = None + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current + # v2 transform instance. It does two things: + # 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general + # 2. If available handle the `fill` attribute for v1 compatibility (see below for details) + # Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen + # if the v2 transform introduced new parameters that are not support by the v1 transform. + common_attrs = nn.Module().__dict__.keys() + params = { + attr: value + for attr, value in self.__dict__.items() + if not attr.startswith("_") and attr not in common_attrs + } + + # transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed + # with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value + # for the different datapoint types. Below we extract the value for tensors and return that together with the + # other params. + # This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and + # `RandomRotation` + if "fill" in params: + fill_type_defaultdict = params.pop("fill") + params["fill"] = fill_type_defaultdict[torch.Tensor] + + return params + + def __prepare_scriptable__(self) -> nn.Module: + # This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return + # value is used for scripting over the original object that should have been scripted. Since the v1 transforms + # are JIT scriptable, and we made sure that for single image inputs v1 and v2 are equivalent, we just return the + # equivalent v1 transform here. This of course only makes transforms v2 JIT scriptable as long as transforms v1 + # is around. + if self._v1_transform_cls is None: + raise RuntimeError( + f"Transform {type(self.__name__)} cannot be JIT scripted. " + f"This is only support for backward compatibility with transforms which already in v1." + f"For torchscript support (on tensors only), you can use the functional API instead." + ) + + return self._v1_transform_cls(**self._extract_params_for_v1_transform()) + class _RandomApplyTransform(Transform): def __init__(self, p: float = 0.5) -> None: