Skip to content

Commit

Permalink
make transforms v2 JIT scriptable (#7135)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Jan 31, 2023
1 parent 170160a commit 7cf0f4c
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 17 deletions.
71 changes: 58 additions & 13 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])


class NotScriptableArgsKwargs(ArgsKwargs):
"""
This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
thus will be tested there, but will be skipped by the JIT tests.
"""

pass


class ConsistencyConfig:
def __init__(
self,
Expand Down Expand Up @@ -73,7 +82,7 @@ def __init__(
prototype_transforms.Resize,
legacy_transforms.Resize,
[
ArgsKwargs(32),
NotScriptableArgsKwargs(32),
ArgsKwargs([32]),
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
Expand All @@ -84,8 +93,10 @@ def __init__(
# ArgsKwargs((30, 27), interpolation=0),
# ArgsKwargs((35, 29), interpolation=2),
# ArgsKwargs((34, 25), interpolation=3),
ArgsKwargs(31, max_size=32),
ArgsKwargs(30, max_size=100),
NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100),
ArgsKwargs([31], max_size=32),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
Expand Down Expand Up @@ -121,14 +132,15 @@ def __init__(
prototype_transforms.Pad,
legacy_transforms.Pad,
[
ArgsKwargs(3),
NotScriptableArgsKwargs(3),
ArgsKwargs([3]),
ArgsKwargs([2, 3]),
ArgsKwargs([3, 2, 1, 4]),
ArgsKwargs(5, fill=1, padding_mode="constant"),
ArgsKwargs(5, padding_mode="edge"),
ArgsKwargs(5, padding_mode="reflect"),
ArgsKwargs(5, padding_mode="symmetric"),
NotScriptableArgsKwargs(5, fill=1, padding_mode="constant"),
ArgsKwargs([5], fill=1, padding_mode="constant"),
NotScriptableArgsKwargs(5, padding_mode="edge"),
NotScriptableArgsKwargs(5, padding_mode="reflect"),
NotScriptableArgsKwargs(5, padding_mode="symmetric"),
],
),
ConsistencyConfig(
Expand Down Expand Up @@ -170,7 +182,7 @@ def __init__(
ConsistencyConfig(
prototype_transforms.ToPILImage,
legacy_transforms.ToPILImage,
[ArgsKwargs()],
[NotScriptableArgsKwargs()],
make_images_kwargs=dict(
color_spaces=[
"GRAY",
Expand All @@ -186,7 +198,7 @@ def __init__(
prototype_transforms.Lambda,
legacy_transforms.Lambda,
[
ArgsKwargs(lambda image: image / 2),
NotScriptableArgsKwargs(lambda image: image / 2),
],
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
# images given that the transform does nothing but call it anyway.
Expand Down Expand Up @@ -380,14 +392,15 @@ def __init__(
[
ArgsKwargs(12),
ArgsKwargs((15, 17)),
ArgsKwargs(11, padding=1),
NotScriptableArgsKwargs(11, padding=1),
ArgsKwargs(11, padding=[1]),
ArgsKwargs((8, 13), padding=(2, 3)),
ArgsKwargs((14, 9), padding=(0, 2, 1, 0)),
ArgsKwargs(36, pad_if_needed=True),
ArgsKwargs((7, 8), fill=1),
ArgsKwargs(5, fill=(1, 2, 3)),
NotScriptableArgsKwargs(5, fill=(1, 2, 3)),
ArgsKwargs(12),
ArgsKwargs(15, padding=2, padding_mode="edge"),
NotScriptableArgsKwargs(15, padding=2, padding_mode="edge"),
ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"),
ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"),
],
Expand Down Expand Up @@ -642,6 +655,38 @@ def test_call_consistency(config, args_kwargs):
)


@pytest.mark.parametrize(
("config", "args_kwargs"),
[
pytest.param(
config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
)
for config in CONSISTENCY_CONFIGS
for idx, args_kwargs in enumerate(config.args_kwargs)
if not isinstance(args_kwargs, NotScriptableArgsKwargs)
],
)
def test_jit_consistency(config, args_kwargs):
args, kwargs = args_kwargs

prototype_transform_eager = config.prototype_cls(*args, **kwargs)
legacy_transform_eager = config.legacy_cls(*args, **kwargs)

legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

for image in make_images(**config.make_images_kwargs):
image = image.as_subclass(torch.Tensor)

torch.manual_seed(0)
output_legacy_scripted = legacy_transform_scripted(image)

torch.manual_seed(0)
output_prototype_scripted = prototype_transform_scripted(image)

assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


class TestContainerTransforms:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
Expand Down
10 changes: 9 additions & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten

from torchvision import transforms as _transforms
from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
Expand All @@ -16,6 +16,14 @@


class RandomErasing(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomErasing

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return dict(
super()._extract_params_for_v1_transform(),
value="random" if self.value is None else self.value,
)

_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)

def __init__(
Expand Down
8 changes: 7 additions & 1 deletion torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec

from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_spatial_size
Expand Down Expand Up @@ -161,6 +161,8 @@ def _apply_image_or_video_transform(


class AutoAugment(_AutoAugmentBase):
_v1_transform_cls = _transforms.AutoAugment

_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
Expand Down Expand Up @@ -315,6 +317,7 @@ def forward(self, *inputs: Any) -> Any:


class RandAugment(_AutoAugmentBase):
_v1_transform_cls = _transforms.RandAugment
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
Expand Down Expand Up @@ -375,6 +378,7 @@ def forward(self, *inputs: Any) -> Any:


class TrivialAugmentWide(_AutoAugmentBase):
_v1_transform_cls = _transforms.TrivialAugmentWide
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
Expand Down Expand Up @@ -425,6 +429,8 @@ def forward(self, *inputs: Any) -> Any:


class AugMix(_AutoAugmentBase):
_v1_transform_cls = _transforms.AugMix

_PARTIAL_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
Expand Down
23 changes: 22 additions & 1 deletion torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import PIL.Image
import torch

from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform

Expand All @@ -12,6 +12,8 @@


class Grayscale(Transform):
_v1_transform_cls = _transforms.Grayscale

_transformed_types = (
datapoints.Image,
PIL.Image.Image,
Expand All @@ -28,6 +30,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomGrayscale(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomGrayscale

_transformed_types = (
datapoints.Image,
PIL.Image.Image,
Expand All @@ -47,6 +51,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class ColorJitter(Transform):
_v1_transform_cls = _transforms.ColorJitter

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}

def __init__(
self,
brightness: Optional[Union[float, Sequence[float]]] = None,
Expand Down Expand Up @@ -194,16 +203,22 @@ def _transform(


class RandomEqualize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomEqualize

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt)


class RandomInvert(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomInvert

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.invert(inpt)


class RandomPosterize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomPosterize

def __init__(self, bits: int, p: float = 0.5) -> None:
super().__init__(p=p)
self.bits = bits
Expand All @@ -213,6 +228,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomSolarize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomSolarize

def __init__(self, threshold: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.threshold = threshold
Expand All @@ -222,11 +239,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomAutocontrast(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomAutocontrast

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.autocontrast(inpt)


class RandomAdjustSharpness(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomAdjustSharpness

def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.sharpness_factor = sharpness_factor
Expand Down
Loading

0 comments on commit 7cf0f4c

Please sign in to comment.