|
7 | 7 | import torch |
8 | 8 | from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS |
9 | 9 | from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair |
| 10 | +from torch.utils.data import DataLoader |
10 | 11 | from torch.utils.data.graph import traverse |
11 | 12 | from torch.utils.data.graph_settings import get_all_graph_pipes |
12 | 13 | from torchdata.datapipes.iter import Shuffler, ShardingFilter |
@@ -109,19 +110,39 @@ def test_transformable(self, test_home, dataset_mock, config): |
109 | 110 |
|
110 | 111 | next(iter(dataset.map(transforms.Identity()))) |
111 | 112 |
|
112 | | - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") |
| 113 | + @pytest.mark.parametrize("only_datapipe", [False, True]) |
113 | 114 | @parametrize_dataset_mocks(DATASET_MOCKS) |
114 | | - def test_serializable(self, test_home, dataset_mock, config): |
| 115 | + def test_traversable(self, test_home, dataset_mock, config, only_datapipe): |
115 | 116 | dataset_mock.prepare(test_home, config) |
| 117 | + dataset = datasets.load(dataset_mock.name, **config) |
116 | 118 |
|
| 119 | + traverse(dataset, only_datapipe=only_datapipe) |
| 120 | + |
| 121 | + @parametrize_dataset_mocks(DATASET_MOCKS) |
| 122 | + def test_serializable(self, test_home, dataset_mock, config): |
| 123 | + dataset_mock.prepare(test_home, config) |
117 | 124 | dataset = datasets.load(dataset_mock.name, **config) |
118 | 125 |
|
119 | 126 | pickle.dumps(dataset) |
120 | 127 |
|
| 128 | + @pytest.mark.parametrize("num_workers", [0, 1]) |
| 129 | + @parametrize_dataset_mocks(DATASET_MOCKS) |
| 130 | + def test_data_loader(self, test_home, dataset_mock, config, num_workers): |
| 131 | + dataset_mock.prepare(test_home, config) |
| 132 | + dataset = datasets.load(dataset_mock.name, **config) |
| 133 | + |
| 134 | + dl = DataLoader( |
| 135 | + dataset, |
| 136 | + batch_size=2, |
| 137 | + num_workers=num_workers, |
| 138 | + collate_fn=lambda batch: batch, |
| 139 | + ) |
| 140 | + |
| 141 | + next(iter(dl)) |
| 142 | + |
121 | 143 | # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also |
122 | 144 | # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 |
123 | 145 | # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now. |
124 | | - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") |
125 | 146 | @parametrize_dataset_mocks(DATASET_MOCKS) |
126 | 147 | @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) |
127 | 148 | def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): |
|
0 commit comments