Skip to content

Commit

Permalink
[PoC] compatibility layer between stable datasets and prototype trans…
Browse files Browse the repository at this point in the history
…forms (#6663)
  • Loading branch information
pmeier authored Feb 10, 2023
1 parent 17088a6 commit a9d2572
Show file tree
Hide file tree
Showing 4 changed files with 446 additions and 5 deletions.
32 changes: 32 additions & 0 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torchvision.datasets
import torchvision.io
from common_utils import disable_console_output, get_tmp_dir
from torch.utils._pytree import tree_any
from torchvision.transforms.functional import get_dimensions


Expand Down Expand Up @@ -581,6 +582,28 @@ def test_transforms(self, config):

mock.assert_called()

@test_all_configs
def test_transforms_v2_wrapper(self, config):
# Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs
# to be available with the next release when v2 is released. Thus, if this import somehow fails on the release
# branch, we screwed up the roll-out
from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2
from torchvision.prototype.datapoints._datapoint import Datapoint

try:
with self.create_dataset(config) as (dataset, _):
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error:
if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"):
return
raise error
except RuntimeError as error:
if "currently not supported by this wrapper" in str(error):
return
raise error


class ImageDatasetTestCase(DatasetTestCase):
"""Abstract base class for image dataset testcases.
Expand Down Expand Up @@ -662,6 +685,15 @@ def wrapper(tmpdir, config):

return wrapper

@test_all_configs
def test_transforms_v2_wrapper(self, config):
# `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
# or use the supported `"TCHW"`
if config.setdefault("output_format", "TCHW") == "THWC":
return

super().test_transforms_v2_wrapper.__wrapped__(self, config)


def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor.
Expand Down
18 changes: 13 additions & 5 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,11 +763,19 @@ def _create_annotation_file(self, root, name, file_names, num_annotations_per_im
return info

def _create_annotations(self, image_ids, num_annotations_per_image):
annotations = datasets_utils.combinations_grid(
image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image
)
for id, annotation in enumerate(annotations):
annotation["id"] = id
annotations = []
annotion_id = 0
for image_id in itertools.islice(itertools.cycle(image_ids), len(image_ids) * num_annotations_per_image):
annotations.append(
dict(
image_id=image_id,
id=annotion_id,
bbox=torch.rand(4).tolist(),
segmentation=[torch.rand(8).tolist()],
category_id=int(torch.randint(91, ())),
)
)
annotion_id += 1
return annotations, dict()

def _create_json(self, root, name, content):
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datapoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
from ._label import Label, OneHotLabel
from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT

from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
Loading

0 comments on commit a9d2572

Please sign in to comment.