diff --git a/test/dataloader2/test_mprs.py b/test/dataloader2/test_mprs.py index 05a7ba27b..b4fb1d7f8 100644 --- a/test/dataloader2/test_mprs.py +++ b/test/dataloader2/test_mprs.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import multiprocessing as mp import unittest from unittest import TestCase @@ -13,7 +14,7 @@ from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService -from torchdata.datapipes.iter import IterableWrapper, IterDataPipe +from torchdata.datapipes.iter import IterableWrapper def _add_one(x: int) -> int: @@ -45,17 +46,6 @@ def _dispatching_dp(n_elements=1000): return dp -class NonShardableDataPipe(IterDataPipe): - def __init__(self, dp: IterDataPipe): - self.dp = dp - - def is_replicable(self): - return False - - def __iter__(self): - yield from self.dp - - class TestMultiProcessingReadingService(TestCase): r""" This tests specific functionalities of MultiProcessingReadingService, notably @@ -74,7 +64,7 @@ def test_early_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None: worker_prefetch_cnt=worker_prefetch, multiprocessing_context=ctx, ) - dl: DataLoader2 = DataLoader2(dp, reading_service=rs) + dl = DataLoader2(dp, reading_service=rs) it = iter(dl) for _ in range(10): _ = next(it) @@ -92,7 +82,7 @@ def test_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None: worker_prefetch_cnt=worker_prefetch, multiprocessing_context=ctx, ) - dl: DataLoader2 = DataLoader2(dp, reading_service=rs) + dl = DataLoader2(dp, reading_service=rs) _ = list(dl) dl.shutdown() @@ -258,54 +248,6 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr res.append(x) self.assertEqual(9, len(res)) - def test_initial_epoch_checkpointing(self): - dp = IterableWrapper(range(20)).shuffle().sharding_filter() - # Note that the second `shuffle` occurs in the main process, which uses a different RNG from - # the `shuffle` done in the worker processes - dp = NonShardableDataPipe(dp).shuffle() # type: ignore[assignment, arg-type] - rs = MultiProcessingReadingService(num_workers=2) - - # Functional Test: Saving state before iterator is created - dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs) - dl.seed(1) - initial_state = dl.state_dict() - it1 = iter(dl) - - restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] - restored_dl._restore_checkpoint_beginning_of_epoch() - self.assertEqual(list(it1), list(restored_dl)) - - dl.shutdown() - restored_dl.shutdown() - - # Functional Test: Saving state after iterator is created - dl = DataLoader2(datapipe=dp, reading_service=rs) - dl.seed(1) - it1 = iter(dl) - initial_state = dl.state_dict() - - restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] - restored_dl._restore_checkpoint_beginning_of_epoch() - self.assertEqual(list(it1), list(restored_dl)) - - dl.shutdown() - restored_dl.shutdown() - - # Functional Test: Saving state after iterator is created and began iterating - dl = DataLoader2(datapipe=dp, reading_service=rs) - dl.seed(1) - it1 = iter(dl) - temp = next(it1) # Starts iterating - initial_state = dl.state_dict() - - restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] - restored_dl._restore_checkpoint_beginning_of_epoch() - - self.assertEqual([temp] + list(it1), list(restored_dl)) # Note skipping over 1st element from actual result - - dl.shutdown() - restored_dl.shutdown() - # TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding # Currently, using sharding_round_robin raises a warning # def test_round_robin_dispatching_pause_limit(self): diff --git a/torchdata/dataloader2/random/seed_generator.py b/torchdata/dataloader2/random/seed_generator.py index fa4fdfabc..2f67ee2be 100644 --- a/torchdata/dataloader2/random/seed_generator.py +++ b/torchdata/dataloader2/random/seed_generator.py @@ -83,13 +83,3 @@ def spawn(self, worker_id: int, inplace: bool = False) -> "SeedGenerator": self._worker_rng = self._worker_rng.spawn(worker_id) return self return SeedGenerator(seed=None, _rngs=(self._shared_rng.clone(), self._worker_rng.spawn(worker_id))) - - def __getstate__(self): - state = ( - self._shared_rng, - self._worker_rng, - ) - return state - - def __setstate__(self, state): - self._shared_rng, self._worker_rng = state diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index d8a43e349..4109c05fe 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -312,8 +312,6 @@ def initialize_iteration( ) -> Optional[Callable[[DataPipe], DataPipe]]: assert self._end_datapipe is not None - # Set random seeds for DataPipe that are in the main process (NOT those in worker processes) - # Worker seeds are set in `process_reset_fn` set_graph_random_seed(self._end_datapipe, seed_generator) if self._mp: