Skip to content

Commit

Permalink
Fix issue related to duplicate collation (#531)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #531

Since we want users to always provide their own collation in the DataPipe graph, we add this function to avoid duplicate collation. It will no longer be needed once this reading service no longer relies on DLv1.

Separately, the [linter issue](#364) is modified to track the fact that we want users to always provide a `Collator`.

Fixes part of #530

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D37348504

Pulled By: NivekT

fbshipit-source-id: d04f425dae47d679edad2ad5669b10e29b831500
  • Loading branch information
NivekT authored and facebook-github-bot committed Jun 23, 2022
1 parent cc3f866 commit 6c8f778
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
21 changes: 21 additions & 0 deletions test/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,26 @@ def test_dataloader2_load_state_dict(self) -> None:
restored_data_loader.shutdown()


class DataLoader2ConsistencyTest(TestCase):
r"""
These tests ensure that the behaviors of `DataLoader2` are consistent across `ReadingServices` and potentially
with `DataLoaderV1`.
"""

def test_dataloader2_batch_collate(self) -> None:
dp: IterDataPipe = IterableWrapper(range(10)).batch(2).collate() # type: ignore[assignment]

dl_no_rs: DataLoader2 = DataLoader2(dp)

rs = MultiProcessingReadingService(num_workers=0)
dl_multi_rs: DataLoader2 = DataLoader2(dp, reading_service=rs)

self.assertTrue(all(x.eq(y).all() for x, y in zip(dl_no_rs, dl_multi_rs)))

def test_dataloader2_shuffle(self) -> None:
# TODO
pass


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def restore(self, datapipe: IterDataPipe, serialized_state: bytes) -> IterDataPi
pass


def _collate_no_op(batch):
return batch[0]


class MultiProcessingReadingService(ReadingServiceInterface):
num_workers: int
pin_memory: bool
Expand Down Expand Up @@ -126,6 +130,9 @@ def initialize(self, datapipe: IterDataPipe) -> IterDataPipe:
multiprocessing_context=self.multiprocessing_context,
prefetch_factor=self.prefetch_factor,
persistent_workers=self.persistent_workers,
# TODO: `collate_fn` is necessary until we stop using DLv1 https://github.com/pytorch/data/issues/530
collate_fn=_collate_no_op,
batch_size=1, # This reading service assume batching is done via DataPipe
)
return IterableWrapper(self.dl_) # type: ignore[return-value]

Expand Down

0 comments on commit 6c8f778

Please sign in to comment.