From 60110cb6e9f490798f8e89b26a8637db8fcf5c15 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 21 Aug 2023 14:13:24 +0200 Subject: [PATCH 1/9] enforce pickleability for v2 transforms and wrapped datasets --- test/datasets_utils.py | 50 +++++++++++++--------- test/test_transforms_v2.py | 7 ++- test/test_transforms_v2_refactored.py | 3 ++ torchvision/datapoints/_dataset_wrapper.py | 9 +++- 4 files changed, 46 insertions(+), 23 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index b6f22d766df..4281b8184d2 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -5,6 +5,7 @@ import itertools import os import pathlib +import pickle import random import shutil import string @@ -572,35 +573,42 @@ def test_transforms_v2_wrapper(self, config): try: with self.create_dataset(config) as (dataset, info): - for target_keys in [None, "all"]: - if target_keys is not None and self.DATASET_CLASS not in { - torchvision.datasets.CocoDetection, - torchvision.datasets.VOCDetection, - torchvision.datasets.Kitti, - torchvision.datasets.WIDERFace, - }: - with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"): - wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) - continue - - wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) - assert isinstance(wrapped_dataset, self.DATASET_CLASS) - assert len(wrapped_dataset) == info["num_examples"] - - wrapped_sample = wrapped_dataset[0] - assert tree_any( - lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample - ) + wrap_dataset_for_transforms_v2(dataset) except TypeError as error: msg = f"No wrapper exists for dataset class {type(dataset).__name__}" if str(error).startswith(msg): - pytest.skip(msg) + return raise error except RuntimeError as error: if "currently not supported by this wrapper" in str(error): - pytest.skip("Config is currently not supported by this wrapper") + return raise error + for target_keys, de_serialize in itertools.product( + [None, "all"], [lambda d: d, lambda d: pickle.loads(pickle.dumps(d))] + ): + + with self.create_dataset(config) as (dataset, info): + if target_keys is not None and self.DATASET_CLASS not in { + torchvision.datasets.CocoDetection, + torchvision.datasets.VOCDetection, + torchvision.datasets.Kitti, + torchvision.datasets.WIDERFace, + }: + with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"): + wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) + continue + + wrapped_dataset = de_serialize(wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)) + + assert isinstance(wrapped_dataset, self.DATASET_CLASS) + assert len(wrapped_dataset) == info["num_examples"] + + wrapped_sample = wrapped_dataset[0] + assert tree_any( + lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample + ) + class ImageDatasetTestCase(DatasetTestCase): """Abstract base class for image dataset testcases. diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 9630132e271..2eee463e0c1 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1,5 +1,6 @@ import itertools import pathlib +import pickle import random import warnings @@ -169,8 +170,11 @@ class TestSmoke: next(make_vanilla_tensor_images()), ], ) + @pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_common(self, transform, adapter, container_type, image_or_video, device): + def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device): + transform = de_serialize(transform) + canvas_size = F.get_size(image_or_video) input = dict( image_or_video=image_or_video, @@ -234,6 +238,7 @@ def test_common(self, transform, adapter, container_type, image_or_video, device tensor=torch.empty(5), array=np.empty(5), ) + if adapter is not None: input = adapter(transform, input, device) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index f57736e5abd..09c1bc33e16 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2,6 +2,7 @@ import decimal import inspect import math +import pickle import re from pathlib import Path from unittest import mock @@ -247,6 +248,8 @@ def _check_transform_v1_compatibility(transform, input): def check_transform(transform_cls, input, *args, **kwargs): transform = transform_cls(*args, **kwargs) + pickle.loads(pickle.dumps(transform)) + output = transform(input) assert isinstance(output, type(input)) diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index 3f1c41debf5..f423ba38cee 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -3,7 +3,6 @@ from __future__ import annotations import collections.abc - import contextlib from collections import defaultdict @@ -97,6 +96,10 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None): f"but got {target_keys}" ) + return _make_wrapped_dataset(dataset, target_keys) + + +def _make_wrapped_dataset(dataset, target_keys): # Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name # "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the # original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks, @@ -162,6 +165,7 @@ def __init__(self, dataset, target_keys): raise TypeError(msg) self._dataset = dataset + self._target_keys = target_keys self._wrapper = wrapper_factory(dataset, target_keys) # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them. @@ -197,6 +201,9 @@ def __getitem__(self, idx): def __len__(self): return len(self._dataset) + def __reduce__(self): + return _make_wrapped_dataset, (self._dataset, self._target_keys) + def raise_not_supported(description): raise RuntimeError( From c68a6deb82be55f6c829ca687a3669cb85a775d1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Aug 2023 10:33:33 +0200 Subject: [PATCH 2/9] use DataLoader for testing on select configs --- test/datasets_utils.py | 86 ++++++++++++++++-------------------------- test/test_datasets.py | 68 +++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 53 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 4281b8184d2..ead07edd6cd 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -3,9 +3,9 @@ import importlib import inspect import itertools +import multiprocessing import os import pathlib -import pickle import random import shutil import string @@ -171,6 +171,38 @@ def wrapper(self): return wrapper +def _no_collate(batch): + return batch + + +def check_transforms_v2_wrapper(dataset_test_case, *, config=None, supports_target_keys=False): + from torch.utils.data import DataLoader + from torchvision import datapoints + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + target_keyss = [None] + if supports_target_keys: + target_keyss.append("all") + + for target_keys, multiprocessing_context in itertools.product( + target_keyss, multiprocessing.get_all_start_methods() + ): + with dataset_test_case.create_dataset(config) as (dataset, info): + wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) + + assert isinstance(wrapped_dataset, type(dataset)) + assert len(wrapped_dataset) == info["num_examples"] + + dataloader = DataLoader( + wrapped_dataset, num_workers=2, multiprocessing_context=multiprocessing_context, collate_fn=_no_collate + ) + + for wrapped_sample in dataloader: + assert tree_any( + lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample + ) + + class DatasetTestCase(unittest.TestCase): """Abstract base class for all dataset testcases. @@ -566,49 +598,6 @@ def test_transforms(self, config): mock.assert_called() - @test_all_configs - def test_transforms_v2_wrapper(self, config): - from torchvision import datapoints - from torchvision.datasets import wrap_dataset_for_transforms_v2 - - try: - with self.create_dataset(config) as (dataset, info): - wrap_dataset_for_transforms_v2(dataset) - except TypeError as error: - msg = f"No wrapper exists for dataset class {type(dataset).__name__}" - if str(error).startswith(msg): - return - raise error - except RuntimeError as error: - if "currently not supported by this wrapper" in str(error): - return - raise error - - for target_keys, de_serialize in itertools.product( - [None, "all"], [lambda d: d, lambda d: pickle.loads(pickle.dumps(d))] - ): - - with self.create_dataset(config) as (dataset, info): - if target_keys is not None and self.DATASET_CLASS not in { - torchvision.datasets.CocoDetection, - torchvision.datasets.VOCDetection, - torchvision.datasets.Kitti, - torchvision.datasets.WIDERFace, - }: - with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"): - wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) - continue - - wrapped_dataset = de_serialize(wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)) - - assert isinstance(wrapped_dataset, self.DATASET_CLASS) - assert len(wrapped_dataset) == info["num_examples"] - - wrapped_sample = wrapped_dataset[0] - assert tree_any( - lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample - ) - class ImageDatasetTestCase(DatasetTestCase): """Abstract base class for image dataset testcases. @@ -690,15 +679,6 @@ def wrapper(tmpdir, config): return wrapper - @test_all_configs - def test_transforms_v2_wrapper(self, config): - # `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly - # or use the supported `"TCHW"` - if config.setdefault("output_format", "TCHW") == "THWC": - return - - super().test_transforms_v2_wrapper.__wrapped__(self, config) - def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: r"""Create a random uint8 tensor. diff --git a/test/test_datasets.py b/test/test_datasets.py index ed6aa17d3f9..b3088686ba3 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -183,6 +183,9 @@ def test_combined_targets(self): ), "Type of the combined target does not match the type of the corresponding individual target: " f"{actual} is not {expected}", + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type="category")) + class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Caltech256 @@ -203,6 +206,9 @@ def inject_fake_data(self, tmpdir, config): return num_images_per_category * len(categories) + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self) + class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.WIDERFace @@ -258,6 +264,9 @@ def inject_fake_data(self, tmpdir, config): return split_to_num_examples[config["split"]] + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True) + class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Cityscapes @@ -382,6 +391,10 @@ def test_feature_types_target_polygon(self): assert isinstance(polygon_img, PIL.Image.Image) (polygon_target, info["expected_polygon_target"]) + def test_transforms_v2_wrapper(self): + for target_type in ["instance", "semantic", ["instance", "semantic"]]: + datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type=target_type)) + class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.ImageNet @@ -413,6 +426,9 @@ def inject_fake_data(self, tmpdir, config): torch.save((wnid_to_classes, None), tmpdir / "meta.bin") return num_examples + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self) + class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.CIFAR10 @@ -470,6 +486,9 @@ def test_class_to_idx(self): actual = dataset.class_to_idx assert actual == expected + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self) + class CIFAR100(CIFAR10TestCase): DATASET_CLASS = datasets.CIFAR100 @@ -484,6 +503,9 @@ class CIFAR100(CIFAR10TestCase): categories_key="fine_label_names", ) + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self) + class CelebATestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.CelebA @@ -607,6 +629,10 @@ def test_images_names_split(self): assert merged_imgs_names == all_imgs_names + def test_transforms_v2_wrapper(self): + for target_type in ["identity", "bbox", ["identity", "bbox"]]: + datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type=target_type)) + class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.VOCSegmentation @@ -694,6 +720,9 @@ def add_bndbox(obj, bndbox=None): return data + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self) + class VOCDetectionTestCase(VOCSegmentationTestCase): DATASET_CLASS = datasets.VOCDetection @@ -714,6 +743,10 @@ def test_annotations(self): assert object == info["annotation"] + def test_transforms_v2_wrapper(self): + for target_type in ["identity", "bbox", ["identity", "bbox"]]: + datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True) + class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.CocoDetection @@ -784,6 +817,9 @@ def _create_json(self, root, name, content): json.dump(content, fh) return file + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True) + class CocoCaptionsTestCase(CocoDetectionTestCase): DATASET_CLASS = datasets.CocoCaptions @@ -800,6 +836,11 @@ def test_captions(self): _, captions = dataset[0] assert tuple(captions) == tuple(info["captions"]) + def test_transforms_v2_wrapper(self): + # We need to define this method, because otherwise the test from the super class will + # be run + pytest.skip("CocoCaptions is currently not supported by the v2 wrapper.") + class UCF101TestCase(datasets_utils.VideoDatasetTestCase): DATASET_CLASS = datasets.UCF101 @@ -860,6 +901,9 @@ def _create_annotation_file(self, root, name, video_files): with open(pathlib.Path(root) / name, "w") as fh: fh.writelines(f"{str(file).replace(os.sep, '/')}\n" for file in sorted(video_files)) + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self, config=dict(output_format="TCHW")) + class LSUNTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.LSUN @@ -966,6 +1010,9 @@ def inject_fake_data(self, tmpdir, config): ) return num_videos_per_class * len(classes) + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self, config=dict(output_format="TCHW")) + class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): DATASET_CLASS = datasets.HMDB51 @@ -1026,6 +1073,9 @@ def _create_split_files(self, root, video_files, fold, train): return num_train_videos if train else (num_videos - num_train_videos) + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self, config=dict(output_format="TCHW")) + class OmniglotTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Omniglot @@ -1193,6 +1243,9 @@ def _create_segmentation(self, size): def _file_stem(self, idx): return f"2008_{idx:06d}" + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self, config=dict(mode="segmentation")) + class FakeDataTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.FakeData @@ -1434,6 +1487,9 @@ def _magic(self, dtype, dims): def _encode(self, v): return torch.tensor(v, dtype=torch.int32).numpy().tobytes()[::-1] + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self) + class FashionMNISTTestCase(MNISTTestCase): DATASET_CLASS = datasets.FashionMNIST @@ -1585,6 +1641,9 @@ def test_classes(self, config): assert len(dataset.classes) == len(info["classes"]) assert all([a == b for a, b in zip(dataset.classes, info["classes"])]) + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self) + class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.ImageFolder @@ -1606,6 +1665,9 @@ def test_classes(self, config): assert len(dataset.classes) == len(info["classes"]) assert all([a == b for a, b in zip(dataset.classes, info["classes"])]) + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self) + class KittiTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Kitti @@ -1642,6 +1704,9 @@ def inject_fake_data(self, tmpdir, config): return split_to_num_examples[config["train"]] + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True) + class SvhnTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.SVHN @@ -2516,6 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx): breed_id = "-1" return (image_id, class_id, species, breed_id) + def test_transforms_v2_wrapper(self): + datasets_utils.check_transforms_v2_wrapper(self) + class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.StanfordCars From d228bdbc43d55331fd67a7e6c7ee79b521069635 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Aug 2023 10:44:23 +0200 Subject: [PATCH 3/9] cleanup --- torchvision/datapoints/_dataset_wrapper.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index f423ba38cee..cbcb5a5027a 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -96,10 +96,6 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None): f"but got {target_keys}" ) - return _make_wrapped_dataset(dataset, target_keys) - - -def _make_wrapped_dataset(dataset, target_keys): # Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name # "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the # original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks, @@ -202,7 +198,7 @@ def __len__(self): return len(self._dataset) def __reduce__(self): - return _make_wrapped_dataset, (self._dataset, self._target_keys) + return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys) def raise_not_supported(description): From 1efe583c263ef1cb016a48e850fa003d525d0edc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Aug 2023 10:48:34 +0200 Subject: [PATCH 4/9] cleanup --- test/test_transforms_v2.py | 1 - torchvision/datapoints/_dataset_wrapper.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 2eee463e0c1..afba0a4997f 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -238,7 +238,6 @@ def test_common(self, transform, adapter, container_type, image_or_video, de_ser tensor=torch.empty(5), array=np.empty(5), ) - if adapter is not None: input = adapter(transform, input, device) diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index cbcb5a5027a..de84d3e6ec6 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -3,6 +3,7 @@ from __future__ import annotations import collections.abc + import contextlib from collections import defaultdict From 535862079dc78ad444d62a5c0c1a04a392643a19 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Aug 2023 10:59:15 +0200 Subject: [PATCH 5/9] streamline v2 check --- test/datasets_utils.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index ead07edd6cd..484b2170fe4 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -3,7 +3,6 @@ import importlib import inspect import itertools -import multiprocessing import os import pathlib import random @@ -180,27 +179,30 @@ def check_transforms_v2_wrapper(dataset_test_case, *, config=None, supports_targ from torchvision import datapoints from torchvision.datasets import wrap_dataset_for_transforms_v2 + def check_wrapped_samples(dataset): + for wrapped_sample in dataset: + assert tree_any( + lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample + ) + target_keyss = [None] if supports_target_keys: target_keyss.append("all") - for target_keys, multiprocessing_context in itertools.product( - target_keyss, multiprocessing.get_all_start_methods() - ): + for target_keys in target_keyss: with dataset_test_case.create_dataset(config) as (dataset, info): wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) assert isinstance(wrapped_dataset, type(dataset)) assert len(wrapped_dataset) == info["num_examples"] - dataloader = DataLoader( - wrapped_dataset, num_workers=2, multiprocessing_context=multiprocessing_context, collate_fn=_no_collate - ) + check_wrapped_samples(wrapped_dataset) - for wrapped_sample in dataloader: - assert tree_any( - lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample - ) + with dataset_test_case.create_dataset(config) as (dataset, _): + wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) + dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate) + + check_wrapped_samples(dataloader) class DatasetTestCase(unittest.TestCase): From af66bd005ea82361de846bb5d85c65720c7b8643 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Aug 2023 11:12:48 +0200 Subject: [PATCH 6/9] run DataLoader test only on macOS --- test/datasets_utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 484b2170fe4..e946f77d224 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -5,6 +5,7 @@ import itertools import os import pathlib +import platform import random import shutil import string @@ -198,11 +199,16 @@ def check_wrapped_samples(dataset): check_wrapped_samples(wrapped_dataset) - with dataset_test_case.create_dataset(config) as (dataset, _): - wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) - dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate) + # On macOS, forking for multiprocessing is not available and thus spawning is used by default. For this to work, + # the whole pipeline including the dataset needs to be pickleable, which is what we are enforcing here. + if platform.system() == "Darwin": + with dataset_test_case.create_dataset(config) as (dataset, _): + wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) + dataloader = DataLoader( + wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate + ) - check_wrapped_samples(dataloader) + check_wrapped_samples(dataloader) class DatasetTestCase(unittest.TestCase): From f339e6c2adc62d2b6acd985e15b5d5af5bbcdfa9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Aug 2023 11:15:16 +0200 Subject: [PATCH 7/9] only run v2 checks once per group --- test/test_datasets.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index b3088686ba3..05f0f72d291 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -206,9 +206,6 @@ def inject_fake_data(self, tmpdir, config): return num_images_per_category * len(categories) - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self) - class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.WIDERFace @@ -486,9 +483,6 @@ def test_class_to_idx(self): actual = dataset.class_to_idx assert actual == expected - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self) - class CIFAR100(CIFAR10TestCase): DATASET_CLASS = datasets.CIFAR100 @@ -503,9 +497,6 @@ class CIFAR100(CIFAR10TestCase): categories_key="fine_label_names", ) - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self) - class CelebATestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.CelebA @@ -901,9 +892,6 @@ def _create_annotation_file(self, root, name, video_files): with open(pathlib.Path(root) / name, "w") as fh: fh.writelines(f"{str(file).replace(os.sep, '/')}\n" for file in sorted(video_files)) - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self, config=dict(output_format="TCHW")) - class LSUNTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.LSUN @@ -1073,9 +1061,6 @@ def _create_split_files(self, root, video_files, fold, train): return num_train_videos if train else (num_videos - num_train_videos) - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self, config=dict(output_format="TCHW")) - class OmniglotTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Omniglot @@ -1487,9 +1472,6 @@ def _magic(self, dtype, dims): def _encode(self, v): return torch.tensor(v, dtype=torch.int32).numpy().tobytes()[::-1] - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self) - class FashionMNISTTestCase(MNISTTestCase): DATASET_CLASS = datasets.FashionMNIST @@ -1641,9 +1623,6 @@ def test_classes(self, config): assert len(dataset.classes) == len(info["classes"]) assert all([a == b for a, b in zip(dataset.classes, info["classes"])]) - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self) - class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.ImageFolder @@ -1665,9 +1644,6 @@ def test_classes(self, config): assert len(dataset.classes) == len(info["classes"]) assert all([a == b for a, b in zip(dataset.classes, info["classes"])]) - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self) - class KittiTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Kitti From 1e35ee7e53e437e65b76d36b338ea5c5584d401f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Aug 2023 13:47:36 +0200 Subject: [PATCH 8/9] reinstate old test --- test/datasets_utils.py | 110 ++++++++++++++++++++++++++--------------- test/test_datasets.py | 63 +++++++++++++---------- 2 files changed, 107 insertions(+), 66 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index e946f77d224..50a742acf1f 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -171,46 +171,6 @@ def wrapper(self): return wrapper -def _no_collate(batch): - return batch - - -def check_transforms_v2_wrapper(dataset_test_case, *, config=None, supports_target_keys=False): - from torch.utils.data import DataLoader - from torchvision import datapoints - from torchvision.datasets import wrap_dataset_for_transforms_v2 - - def check_wrapped_samples(dataset): - for wrapped_sample in dataset: - assert tree_any( - lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample - ) - - target_keyss = [None] - if supports_target_keys: - target_keyss.append("all") - - for target_keys in target_keyss: - with dataset_test_case.create_dataset(config) as (dataset, info): - wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) - - assert isinstance(wrapped_dataset, type(dataset)) - assert len(wrapped_dataset) == info["num_examples"] - - check_wrapped_samples(wrapped_dataset) - - # On macOS, forking for multiprocessing is not available and thus spawning is used by default. For this to work, - # the whole pipeline including the dataset needs to be pickleable, which is what we are enforcing here. - if platform.system() == "Darwin": - with dataset_test_case.create_dataset(config) as (dataset, _): - wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) - dataloader = DataLoader( - wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate - ) - - check_wrapped_samples(dataloader) - - class DatasetTestCase(unittest.TestCase): """Abstract base class for all dataset testcases. @@ -606,6 +566,42 @@ def test_transforms(self, config): mock.assert_called() + @test_all_configs + def test_transforms_v2_wrapper(self, config): + from torchvision import datapoints + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + try: + with self.create_dataset(config) as (dataset, info): + for target_keys in [None, "all"]: + if target_keys is not None and self.DATASET_CLASS not in { + torchvision.datasets.CocoDetection, + torchvision.datasets.VOCDetection, + torchvision.datasets.Kitti, + torchvision.datasets.WIDERFace, + }: + with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"): + wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) + continue + + wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) + assert isinstance(wrapped_dataset, self.DATASET_CLASS) + assert len(wrapped_dataset) == info["num_examples"] + + wrapped_sample = wrapped_dataset[0] + assert tree_any( + lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample + ) + except TypeError as error: + msg = f"No wrapper exists for dataset class {type(dataset).__name__}" + if str(error).startswith(msg): + pytest.skip(msg) + raise error + except RuntimeError as error: + if "currently not supported by this wrapper" in str(error): + pytest.skip("Config is currently not supported by this wrapper") + raise error + class ImageDatasetTestCase(DatasetTestCase): """Abstract base class for image dataset testcases. @@ -687,6 +683,40 @@ def wrapper(tmpdir, config): return wrapper + @test_all_configs + def test_transforms_v2_wrapper(self, config): + # `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly + # or use the supported `"TCHW"` + if config.setdefault("output_format", "TCHW") == "THWC": + return + + super().test_transforms_v2_wrapper.__wrapped__(self, config) + + +def _no_collate(batch): + return batch + + +def check_transforms_v2_wrapper_spawn(dataset): + # On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new + # subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what + # we are enforcing here. + if platform.system() != "Darwin": + pytest.skip("Multiprocessing spawning is only checked on macOS.") + + from torch.utils.data import DataLoader + from torchvision import datapoints + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) + + dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate) + + for wrapped_sample in dataloader: + assert tree_any( + lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample + ) + def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: r"""Create a random uint8 tensor. diff --git a/test/test_datasets.py b/test/test_datasets.py index 05f0f72d291..304ad756579 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -183,8 +183,9 @@ def test_combined_targets(self): ), "Type of the combined target does not match the type of the corresponding individual target: " f"{actual} is not {expected}", - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type="category")) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset(target_type="category") as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): @@ -261,8 +262,9 @@ def inject_fake_data(self, tmpdir, config): return split_to_num_examples[config["split"]] - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): @@ -388,9 +390,10 @@ def test_feature_types_target_polygon(self): assert isinstance(polygon_img, PIL.Image.Image) (polygon_target, info["expected_polygon_target"]) - def test_transforms_v2_wrapper(self): + def test_transforms_v2_wrapper_spawn(self): for target_type in ["instance", "semantic", ["instance", "semantic"]]: - datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type=target_type)) + with self.create_dataset(target_type=target_type) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): @@ -423,8 +426,9 @@ def inject_fake_data(self, tmpdir, config): torch.save((wnid_to_classes, None), tmpdir / "meta.bin") return num_examples - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): @@ -620,9 +624,10 @@ def test_images_names_split(self): assert merged_imgs_names == all_imgs_names - def test_transforms_v2_wrapper(self): + def test_transforms_v2_wrapper_spawn(self): for target_type in ["identity", "bbox", ["identity", "bbox"]]: - datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type=target_type)) + with self.create_dataset(target_type=target_type) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): @@ -711,8 +716,9 @@ def add_bndbox(obj, bndbox=None): return data - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class VOCDetectionTestCase(VOCSegmentationTestCase): @@ -734,9 +740,9 @@ def test_annotations(self): assert object == info["annotation"] - def test_transforms_v2_wrapper(self): - for target_type in ["identity", "bbox", ["identity", "bbox"]]: - datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): @@ -808,8 +814,9 @@ def _create_json(self, root, name, content): json.dump(content, fh) return file - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class CocoCaptionsTestCase(CocoDetectionTestCase): @@ -827,7 +834,7 @@ def test_captions(self): _, captions = dataset[0] assert tuple(captions) == tuple(info["captions"]) - def test_transforms_v2_wrapper(self): + def test_transforms_v2_wrapper_spawn(self): # We need to define this method, because otherwise the test from the super class will # be run pytest.skip("CocoCaptions is currently not supported by the v2 wrapper.") @@ -998,8 +1005,9 @@ def inject_fake_data(self, tmpdir, config): ) return num_videos_per_class * len(classes) - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self, config=dict(output_format="TCHW")) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset(output_format="TCHW") as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): @@ -1228,8 +1236,9 @@ def _create_segmentation(self, size): def _file_stem(self, idx): return f"2008_{idx:06d}" - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self, config=dict(mode="segmentation")) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset(mode="segmentation") as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class FakeDataTestCase(datasets_utils.ImageDatasetTestCase): @@ -1680,8 +1689,9 @@ def inject_fake_data(self, tmpdir, config): return split_to_num_examples[config["train"]] - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class SvhnTestCase(datasets_utils.ImageDatasetTestCase): @@ -2557,8 +2567,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx): breed_id = "-1" return (image_id, class_id, species, breed_id) - def test_transforms_v2_wrapper(self): - datasets_utils.check_transforms_v2_wrapper(self) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): From edc043e5d0b78b7315c6856e9fff60be7ea9083e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Aug 2023 14:29:20 +0200 Subject: [PATCH 9/9] fix broken tests --- test/datasets_utils.py | 2 +- test/test_datasets.py | 2 +- torchvision/datasets/widerface.py | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 50a742acf1f..8afc6ddb369 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -549,7 +549,7 @@ def test_feature_types(self, config): @test_all_configs def test_num_examples(self, config): with self.create_dataset(config) as (dataset, info): - assert len(dataset) == info["num_examples"] + assert len(list(dataset)) == len(dataset) == info["num_examples"] @test_all_configs def test_transforms(self, config): diff --git a/test/test_datasets.py b/test/test_datasets.py index 304ad756579..265316264f8 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -194,7 +194,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories" - categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter")) + categories = ((1, "ak47"), (2, "american-flag"), (3, "backpack")) num_images_per_category = 2 for idx, category in categories: diff --git a/torchvision/datasets/widerface.py b/torchvision/datasets/widerface.py index b46c7982d8b..aa520455ef1 100644 --- a/torchvision/datasets/widerface.py +++ b/torchvision/datasets/widerface.py @@ -137,13 +137,13 @@ def parse_train_val_annotations_file(self) -> None: { "img_path": img_path, "annotations": { - "bbox": labels_tensor[:, 0:4], # x, y, width, height - "blur": labels_tensor[:, 4], - "expression": labels_tensor[:, 5], - "illumination": labels_tensor[:, 6], - "occlusion": labels_tensor[:, 7], - "pose": labels_tensor[:, 8], - "invalid": labels_tensor[:, 9], + "bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height + "blur": labels_tensor[:, 4].clone(), + "expression": labels_tensor[:, 5].clone(), + "illumination": labels_tensor[:, 6].clone(), + "occlusion": labels_tensor[:, 7].clone(), + "pose": labels_tensor[:, 8].clone(), + "invalid": labels_tensor[:, 9].clone(), }, } )