diff --git a/test/dataloader2/test_mprs.py b/test/dataloader2/test_mprs.py index b4fb1d7f8..05a7ba27b 100644 --- a/test/dataloader2/test_mprs.py +++ b/test/dataloader2/test_mprs.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import multiprocessing as mp import unittest from unittest import TestCase @@ -14,7 +13,7 @@ from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService -from torchdata.datapipes.iter import IterableWrapper +from torchdata.datapipes.iter import IterableWrapper, IterDataPipe def _add_one(x: int) -> int: @@ -46,6 +45,17 @@ def _dispatching_dp(n_elements=1000): return dp +class NonShardableDataPipe(IterDataPipe): + def __init__(self, dp: IterDataPipe): + self.dp = dp + + def is_replicable(self): + return False + + def __iter__(self): + yield from self.dp + + class TestMultiProcessingReadingService(TestCase): r""" This tests specific functionalities of MultiProcessingReadingService, notably @@ -64,7 +74,7 @@ def test_early_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None: worker_prefetch_cnt=worker_prefetch, multiprocessing_context=ctx, ) - dl = DataLoader2(dp, reading_service=rs) + dl: DataLoader2 = DataLoader2(dp, reading_service=rs) it = iter(dl) for _ in range(10): _ = next(it) @@ -82,7 +92,7 @@ def test_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None: worker_prefetch_cnt=worker_prefetch, multiprocessing_context=ctx, ) - dl = DataLoader2(dp, reading_service=rs) + dl: DataLoader2 = DataLoader2(dp, reading_service=rs) _ = list(dl) dl.shutdown() @@ -248,6 +258,54 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr res.append(x) self.assertEqual(9, len(res)) + def test_initial_epoch_checkpointing(self): + dp = IterableWrapper(range(20)).shuffle().sharding_filter() + # Note that the second `shuffle` occurs in the main process, which uses a different RNG from + # the `shuffle` done in the worker processes + dp = NonShardableDataPipe(dp).shuffle() # type: ignore[assignment, arg-type] + rs = MultiProcessingReadingService(num_workers=2) + + # Functional Test: Saving state before iterator is created + dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs) + dl.seed(1) + initial_state = dl.state_dict() + it1 = iter(dl) + + restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] + restored_dl._restore_checkpoint_beginning_of_epoch() + self.assertEqual(list(it1), list(restored_dl)) + + dl.shutdown() + restored_dl.shutdown() + + # Functional Test: Saving state after iterator is created + dl = DataLoader2(datapipe=dp, reading_service=rs) + dl.seed(1) + it1 = iter(dl) + initial_state = dl.state_dict() + + restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] + restored_dl._restore_checkpoint_beginning_of_epoch() + self.assertEqual(list(it1), list(restored_dl)) + + dl.shutdown() + restored_dl.shutdown() + + # Functional Test: Saving state after iterator is created and began iterating + dl = DataLoader2(datapipe=dp, reading_service=rs) + dl.seed(1) + it1 = iter(dl) + temp = next(it1) # Starts iterating + initial_state = dl.state_dict() + + restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] + restored_dl._restore_checkpoint_beginning_of_epoch() + + self.assertEqual([temp] + list(it1), list(restored_dl)) # Note skipping over 1st element from actual result + + dl.shutdown() + restored_dl.shutdown() + # TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding # Currently, using sharding_round_robin raises a warning # def test_round_robin_dispatching_pause_limit(self): diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index 893e51180..15b582187 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - +import pickle import warnings from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union @@ -19,6 +19,7 @@ T_co = TypeVar("T_co", covariant=True) SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe" READING_SERVICE_STATE_KEY_NAME = "reading_service_state" +RANDOMNESS_STATE_KEY_NAME = "randomness_state" class DataLoader2Iterator(Iterator[T_co]): @@ -176,6 +177,8 @@ def __init__( self._seed_generator: SeedGenerator = SeedGenerator() self._seed: Optional[int] = None self._reset_seed: bool = True + # Seed generator as of beginning of each epoch + self._initial_seed_generator: SeedGenerator = clone(self._seed_generator) def __iter__(self) -> DataLoader2Iterator[T_co]: r""" @@ -198,6 +201,9 @@ def __iter__(self) -> DataLoader2Iterator[T_co]: else: self._seed_generator.seed() + # Saving initial seed generator state + self._initial_seed_generator = clone(self._seed_generator) + if not self._adapted and self.reading_service is not None: if self.reading_service_state is None: self.datapipe = self.reading_service.initialize(self.datapipe) @@ -269,10 +275,17 @@ def state_dict(self) -> Dict[str, Any]: # Serialize datapipe after applying adapters and before reading service adaption serialized_datapipe = serialize_datapipe(self._datapipe_before_reading_service_adapt) + serialized_randomness_state = ( + self._seed, + self._reset_seed, + pickle.dumps(self._seed_generator), + pickle.dumps(self._initial_seed_generator), + ) return { SERIALIZED_DATAPIPE_KEY_NAME: serialized_datapipe, READING_SERVICE_STATE_KEY_NAME: reading_service_state, + RANDOMNESS_STATE_KEY_NAME: serialized_randomness_state, } @classmethod @@ -294,6 +307,12 @@ def from_state( reading_service=reading_service, ) data_loader.reading_service_state = reading_service_state + + randomness_state = state[RANDOMNESS_STATE_KEY_NAME] + data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1] + data_loader._seed_generator = pickle.loads(randomness_state[2]) + data_loader._initial_seed_generator = pickle.loads(randomness_state[3]) + return data_loader def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -320,12 +339,28 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.datapipe = deserialized_datapipe self.reading_service_state = reading_service_state + randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME] + self._seed, self._reset_seed = randomness_state[0], randomness_state[1] + self._seed_generator = pickle.loads(randomness_state[2]) + self._initial_seed_generator = pickle.loads(randomness_state[3]) + # re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt if self.datapipe_adapter_fns is not None: for adapter_fn in self.datapipe_adapter_fns: self.datapipe = adapter_fn(self.datapipe) self._datapipe_before_reading_service_adapt = clone(self.datapipe) + def _restore_checkpoint_beginning_of_epoch(self) -> None: + r""" + At the beginning of each iteration (epoch), the initial state of randomness is automatically saved. + That state is also saved as part of ``state_dict``. This method restores the current DataLoader2 RNG state + to that initial state. + + The common use case is to invoke this method after ``DataLoader2``'s state is restored (through + ``.from_state(...)`` or ``load_state_dict(...)``) in order to resume from the beginning of the last-ran epoch. + """ + self._seed_generator = self._initial_seed_generator + def _pause(self): if hasattr(self.reading_service, "_pause"): self._is_paused = True diff --git a/torchdata/dataloader2/random/seed_generator.py b/torchdata/dataloader2/random/seed_generator.py index 2f67ee2be..fa4fdfabc 100644 --- a/torchdata/dataloader2/random/seed_generator.py +++ b/torchdata/dataloader2/random/seed_generator.py @@ -83,3 +83,13 @@ def spawn(self, worker_id: int, inplace: bool = False) -> "SeedGenerator": self._worker_rng = self._worker_rng.spawn(worker_id) return self return SeedGenerator(seed=None, _rngs=(self._shared_rng.clone(), self._worker_rng.spawn(worker_id))) + + def __getstate__(self): + state = ( + self._shared_rng, + self._worker_rng, + ) + return state + + def __setstate__(self, state): + self._shared_rng, self._worker_rng = state diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 4109c05fe..d8a43e349 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -312,6 +312,8 @@ def initialize_iteration( ) -> Optional[Callable[[DataPipe], DataPipe]]: assert self._end_datapipe is not None + # Set random seeds for DataPipe that are in the main process (NOT those in worker processes) + # Worker seeds are set in `process_reset_fn` set_graph_random_seed(self._end_datapipe, seed_generator) if self._mp: