diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index d4d9f1c9614..e9192f44f52 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -69,7 +69,8 @@ def __init__( def _process_inputs(self, actual, expected, *, id, allow_subclasses): actual, expected = [ - to_image_tensor(input) if not isinstance(input, torch.Tensor) else input for input in [actual, expected] + to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input) + for input in [actual, expected] ] # This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL # image to a tensor adds a singleton leading dimension. diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index fac2eb0bd94..9e2e3051189 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -1,5 +1,7 @@ import enum import inspect +import random +from collections import defaultdict from importlib.machinery import SourceFileLoader from pathlib import Path @@ -16,13 +18,15 @@ make_image, make_images, make_label, + make_segmentation_mask, ) from torchvision import transforms as legacy_transforms from torchvision._utils import sequence_to_str from torchvision.prototype import features, transforms as prototype_transforms +from torchvision.prototype.transforms import functional as F +from torchvision.prototype.transforms._utils import query_chw from torchvision.prototype.transforms.functional import to_image_pil - DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) @@ -852,10 +856,12 @@ def test_aa(self, inpt, interpolation): assert_equal(expected_output, output) -# Import reference detection transforms here for consistency checks -# torchvision/references/detection/transforms.py -ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py" -det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module() +def import_transforms_from_references(reference): + ref_det_filepath = Path(__file__).parent.parent / "references" / reference / "transforms.py" + return SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module() + + +det_transforms = import_transforms_from_references("detection") class TestRefDetTransforms: @@ -873,7 +879,7 @@ def make_datapoints(self, with_mask=True): yield (pil_image, target) - tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8) + tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB)) target = { "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -883,7 +889,7 @@ def make_datapoints(self, with_mask=True): yield (tensor_image, target) - feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)) + feature_image = make_image(size=size, color_space=features.ColorSpace.RGB) target = { "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -927,3 +933,165 @@ def test_transform(self, t_ref, t, data_kwargs): expected_output = t_ref(*dp) assert_equal(expected_output, output) + + +seg_transforms = import_transforms_from_references("segmentation") + + +# We need this transform for two reasons: +# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name +# counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True` +# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size. +class PadIfSmaller(prototype_transforms.Transform): + def __init__(self, size, fill=0): + super().__init__() + self.size = size + self.fill = prototype_transforms._geometry._setup_fill_arg(fill) + + def _get_params(self, sample): + _, height, width = query_chw(sample) + padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] + needs_padding = any(padding) + return dict(padding=padding, needs_padding=needs_padding) + + def _transform(self, inpt, params): + if not params["needs_padding"]: + return inpt + + fill = self.fill[type(inpt)] + fill = F._geometry._convert_fill_arg(fill) + + return F.pad(inpt, padding=params["padding"], fill=fill) + + +class TestRefSegTransforms: + def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): + size = (256, 640) + num_categories = 21 + + conv_fns = [] + if supports_pil: + conv_fns.append(to_image_pil) + conv_fns.extend([torch.Tensor, lambda x: x]) + + for conv_fn in conv_fns: + feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype) + feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) + + dp = (conv_fn(feature_image), feature_mask) + dp_ref = ( + to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image), + to_image_pil(feature_mask), + ) + + yield dp, dp_ref + + def set_seed(self, seed=12): + torch.manual_seed(seed) + random.seed(seed) + + def check(self, t, t_ref, data_kwargs=None): + for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()): + + self.set_seed() + output = t(dp) + + self.set_seed() + expected_output = t_ref(*dp_ref) + + assert_equal(output, expected_output) + + @pytest.mark.parametrize( + ("t_ref", "t", "data_kwargs"), + [ + ( + seg_transforms.RandomHorizontalFlip(flip_prob=1.0), + prototype_transforms.RandomHorizontalFlip(p=1.0), + dict(), + ), + ( + seg_transforms.RandomHorizontalFlip(flip_prob=0.0), + prototype_transforms.RandomHorizontalFlip(p=0.0), + dict(), + ), + ( + seg_transforms.RandomCrop(size=480), + prototype_transforms.Compose( + [ + PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})), + prototype_transforms.RandomCrop(size=480), + ] + ), + dict(), + ), + ( + seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + dict(supports_pil=False, image_dtype=torch.float), + ), + ], + ) + def test_common(self, t_ref, t, data_kwargs): + self.check(t, t_ref, data_kwargs) + + def check_resize(self, mocker, t_ref, t): + mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") + mock_ref = mocker.patch("torchvision.transforms.functional.resize") + + for dp, dp_ref in self.make_datapoints(): + mock.reset_mock() + mock_ref.reset_mock() + + self.set_seed() + t(dp) + assert mock.call_count == 2 + assert all( + actual is expected + for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp) + ) + + self.set_seed() + t_ref(*dp_ref) + assert mock_ref.call_count == 2 + assert all( + actual is expected + for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref) + ) + + for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list): + assert args_kwargs[0][1] == [args_kwargs_ref[0][1]] + + def test_random_resize_train(self, mocker): + base_size = 520 + min_size = base_size // 2 + max_size = base_size * 2 + + randint = torch.randint + + def patched_randint(a, b, *other_args, **kwargs): + if kwargs or len(other_args) > 1 or other_args[0] != (): + return randint(a, b, *other_args, **kwargs) + + return random.randint(a, b) + + # We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported + # normally + t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) + mocker.patch( + "torchvision.prototype.transforms._geometry.torch.randint", + new=patched_randint, + ) + + t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size) + + self.check_resize(mocker, t_ref, t) + + def test_random_resize_eval(self, mocker): + torch.manual_seed(0) + base_size = 520 + + t = prototype_transforms.Resize(size=base_size, antialias=True) + + t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size) + + self.check_resize(mocker, t_ref, t)