Skip to content

add consistency tests for prototype container transforms #6525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 112 additions & 37 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,38 +464,18 @@ def test_automatic_coverage_deterministic():
)


@pytest.mark.parametrize(
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"),
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS),
)
def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, make_images_kwargs, supports_pil):
args, kwargs = args_kwargs

try:
legacy = legacy_transform_cls(*args, **kwargs)
except Exception as exc:
raise pytest.UsageError(
f"Initializing the legacy transform failed with the error above. "
f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
) from exc
def check_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True):
if images is None:
images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)

try:
prototype = prototype_transform_cls(*args, **kwargs)
except Exception as exc:
raise AssertionError(
"Initializing the prototype transform failed with the error above. "
"This means there is a consistency bug in the constructor."
) from exc
for image in images:
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"

for image in make_images(**make_images_kwargs):
image_tensor = torch.Tensor(image)
image_pil = to_image_pil(image) if image.ndim == 3 and supports_pil else None

image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"

try:
torch.manual_seed(0)
output_legacy_tensor = legacy(image_tensor)
output_legacy_tensor = legacy_transform(image_tensor)
except Exception as exc:
raise pytest.UsageError(
f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
Expand All @@ -505,7 +485,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,

try:
torch.manual_seed(0)
output_prototype_tensor = prototype(image_tensor)
output_prototype_tensor = prototype_transform(image_tensor)
except Exception as exc:
raise AssertionError(
f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
Expand All @@ -521,7 +501,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,

try:
torch.manual_seed(0)
output_prototype_image = prototype(image)
output_prototype_image = prototype_transform(image)
except Exception as exc:
raise AssertionError(
f"Transforming a feature image with shape {image_repr} failed in the prototype transform with "
Expand All @@ -535,10 +515,12 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
)

if image_pil is not None:
if image.ndim == 3 and supports_pil:
image_pil = to_image_pil(image)

try:
torch.manual_seed(0)
output_legacy_pil = legacy(image_pil)
output_legacy_pil = legacy_transform(image_pil)
except Exception as exc:
raise pytest.UsageError(
f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
Expand All @@ -548,7 +530,7 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,

try:
torch.manual_seed(0)
output_prototype_pil = prototype(image_pil)
output_prototype_pil = prototype_transform(image_pil)
except Exception as exc:
raise AssertionError(
f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
Expand All @@ -563,23 +545,116 @@ def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs,
)


@pytest.mark.parametrize(
("prototype_transform_cls", "legacy_transform_cls", "args_kwargs", "make_images_kwargs", "supports_pil"),
itertools.chain.from_iterable(config.parametrization() for config in CONSISTENCY_CONFIGS),
)
def test_consistency(prototype_transform_cls, legacy_transform_cls, args_kwargs, make_images_kwargs, supports_pil):
args, kwargs = args_kwargs

try:
legacy_transform = legacy_transform_cls(*args, **kwargs)
except Exception as exc:
raise pytest.UsageError(
f"Initializing the legacy transform failed with the error above. "
f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
) from exc

try:
prototype_transform = prototype_transform_cls(*args, **kwargs)
except Exception as exc:
raise AssertionError(
"Initializing the prototype transform failed with the error above. "
"This means there is a consistency bug in the constructor."
) from exc

check_consistency(
prototype_transform, legacy_transform, images=make_images(**make_images_kwargs), supports_pil=supports_pil
)


class TestContainerTransforms:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
consistency automatically tests the wrapped transforms consistency.

Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
that were already tested for consistency above.
"""

def test_compose(self):
prototype_transform = prototype_transforms.Compose(
[
prototype_transforms.Resize(256),
prototype_transforms.CenterCrop(224),
]
)
legacy_transform = legacy_transforms.Compose(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
]
)

check_consistency(prototype_transform, legacy_transform)

@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
def test_random_apply(self, p):
prototype_transform = prototype_transforms.RandomApply(
[
prototype_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)
legacy_transform = legacy_transforms.RandomApply(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)

check_consistency(prototype_transform, legacy_transform)

# We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("p", [(0, 1), (1, 0)])
def test_random_choice(self, p):
prototype_transform = prototype_transforms.RandomChoice(
[
prototype_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)
legacy_transform = legacy_transforms.RandomChoice(
[
legacy_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
p=p,
)

check_consistency(prototype_transform, legacy_transform)


class TestToTensorTransforms:
def test_pil_to_tensor(self):
prototype_transform = prototype_transforms.PILToTensor()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to create the transform for ever image.

legacy_transform = legacy_transforms.PILToTensor()

for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image)

prototype_transform = prototype_transforms.PILToTensor()
legacy_transform = legacy_transforms.PILToTensor()

assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))

def test_to_tensor(self):
prototype_transform = prototype_transforms.ToTensor()
legacy_transform = legacy_transforms.ToTensor()

for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image)
image_numpy = np.array(image_pil)

prototype_transform = prototype_transforms.ToTensor()
legacy_transform = legacy_transforms.ToTensor()

assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
25 changes: 14 additions & 11 deletions torchvision/prototype/transforms/_container.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence

import torch
from torchvision.prototype.transforms import Transform

from ._transform import _RandomApplyTransform


class Compose(Transform):
def __init__(self, transforms: Sequence[Callable]) -> None:
Expand All @@ -21,16 +19,21 @@ def forward(self, *inputs: Any) -> Any:
return sample


class RandomApply(_RandomApplyTransform):
def __init__(self, transform: Transform, p: float = 0.5) -> None:
super().__init__(p=p)
self.transform = transform
class RandomApply(Compose):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand why RandomApply needs to support multiple transforms. I've re-added it here for BC, but I think we should deprecate this in favor of what our internal _RandomApplyTransform does. If the user needs multiple transforms, they can simply wrap them in a Compose before passing it to RandomApply.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Maybe, we can extend Compose with p argument to cover RandomApply feature...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now we should keep it and review on the future what we want to deprecate. It's true that many of the transforms can be written as a combination of other transforms. For example, the majority of the Random* transforms that support a p probability could have used the RandomApply. Unfortunately dropping some of these functionalities not only breaks BC but also leads to more verbose code. So we should have a separate discussion over what should be deprecated and why.

def __init__(self, transforms: Sequence[Callable], p: float = 0.5) -> None:
super().__init__(transforms)

if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
self.p = p

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self.transform(inpt)
if torch.rand(1) >= self.p:
return sample

def extra_repr(self) -> str:
return f"p={self.p}"
return super().forward(sample)


class RandomChoice(Transform):
Expand Down