Skip to content

Expand prototype transforms consistency tests to most random tests #6522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,10 +555,10 @@ def test__get_params(self, degrees, translate, scale, shear, mocker):
if translate is not None:
w_max = int(round(translate[0] * w))
h_max = int(round(translate[1] * h))
assert -w_max <= params["translations"][0] <= w_max
assert -h_max <= params["translations"][1] <= h_max
assert -w_max <= params["translate"][0] <= w_max
assert -h_max <= params["translate"][1] <= h_max
else:
assert params["translations"] == (0, 0)
assert params["translate"] == (0, 0)

if scale is not None:
assert scale[0] <= params["scale"] <= scale[1]
Expand Down Expand Up @@ -759,7 +759,8 @@ def test__transform(self, kernel_size, sigma, mocker):
if isinstance(kernel_size, (tuple, list)):
assert transform.kernel_size == kernel_size
else:
assert transform.kernel_size == (kernel_size, kernel_size)
kernel_size = (kernel_size, kernel_size)
assert transform.kernel_size == kernel_size

if isinstance(sigma, (tuple, list)):
assert transform.sigma == sigma
Expand All @@ -779,7 +780,7 @@ def test__transform(self, kernel_size, sigma, mocker):
torch.manual_seed(12)
params = transform._get_params(inpt)

fn.assert_called_once_with(inpt, **params)
fn.assert_called_once_with(inpt, kernel_size, **params)


class TestRandomColorOp:
Expand Down
238 changes: 226 additions & 12 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def parametrization(self):
[
ArgsKwargs(18),
ArgsKwargs((18, 13)),
ArgsKwargs(18, vertical_flip=True),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
),
Expand Down Expand Up @@ -215,6 +216,214 @@ def parametrization(self):
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
ConsistencyConfig(
prototype_transforms.RandomHorizontalFlip,
legacy_transforms.RandomHorizontalFlip,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
),
ConsistencyConfig(
prototype_transforms.RandomVerticalFlip,
legacy_transforms.RandomVerticalFlip,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
),
ConsistencyConfig(
prototype_transforms.RandomEqualize,
legacy_transforms.RandomEqualize,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
),
ConsistencyConfig(
prototype_transforms.RandomInvert,
legacy_transforms.RandomInvert,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
),
ConsistencyConfig(
prototype_transforms.RandomPosterize,
legacy_transforms.RandomPosterize,
[
ArgsKwargs(p=0, bits=5),
ArgsKwargs(p=1, bits=1),
ArgsKwargs(p=1, bits=3),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
),
ConsistencyConfig(
prototype_transforms.RandomSolarize,
legacy_transforms.RandomSolarize,
[
ArgsKwargs(p=0, threshold=0.5),
ArgsKwargs(p=1, threshold=0.3),
ArgsKwargs(p=1, threshold=0.99),
],
),
ConsistencyConfig(
prototype_transforms.RandomAutocontrast,
legacy_transforms.RandomAutocontrast,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
),
ConsistencyConfig(
prototype_transforms.RandomAdjustSharpness,
legacy_transforms.RandomAdjustSharpness,
[
ArgsKwargs(p=0, sharpness_factor=0.5),
ArgsKwargs(p=1, sharpness_factor=0.3),
ArgsKwargs(p=1, sharpness_factor=0.99),
],
),
ConsistencyConfig(
prototype_transforms.RandomGrayscale,
legacy_transforms.RandomGrayscale,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
),
ConsistencyConfig(
prototype_transforms.RandomResizedCrop,
legacy_transforms.RandomResizedCrop,
[
ArgsKwargs(16),
ArgsKwargs(17, scale=(0.3, 0.7)),
ArgsKwargs(25, ratio=(0.5, 1.5)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
),
ConsistencyConfig(
prototype_transforms.RandomErasing,
legacy_transforms.RandomErasing,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
ArgsKwargs(p=1, scale=(0.3, 0.7)),
ArgsKwargs(p=1, ratio=(0.5, 1.5)),
ArgsKwargs(p=1, value=1),
ArgsKwargs(p=1, value=(1, 2, 3)),
ArgsKwargs(p=1, value="random"),
],
supports_pil=False,
),
ConsistencyConfig(
prototype_transforms.ColorJitter,
legacy_transforms.ColorJitter,
[
ArgsKwargs(),
ArgsKwargs(brightness=0.1),
ArgsKwargs(brightness=(0.2, 0.3)),
ArgsKwargs(contrast=0.4),
ArgsKwargs(contrast=(0.5, 0.6)),
ArgsKwargs(saturation=0.7),
ArgsKwargs(saturation=(0.8, 0.9)),
ArgsKwargs(hue=0.3),
ArgsKwargs(hue=(-0.1, 0.2)),
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.7, hue=0.3),
],
),
ConsistencyConfig(
prototype_transforms.ElasticTransform,
legacy_transforms.ElasticTransform,
[
ArgsKwargs(),
ArgsKwargs(alpha=20.0),
ArgsKwargs(alpha=(15.3, 27.2)),
ArgsKwargs(sigma=3.0),
ArgsKwargs(sigma=(2.5, 3.9)),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs(fill=1),
],
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)]),
),
ConsistencyConfig(
prototype_transforms.GaussianBlur,
legacy_transforms.GaussianBlur,
[
ArgsKwargs(kernel_size=3),
ArgsKwargs(kernel_size=(1, 5)),
ArgsKwargs(kernel_size=3, sigma=0.7),
ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)),
],
),
ConsistencyConfig(
prototype_transforms.RandomAffine,
legacy_transforms.RandomAffine,
[
ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=0.0, translate=(0.4, 0.6)),
ArgsKwargs(degrees=0.0, scale=(0.3, 0.8)),
ArgsKwargs(degrees=0.0, shear=13),
ArgsKwargs(degrees=0.0, shear=(8, 17)),
ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
ArgsKwargs(degrees=30.0, center=(0, 0)),
],
),
ConsistencyConfig(
prototype_transforms.RandomCrop,
legacy_transforms.RandomCrop,
[
ArgsKwargs(12),
ArgsKwargs((15, 17)),
ArgsKwargs(11, padding=1),
ArgsKwargs((8, 13), padding=(2, 3)),
ArgsKwargs((14, 9), padding=(0, 2, 1, 0)),
ArgsKwargs(36, pad_if_needed=True),
ArgsKwargs((7, 8), fill=1),
ArgsKwargs(5, fill=(1, 2, 3)),
ArgsKwargs(12),
ArgsKwargs(15, padding=2, padding_mode="edge"),
ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"),
ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(26, 26), (18, 33), (29, 22)]),
),
ConsistencyConfig(
prototype_transforms.RandomPerspective,
legacy_transforms.RandomPerspective,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
ArgsKwargs(p=1, distortion_scale=0.3),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
],
),
ConsistencyConfig(
prototype_transforms.RandomRotation,
legacy_transforms.RandomRotation,
[
ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR),
ArgsKwargs(degrees=30.0, expand=True),
ArgsKwargs(degrees=30.0, center=(0, 0)),
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(1, 2, 3)),
],
),
]


