Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility layer between stable datasets and prototype transforms #6663

Merged
merged 39 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d6786ac
PoC
pmeier Sep 21, 2022
3a916c8
Merge branch 'main'
pmeier Sep 27, 2022
d77ef0b
Merge branch 'main' into dataset-wrappers
pmeier Sep 28, 2022
63e1148
cleanup
pmeier Sep 28, 2022
13a820c
Merge branch 'main' into dataset-wrappers
pmeier Sep 28, 2022
fb600a7
Merge branch 'main' into dataset-wrappers
pmeier Oct 4, 2022
cae3e71
Merge branch 'main' into dataset-wrappers
pmeier Jan 30, 2023
dbfac05
refactor
pmeier Jan 31, 2023
2dba1c7
handle None label for test set use case
pmeier Jan 31, 2023
bcd7620
minor cleanup
pmeier Jan 31, 2023
f72ed86
Merge branch 'main' into dataset-wrappers
pmeier Feb 1, 2023
fe6be60
minor refactorings
pmeier Feb 1, 2023
cff9092
minor cache refactoring for COCO
pmeier Feb 1, 2023
9965492
remove GenericDatapoint for now
pmeier Feb 1, 2023
a588686
Merge branch 'main' into dataset-wrappers
pmeier Feb 2, 2023
d64e1a9
add all detection and segmentation datasets
pmeier Feb 2, 2023
49cc8e7
add Image/DatasetFolder
pmeier Feb 2, 2023
8e12bad
add video datasets
pmeier Feb 2, 2023
7a9f083
nuke annotations
pmeier Feb 2, 2023
7f7efd5
reinstate transform and target_transform disabling
pmeier Feb 2, 2023
e6f2b68
address minor comments
pmeier Feb 3, 2023
4c3860e
Merge branch 'main' into dataset-wrappers
pmeier Feb 6, 2023
58f21f4
Merge branch 'main' into dataset-wrappers
pmeier Feb 9, 2023
22288ce
remove categories and refactor wrapping architecture
pmeier Feb 9, 2023
a88aec3
add tests
pmeier Feb 9, 2023
ce740c1
cleanup
pmeier Feb 9, 2023
edad790
Merge branch 'main' into dataset-wrappers
pmeier Feb 9, 2023
3398822
remove GenericDatapoint
pmeier Feb 9, 2023
b565426
Merge branch 'dataset-wrappers' of https://github.com/pmeier/vision i…
pmeier Feb 9, 2023
a236f9c
Merge branch 'main' into dataset-wrappers
pmeier Feb 9, 2023
331a66d
move wrapper instantiation into the class
pmeier Feb 9, 2023
48405b8
use decorator registering everywhere
pmeier Feb 9, 2023
0286238
hard depend on wrapper in stable tests
pmeier Feb 9, 2023
be42cc9
remove target type wrapping default
pmeier Feb 9, 2023
e3c4d50
make test more strict
pmeier Feb 9, 2023
351becb
fix cityscapes instance return
pmeier Feb 9, 2023
8ed41ba
add comment for two stage design
pmeier Feb 9, 2023
f0e1af7
Merge branch 'main' into dataset-wrappers
pmeier Feb 9, 2023
dbebe40
Merge branch 'main' into dataset-wrappers
pmeier Feb 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torchvision.datasets
import torchvision.io
from common_utils import disable_console_output, get_tmp_dir
from torch.utils._pytree import tree_any
from torchvision.transforms.functional import get_dimensions


Expand Down Expand Up @@ -581,6 +582,28 @@ def test_transforms(self, config):

mock.assert_called()

@test_all_configs
def test_transforms_v2_wrapper(self, config):
# Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs
# to be available with the next release when v2 is released. Thus, if this import somehow fails on the release
# branch, we screwed up the roll-out
from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2
from torchvision.prototype.datapoints._datapoint import Datapoint

try:
with self.create_dataset(config) as (dataset, _):
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error:
if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"):
return
raise error
except RuntimeError as error:
if "currently not supported by this wrapper" in str(error):
return
raise error


class ImageDatasetTestCase(DatasetTestCase):
"""Abstract base class for image dataset testcases.
Expand Down Expand Up @@ -662,6 +685,15 @@ def wrapper(tmpdir, config):

return wrapper

@test_all_configs
def test_transforms_v2_wrapper(self, config):
# `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
# or use the supported `"TCHW"`
if config.setdefault("output_format", "TCHW") == "THWC":
return

super().test_transforms_v2_wrapper.__wrapped__(self, config)


def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor.
Expand Down
18 changes: 13 additions & 5 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,11 +763,19 @@ def _create_annotation_file(self, root, name, file_names, num_annotations_per_im
return info

def _create_annotations(self, image_ids, num_annotations_per_image):
annotations = datasets_utils.combinations_grid(
image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image
)
for id, annotation in enumerate(annotations):
annotation["id"] = id
annotations = []
annotion_id = 0
for image_id in itertools.islice(itertools.cycle(image_ids), len(image_ids) * num_annotations_per_image):
annotations.append(
dict(
image_id=image_id,
id=annotion_id,
bbox=torch.rand(4).tolist(),
segmentation=[torch.rand(8).tolist()],
category_id=int(torch.randint(91, ())),
pmeier marked this conversation as resolved.
Show resolved Hide resolved
)
)
annotion_id += 1
return annotations, dict()

def _create_json(self, root, name, content):
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datapoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
from ._label import Label, OneHotLabel
from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT

from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
Loading