Skip to content

Commit

Permalink
Support integer values for interpolation in the prototype transforms (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Feb 15, 2023
1 parent f627b9d commit 0e0a5dc
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 96 deletions.
6 changes: 3 additions & 3 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,7 +1534,7 @@ def test__get_params(self, mocker):
assert int(spatial_size[1] * r_min) <= width <= int(spatial_size[1] * r_max)

def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()

transform = transforms.ScaleJitter(
Expand Down Expand Up @@ -1581,7 +1581,7 @@ def test__get_params(self, min_size, max_size, mocker):
assert shorter in min_size

def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()

transform = transforms.RandomShortestSize(
Expand Down Expand Up @@ -1945,7 +1945,7 @@ def test__get_params(self):
assert min_size <= size < max_size

def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()

transform = transforms.RandomResize(
Expand Down
46 changes: 36 additions & 10 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def __init__(
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC),
NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100),
Expand Down Expand Up @@ -305,6 +308,8 @@ def __init__(
ArgsKwargs(25, ratio=(0.5, 1.5)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
Expand Down Expand Up @@ -352,6 +357,8 @@ def __init__(
ArgsKwargs(sigma=(2.5, 3.9)),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs(interpolation=PIL.Image.NEAREST),
ArgsKwargs(interpolation=PIL.Image.BICUBIC),
ArgsKwargs(fill=1),
],
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
Expand Down Expand Up @@ -386,6 +393,7 @@ def __init__(
ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
ArgsKwargs(degrees=30.0, center=(0, 0)),
Expand Down Expand Up @@ -420,6 +428,7 @@ def __init__(
ArgsKwargs(p=1),
ArgsKwargs(p=1, distortion_scale=0.3),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
],
Expand All @@ -432,6 +441,7 @@ def __init__(
ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
ArgsKwargs(degrees=30.0, expand=True),
ArgsKwargs(degrees=30.0, center=(0, 0)),
ArgsKwargs(degrees=30.0, fill=1),
Expand Down Expand Up @@ -851,7 +861,11 @@ class TestAATransforms:
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_randaug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
Expand Down Expand Up @@ -889,7 +903,11 @@ def test_randaug(self, inpt, interpolation, mocker):
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_trivial_aug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
Expand Down Expand Up @@ -937,7 +955,11 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_augmix(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
Expand Down Expand Up @@ -986,7 +1008,11 @@ def test_augmix(self, inpt, interpolation, mocker):
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_aa(self, inpt, interpolation):
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
Expand Down Expand Up @@ -1264,13 +1290,13 @@ def test_random_resize_eval(self, mocker):
(legacy_F.convert_image_dtype, {}),
(legacy_F.to_pil_image, {}),
(legacy_F.normalize, {}),
(legacy_F.resize, {}),
(legacy_F.resize, {"interpolation"}),
(legacy_F.pad, {"padding", "fill"}),
(legacy_F.crop, {}),
(legacy_F.center_crop, {}),
(legacy_F.resized_crop, {}),
(legacy_F.resized_crop, {"interpolation"}),
(legacy_F.hflip, {}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill"}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
(legacy_F.vflip, {}),
(legacy_F.five_crop, {}),
(legacy_F.ten_crop, {}),
Expand All @@ -1279,8 +1305,8 @@ def test_random_resize_eval(self, mocker):
(legacy_F.adjust_saturation, {}),
(legacy_F.adjust_hue, {}),
(legacy_F.adjust_gamma, {}),
(legacy_F.rotate, {"center", "fill"}),
(legacy_F.affine, {"angle", "translate", "center", "fill"}),
(legacy_F.rotate, {"center", "fill", "interpolation"}),
(legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
(legacy_F.to_grayscale, {}),
(legacy_F.rgb_to_grayscale, {}),
(legacy_F.to_tensor, {}),
Expand All @@ -1292,7 +1318,7 @@ def test_random_resize_eval(self, mocker):
(legacy_F.adjust_sharpness, {}),
(legacy_F.autocontrast, {}),
(legacy_F.equalize, {}),
(legacy_F.elastic_transform, {"fill"}),
(legacy_F.elastic_transform, {"fill", "interpolation"}),
],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def vertical_flip(self) -> BoundingBox:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox:
Expand Down Expand Up @@ -107,7 +107,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox:
output, spatial_size = self._F.resized_crop_bounding_box(
Expand All @@ -133,7 +133,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -154,7 +154,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
Expand All @@ -174,7 +174,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> BoundingBox:
Expand All @@ -191,7 +191,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def vertical_flip(self) -> Datapoint:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
Expand All @@ -162,7 +162,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
return self
Expand All @@ -178,7 +178,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -191,7 +191,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Datapoint:
Expand All @@ -201,7 +201,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Datapoint:
Expand All @@ -210,7 +210,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Datapoint:
return self
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def vertical_flip(self) -> Image:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
Expand All @@ -86,7 +86,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
output = self._F.resized_crop_image_tensor(
Expand All @@ -113,7 +113,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -129,7 +129,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
Expand All @@ -149,7 +149,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Image:
Expand All @@ -166,7 +166,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Image:
output = self._F.elastic_image_tensor(
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def vertical_flip(self) -> Mask:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
Expand All @@ -75,7 +75,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
Expand All @@ -93,7 +93,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -107,7 +107,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
Expand All @@ -126,7 +126,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Mask:
Expand All @@ -138,7 +138,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
Expand Down
Loading

0 comments on commit 0e0a5dc

Please sign in to comment.