Expand All @@ -227,23 +436,23 @@ def test_automatic_coverage_deterministic():
and not issubclass(obj, enum.Enum)
and name
not in {
"Compose",
# This framework is based on the assumption that the input image can always be a tensor and optionally a
# PIL image. The transforms below require a non-tensor input and thus have to be tested manually.
# PIL image, but the transforms below require a non-tensor input.
"PILToTensor",
"ToTensor",
# Transform containers cannot be tested without other tranforms
"Compose",
"RandomApply",
"RandomChoice",
"RandomOrder",
# If the random parameter generation in the legacy and prototype transform is the same, setting the seed
# should be sufficient. In that case, the transforms below should be tested automatically.
"AugMix",
"AutoAugment",
"RandAugment",
"TrivialAugmentWide",
}
}
# filter out random transformations
legacy = {name for name in legacy if "Random" not in name} - {
"AugMix",
"TrivialAugmentWide",
"GaussianBlur",
"RandAugment",
"AutoAugment",
"ColorJitter",
"ElasticTransform",
}

prototype = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}

Expand Down Expand Up @@ -285,6 +494,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"

try:
torch.manual_seed(0)
output_legacy_tensor = legacy(image_tensor)
except Exception as exc:
raise pytest.UsageError(
Expand All @@ -294,6 +504,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
) from exc

try:
torch.manual_seed(0)
output_prototype_tensor = prototype(image_tensor)
except Exception as exc:
raise AssertionError(
Expand All @@ -309,6 +520,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
)

try:
torch.manual_seed(0)
output_prototype_image = prototype(image)
except Exception as exc:
raise AssertionError(
Expand All @@ -325,6 +537,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,

if image_pil is not None:
try:
torch.manual_seed(0)
output_legacy_pil = legacy(image_pil)
except Exception as exc:
raise pytest.UsageError(
Expand All @@ -334,6 +547,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
) from exc

try:
torch.manual_seed(0)
output_prototype_pil = prototype(image_pil)
except Exception as exc:
raise AssertionError(
Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,9 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
max_dy = float(self.translate[1] * height)
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
translations = (tx, ty)
translate = (tx, ty)
else:
translations = (0, 0)
translate = (0, 0)

if self.scale is not None:
scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item())
Expand All @@ -389,7 +389,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item())

shear = (shear_x, shear_y)
return dict(angle=angle, translations=translations, scale=scale, shear=shear)
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Seems like I just copied the name from stable RandomAffine.get_params but unfortunately previous tests didn't revealed that.


def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.affine(
Expand Down Expand Up @@ -489,7 +489,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
class RandomPerspective(_RandomApplyTransform):
def __init__(
self,
distortion_scale: float,
distortion_scale: float = 0.5,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
p: float = 0.5,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(sigma=[sigma, sigma])

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.gaussian_blur(inpt, **params)
return F.gaussian_blur(inpt, self.kernel_size, **params)


class ToDtype(Lambda):
Expand Down