Skip to content

Commit

Permalink
Added wrap_dataset_for_transforms_v2 into datasets and handled beta w… (
Browse files Browse the repository at this point in the history
#7279)

Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
vfdev-5 and NicolasHug authored Feb 17, 2023
1 parent 56b0497 commit ac1512b
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 49 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/test-linux-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
# Create Conda Env
conda create -yp ci_env python="${PYTHON_VERSION}" numpy libpng jpeg scipy
conda activate /work/ci_env
# Install PyTorch, Torchvision, and testing libraries
set -ex
conda install \
Expand All @@ -55,3 +55,9 @@ jobs:
# Run Tests
python3 -m torch.utils.collect_env
python3 -m pytest --junitxml=test-results/junit.xml -v --durations 20
# Specific test for warnings on "from torchvision.datasets import wrap_dataset_for_transforms_v2"
# We keep them separate to avoid any side effects due to warnings / imports.
# TODO: Remove this and add proper tests (possibly using a sub-process solution as described
# in https://github.com/pytorch/vision/pull/7269).
python3 -m pytest -v test/check_v2_dataset_warnings.py
19 changes: 19 additions & 0 deletions test/check_v2_dataset_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest


def test_warns_if_imported_from_datasets(mocker):
mocker.patch("torchvision._WARN_ABOUT_BETA_TRANSFORMS", return_value=True)

import torchvision

with pytest.warns(UserWarning, match=torchvision._BETA_TRANSFORMS_WARNING):
from torchvision.datasets import wrap_dataset_for_transforms_v2

assert callable(wrap_dataset_for_transforms_v2)


@pytest.mark.filterwarnings("error")
def test_no_warns_if_imported_from_datasets():
from torchvision.datasets import wrap_dataset_for_transforms_v2

assert callable(wrap_dataset_for_transforms_v2)
2 changes: 1 addition & 1 deletion test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,8 @@ def test_transforms(self, config):

@test_all_configs
def test_transforms_v2_wrapper(self, config):
from torchvision.datapoints import wrap_dataset_for_transforms_v2
from torchvision.datapoints._datapoint import Datapoint
from torchvision.datasets import wrap_dataset_for_transforms_v2

try:
with self.create_dataset(config) as (dataset, _):
Expand Down
43 changes: 43 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pathlib
import pickle
import random
import re
import shutil
import string
import unittest
Expand Down Expand Up @@ -3309,5 +3310,47 @@ def test_bad_input(self):
pass


class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datasets.wrap_dataset_for_transforms_v2(unknown_object)

def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass

dataset = MyVisionDataset("root")

with pytest.raises(TypeError, match="No wrapper exist"):
datasets.wrap_dataset_for_transforms_v2(dataset)

def test_missing_wrapper(self):
dataset = datasets.FakeData()

with pytest.raises(TypeError, match="please open an issue"):
datasets.wrap_dataset_for_transforms_v2(dataset)

def test_subclass(self, mocker):
from torchvision import datapoints

sentinel = object()
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
)

class MyFakeData(datasets.FakeData):
pass

dataset = MyFakeData()
wrapped_dataset = datasets.wrap_dataset_for_transforms_v2(dataset)

assert wrapped_dataset[0] is sentinel


if __name__ == "__main__":
unittest.main()
44 changes: 1 addition & 43 deletions test/test_prototype_datapoints.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import re

import pytest
import torch

from PIL import Image

from torchvision import datapoints, datasets
from torchvision import datapoints
from torchvision.prototype import datapoints as proto_datapoints


Expand Down Expand Up @@ -163,43 +161,3 @@ def test_bbox_instance(data, format):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat.from_str(format.upper())
assert bboxes.format == format


class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datapoints.wrap_dataset_for_transforms_v2(unknown_object)

def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass

dataset = MyVisionDataset("root")

with pytest.raises(TypeError, match="No wrapper exist"):
datapoints.wrap_dataset_for_transforms_v2(dataset)

def test_missing_wrapper(self):
dataset = datasets.FakeData()

with pytest.raises(TypeError, match="please open an issue"):
datapoints.wrap_dataset_for_transforms_v2(dataset)

def test_subclass(self, mocker):
sentinel = object()
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
)

class MyFakeData(datasets.FakeData):
pass

dataset = MyFakeData()
wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset)

assert wrapped_dataset[0] is sentinel
6 changes: 2 additions & 4 deletions torchvision/datapoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS

from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
from ._mask import Mask
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video

from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip

from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS

if _WARN_ABOUT_BETA_TRANSFORMS:
import warnings

Expand Down
15 changes: 15 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,18 @@
"InStereo2k",
"ETH3DStereo",
)


# We override current module's attributes to handle the import:
# from torchvision.datasets import wrap_dataset_for_transforms_v2
# with beta state v2 warning from torchvision.datapoints
# We also want to avoid raising the warning when importing other attributes
# from torchvision.datasets
# Ref: https://peps.python.org/pep-0562/
def __getattr__(name):
if name in ("wrap_dataset_for_transforms_v2",):
from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2

return wrap_dataset_for_transforms_v2

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

0 comments on commit ac1512b

Please sign in to comment.