diff --git a/test/test_dataloader2.py b/test/test_dataloader2.py index a08bfaa8d..dbc4a8f01 100644 --- a/test/test_dataloader2.py +++ b/test/test_dataloader2.py @@ -9,6 +9,8 @@ import unittest from unittest import TestCase +from torch.utils.data.graph import DataPipe + from torchdata.dataloader2 import ( DataLoader2, MultiProcessingReadingService, @@ -19,15 +21,31 @@ from torchdata.datapipes.iter import IterableWrapper, IterDataPipe +class _ReadingServiceWrapper: + def __init__(self, dp): + self.dp = dp + + def __iter__(self): + self.it = iter(self.dp) + return self + + def __next__(self): + return next(self.it) + + @staticmethod + def return_one(): + return 1 + + class TestReadingService(ReadingServiceInterface): - def initialize(self, dp: IterDataPipe) -> IterDataPipe: - return dp + def initialize(self, dp: DataPipe) -> DataPipe: + return _ReadingServiceWrapper(dp) # type: ignore[return-value] class DataLoader2Test(TestCase): def test_dataloader2(self) -> None: test_data_pipe = IterableWrapper(range(3)) - data_loader = DataLoader2(datapipe=test_data_pipe) + data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe) expected_batch = 0 for batch in iter(data_loader): @@ -36,12 +54,12 @@ def test_dataloader2(self) -> None: def test_dataloader2_shutdown(self) -> None: test_data_pipe = IterableWrapper(range(3)) - data_loader = DataLoader2(datapipe=test_data_pipe) + data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe) data_loader.shutdown() def test_dataloader2_state_dict(self) -> None: test_data_pipe = IterableWrapper(range(3)) - data_loader = DataLoader2(datapipe=test_data_pipe) + data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe) state = data_loader.state_dict() self.assertIsNotNone(state) @@ -52,7 +70,7 @@ def test_dataloader2_state_dict(self) -> None: def test_dataloader2_reading_service(self) -> None: test_data_pipe = IterableWrapper(range(3)) reading_service = TestReadingService() - data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) + data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) expected_batch = 0 for batch in iter(data_loader): @@ -62,7 +80,7 @@ def test_dataloader2_reading_service(self) -> None: def test_dataloader2_multi_process_reading_service(self) -> None: test_data_pipe = IterableWrapper(range(3)) reading_service = MultiProcessingReadingService() - data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) + data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) expected_batch = 0 for batch in iter(data_loader): @@ -72,7 +90,7 @@ def test_dataloader2_multi_process_reading_service(self) -> None: def test_dataloader2_load_state_dict(self) -> None: test_data_pipe = IterableWrapper(range(3)) reading_service = TestReadingService() - data_loader = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) + data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) batch = next(iter(data_loader)) self.assertEqual(batch, 0) @@ -84,7 +102,7 @@ def test_dataloader2_load_state_dict(self) -> None: data_loader.shutdown() test_data_pipe_2 = IterableWrapper(range(5)) - restored_data_loader = DataLoader2(datapipe=test_data_pipe_2, reading_service=reading_service) + restored_data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe_2, reading_service=reading_service) restored_data_loader.load_state_dict(state) restored_data_loader_datapipe = restored_data_loader.datapipe @@ -104,6 +122,42 @@ def test_dataloader2_load_state_dict(self) -> None: restored_data_loader.shutdown() + def test_dataloader2_reset(self) -> None: + + test_data_pipe = IterableWrapper(range(10)) + reading_services = [None, TestReadingService(), MultiProcessingReadingService(num_workers=1)] + + for reading_service in reading_services: + data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) + + # Functional Test: Ensure multiple sequential reads of DL2 is possible + self.assertEqual(list(range(10)), list(data_loader)) + self.assertEqual(list(range(10)), list(data_loader)) + self.assertEqual(list(range(10)), list(data_loader)) + + # Functional Test: Ensure that the creation of a new iterator invalidates the old one + it1 = iter(data_loader) + self.assertEqual(0, next(it1)) + self.assertEqual(1, next(it1)) + it2 = iter(data_loader) + self.assertEqual(0, next(it2)) + self.assertEqual(1, next(it2)) + with self.assertRaisesRegex(RuntimeError, "iterator has been invalidated"): + next(it1) + self.assertEqual(list(range(2, 10)), list(it2)) + + def test_dataloader2_delegate_attribute(self) -> None: + test_data_pipe = IterableWrapper(range(10)) + data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=TestReadingService()) + + # Functional Test: Ensure multiple sequential reads of DL2 is possible + self.assertEqual(list(range(10)), list(data_loader)) + self.assertEqual(list(range(10)), list(data_loader)) + + # Functional Test: Ensure that attribute/method of `dataloader._datapipe_iter` can be used + it = iter(data_loader) + self.assertEqual(1, it.return_one()) # type: ignore[attr-defined] + class DataLoader2ConsistencyTest(TestCase): r""" diff --git a/torchdata/dataloader2/__init__.py b/torchdata/dataloader2/__init__.py index 4351b7300..2414d13d0 100644 --- a/torchdata/dataloader2/__init__.py +++ b/torchdata/dataloader2/__init__.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. -from .dataloader2 import DataLoader2 +from .dataloader2 import DataLoader2, DataLoader2Iterator from .error import PauseIteration from .reading_service import ( MultiProcessingReadingService, @@ -16,6 +16,7 @@ __all__ = [ "DataLoader2", + "DataLoader2Iterator", "MultiProcessingReadingService", "PauseIteration", "PrototypeMultiProcessingReadingService", diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index c6b862bad..555839e76 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -43,6 +43,42 @@ class ConcurrencySpec: persistent_workers: bool = False +class DataLoader2Iterator(Iterator): + def __init__(self, dataloader: "DataLoader2", iterator_id: int): + self.dataloader = dataloader + self.iterator_id = iterator_id + + def __next__(self) -> T_co: + if self.iterator_id == self.dataloader.valid_iterator_id: + self.dataloader._reset_iter = True + try: + return next(self.dataloader._datapipe_iter) # type: ignore[arg-type] + except PauseIteration: + raise StopIteration + except StopIteration: + if self.dataloader.reading_service is not None: + self.dataloader.reading_service.finalize_iteration() + raise + else: + if self.dataloader.reading_service is not None: + self.dataloader.reading_service.finalize_iteration() + raise RuntimeError( + "This iterator has been invalidated because another iterator has been created " + "from the same DataLoader2.\n" + "This may be caused multiple references to the same DataLoader2. " + "For feedback regarding this single iterator per DataLoader2 constraint, feel free " + "to comment on this issue: https://github.com/pytorch/data/issues/45." + ) + + def __getattr__(self, name): + """ + To delegate operations to ``dataloader._datapipe_iter``. + """ + if self.dataloader._datapipe_iter is None: + raise AttributeError + return getattr(self.dataloader._datapipe_iter, name) + + class DataLoader2(Generic[T_co]): def __init__( self, @@ -53,7 +89,7 @@ def __init__( self.datapipe = datapipe self._adapted: bool = False self._datapipe_iter: Optional[Iterator[T_co]] = None - self._reset_iter: bool = True + self._reset_iter: bool = True # Sets to `False` when __iter__ starts, and `True` when `StopIteration`` # TODO(630): Some ReadingServices might want to validate adapters, we can add this feature if datapipe_adapter_fn is None: self.datapipe_adapter_fns = None @@ -62,8 +98,9 @@ def __init__( else: self.datapipe_adapter_fns = [datapipe_adapter_fn] self.reading_service = reading_service - self.reading_service_state: Optional[bytes] = None + self.reading_service_state: Optional[bytes] = None # is not `None` when `load_state_dict` is called self._terminated: bool = False + self.valid_iterator_id: Optional[int] = None if self.datapipe_adapter_fns is not None: for adapter_fn in self.datapipe_adapter_fns: @@ -88,28 +125,10 @@ def __iter__(self) -> Iterator[T_co]: self.reading_service.initialize_iteration() self._datapipe_iter = iter(self.datapipe) - self._reset_iter = False - return self - - def __next__(self) -> T_co: - if self._reset_iter: - raise StopIteration - try: - return next(self._datapipe_iter) # type: ignore[arg-type] - except PauseIteration: - raise StopIteration - except StopIteration: - if self.reading_service is not None: - self.reading_service.finalize_iteration() - self._reset_iter = True - raise - - def __getattr__(self, name: str) -> Any: - if self._datapipe_iter is None: - raise AttributeError - return getattr(self._datapipe_iter, name) + self.valid_iterator_id = 0 if self.valid_iterator_id is None else self.valid_iterator_id + 1 + return DataLoader2Iterator(self, self.valid_iterator_id) def __del__(self) -> None: self.shutdown() @@ -175,7 +194,8 @@ def load_state_dict(self, state: Dict[str, Any]) -> None: # iterator has already been created: 1) iterator is just created 2) iterator is created and iter is exhausted if self._datapipe_iter is not None: raise RuntimeError( - "DataLoaderV2 iterator has already been created, `load_state_dict()` can’t be called. Please create a new dataloader in order to use load state dict." + "DataLoaderV2 iterator has already been created, `load_state_dict()` can’t be called. " + "Please create a new dataloader in order to use load state dict." ) serialized_datapipe = state[SERIALIZED_DATAPIPE_KEY_NAME] diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 407c72fcc..e144e9027 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -146,10 +146,11 @@ def init_datapipe_process(num_workers, worker_id, datapipe): def initialize(self, datapipe: DataPipe) -> DataPipe: if self.num_workers == 0: - # TODO(616): Warn and recommend usage of InPorcessReadingService + # TODO(616): Warn and recommend usage of InProcessReadingService return datapipe for worker_id in range(self.num_workers): - # TODO(617): Separate into function, because we also need to apply distributed seed and call it inside process + # TODO(617): Separate into function, because we also need to apply distributed seed + # and call it inside process call_inside_process = functools.partial(self.init_datapipe_process, self.num_workers, worker_id) ctx = mp.get_context(self.multiprocessing_context) (process, req_queue, res_queue) = communication.eventloop.SpawnProcessForDataPipeline(