Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make transforms v2 JIT scriptable #7135

Merged
merged 16 commits into from
Jan 31, 2023
71 changes: 58 additions & 13 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])


class NotScriptableArgsKwargs(ArgsKwargs):
pmeier marked this conversation as resolved.
Show resolved Hide resolved
"""
This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
thus will be tested there, but will be skipped by the JIT tests.
"""

pass


class ConsistencyConfig:
def __init__(
self,
Expand Down Expand Up @@ -73,7 +82,7 @@ def __init__(
prototype_transforms.Resize,
legacy_transforms.Resize,
[
ArgsKwargs(32),
NotScriptableArgsKwargs(32),
ArgsKwargs([32]),
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
Expand All @@ -84,8 +93,10 @@ def __init__(
# ArgsKwargs((30, 27), interpolation=0),
# ArgsKwargs((35, 29), interpolation=2),
# ArgsKwargs((34, 25), interpolation=3),
ArgsKwargs(31, max_size=32),
ArgsKwargs(30, max_size=100),
NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100),
ArgsKwargs([31], max_size=32),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
Expand Down Expand Up @@ -121,14 +132,15 @@ def __init__(
prototype_transforms.Pad,
legacy_transforms.Pad,
[
ArgsKwargs(3),
NotScriptableArgsKwargs(3),
ArgsKwargs([3]),
ArgsKwargs([2, 3]),
ArgsKwargs([3, 2, 1, 4]),
ArgsKwargs(5, fill=1, padding_mode="constant"),
ArgsKwargs(5, padding_mode="edge"),
ArgsKwargs(5, padding_mode="reflect"),
ArgsKwargs(5, padding_mode="symmetric"),
NotScriptableArgsKwargs(5, fill=1, padding_mode="constant"),
ArgsKwargs([5], fill=1, padding_mode="constant"),
NotScriptableArgsKwargs(5, padding_mode="edge"),
NotScriptableArgsKwargs(5, padding_mode="reflect"),
NotScriptableArgsKwargs(5, padding_mode="symmetric"),
],
),
ConsistencyConfig(
Expand Down Expand Up @@ -170,7 +182,7 @@ def __init__(
ConsistencyConfig(
prototype_transforms.ToPILImage,
legacy_transforms.ToPILImage,
[ArgsKwargs()],
[NotScriptableArgsKwargs()],
make_images_kwargs=dict(
color_spaces=[
"GRAY",
Expand All @@ -186,7 +198,7 @@ def __init__(
prototype_transforms.Lambda,
legacy_transforms.Lambda,
[
ArgsKwargs(lambda image: image / 2),
NotScriptableArgsKwargs(lambda image: image / 2),
],
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
# images given that the transform does nothing but call it anyway.
Expand Down Expand Up @@ -380,14 +392,15 @@ def __init__(
[
ArgsKwargs(12),
ArgsKwargs((15, 17)),
ArgsKwargs(11, padding=1),
NotScriptableArgsKwargs(11, padding=1),
ArgsKwargs(11, padding=[1]),
ArgsKwargs((8, 13), padding=(2, 3)),
ArgsKwargs((14, 9), padding=(0, 2, 1, 0)),
ArgsKwargs(36, pad_if_needed=True),
ArgsKwargs((7, 8), fill=1),
ArgsKwargs(5, fill=(1, 2, 3)),
NotScriptableArgsKwargs(5, fill=(1, 2, 3)),
ArgsKwargs(12),
ArgsKwargs(15, padding=2, padding_mode="edge"),
NotScriptableArgsKwargs(15, padding=2, padding_mode="edge"),
ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"),
ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"),
],
Expand Down Expand Up @@ -642,6 +655,38 @@ def test_call_consistency(config, args_kwargs):
)


@pytest.mark.parametrize(
("config", "args_kwargs"),
[
pytest.param(
config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
)
for config in CONSISTENCY_CONFIGS
for idx, args_kwargs in enumerate(config.args_kwargs)
if not isinstance(args_kwargs, NotScriptableArgsKwargs)
],
)
def test_jit_consistency(config, args_kwargs):
args, kwargs = args_kwargs

prototype_transform_eager = config.prototype_cls(*args, **kwargs)
legacy_transform_eager = config.legacy_cls(*args, **kwargs)

legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

for image in make_images(**config.make_images_kwargs):
image = image.as_subclass(torch.Tensor)

torch.manual_seed(0)
output_legacy_scripted = legacy_transform_scripted(image)

torch.manual_seed(0)
output_prototype_scripted = prototype_transform_scripted(image)

assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)
pmeier marked this conversation as resolved.
Show resolved Hide resolved


class TestContainerTransforms:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
Expand Down
10 changes: 9 additions & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten

from torchvision import transforms as _transforms
from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
Expand All @@ -16,6 +16,14 @@


class RandomErasing(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomErasing

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return dict(
super()._extract_params_for_v1_transform(),
value="random" if self.value is None else self.value,
)

_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)

def __init__(
Expand Down
8 changes: 7 additions & 1 deletion torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec

from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_spatial_size
Expand Down Expand Up @@ -161,6 +161,8 @@ def _apply_image_or_video_transform(


class AutoAugment(_AutoAugmentBase):
_v1_transform_cls = _transforms.AutoAugment

_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
Expand Down Expand Up @@ -315,6 +317,7 @@ def forward(self, *inputs: Any) -> Any:


class RandAugment(_AutoAugmentBase):
_v1_transform_cls = _transforms.RandAugment
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
Expand Down Expand Up @@ -375,6 +378,7 @@ def forward(self, *inputs: Any) -> Any:


class TrivialAugmentWide(_AutoAugmentBase):
_v1_transform_cls = _transforms.TrivialAugmentWide
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
Expand Down Expand Up @@ -425,6 +429,8 @@ def forward(self, *inputs: Any) -> Any:


class AugMix(_AutoAugmentBase):
_v1_transform_cls = _transforms.AugMix

_PARTIAL_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
Expand Down
23 changes: 22 additions & 1 deletion torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import PIL.Image
import torch

from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform

Expand All @@ -12,6 +12,8 @@


class Grayscale(Transform):
_v1_transform_cls = _transforms.Grayscale

_transformed_types = (
datapoints.Image,
PIL.Image.Image,
Expand All @@ -28,6 +30,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomGrayscale(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomGrayscale

_transformed_types = (
datapoints.Image,
PIL.Image.Image,
Expand All @@ -47,6 +51,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class ColorJitter(Transform):
_v1_transform_cls = _transforms.ColorJitter

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}

def __init__(
self,
brightness: Optional[Union[float, Sequence[float]]] = None,
Expand Down Expand Up @@ -194,16 +203,22 @@ def _transform(


class RandomEqualize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomEqualize

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt)


class RandomInvert(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomInvert

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.invert(inpt)


class RandomPosterize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomPosterize

def __init__(self, bits: int, p: float = 0.5) -> None:
super().__init__(p=p)
self.bits = bits
Expand All @@ -213,6 +228,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomSolarize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomSolarize

def __init__(self, threshold: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.threshold = threshold
Expand All @@ -222,11 +239,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class RandomAutocontrast(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomAutocontrast

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.autocontrast(inpt)


class RandomAdjustSharpness(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomAdjustSharpness

def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.sharpness_factor = sharpness_factor
Expand Down
Loading