Skip to content

Commit

Permalink
Update on "[DataLoader2] Adding guard to randomness state for backwar…
Browse files Browse the repository at this point in the history
…d compatibility"


Follow up to #998 for backward compatibility.

Differential Revision: [D44747988](https://our.internmc.facebook.com/intern/diff/D44747988)

[ghstack-poisoned]
  • Loading branch information
NivekT committed Apr 6, 2023
2 parents 2f646e0 + f275b1c commit 288d369
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 4 deletions.
66 changes: 62 additions & 4 deletions test/dataloader2/test_mprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import multiprocessing as mp
import unittest
from unittest import TestCase
Expand All @@ -14,7 +13,7 @@
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


def _add_one(x: int) -> int:
Expand Down Expand Up @@ -46,6 +45,17 @@ def _dispatching_dp(n_elements=1000):
return dp


class NonShardableDataPipe(IterDataPipe):
def __init__(self, dp: IterDataPipe):
self.dp = dp

def is_replicable(self):
return False

def __iter__(self):
yield from self.dp


class TestMultiProcessingReadingService(TestCase):
r"""
This tests specific functionalities of MultiProcessingReadingService, notably
Expand All @@ -64,7 +74,7 @@ def test_early_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
worker_prefetch_cnt=worker_prefetch,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for _ in range(10):
_ = next(it)
Expand All @@ -82,7 +92,7 @@ def test_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
worker_prefetch_cnt=worker_prefetch,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
_ = list(dl)
dl.shutdown()

Expand Down Expand Up @@ -248,6 +258,54 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
res.append(x)
self.assertEqual(9, len(res))

def test_initial_epoch_checkpointing(self):
dp = IterableWrapper(range(20)).shuffle().sharding_filter()
# Note that the second `shuffle` occurs in the main process, which uses a different RNG from
# the `shuffle` done in the worker processes
dp = NonShardableDataPipe(dp).shuffle() # type: ignore[assignment, arg-type]
rs = MultiProcessingReadingService(num_workers=2)

# Functional Test: Saving state before iterator is created
dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
initial_state = dl.state_dict()
it1 = iter(dl)

restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created and began iterating
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
temp = next(it1) # Starts iterating
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()

self.assertEqual([temp] + list(it1), list(restored_dl)) # Note skipping over 1st element from actual result

dl.shutdown()
restored_dl.shutdown()

# TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding
# Currently, using sharding_round_robin raises a warning
# def test_round_robin_dispatching_pause_limit(self):
Expand Down
10 changes: 10 additions & 0 deletions torchdata/dataloader2/random/seed_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,13 @@ def spawn(self, worker_id: int, inplace: bool = False) -> "SeedGenerator":
self._worker_rng = self._worker_rng.spawn(worker_id)
return self
return SeedGenerator(seed=None, _rngs=(self._shared_rng.clone(), self._worker_rng.spawn(worker_id)))

def __getstate__(self):
state = (
self._shared_rng,
self._worker_rng,
)
return state

def __setstate__(self, state):
self._shared_rng, self._worker_rng = state
2 changes: 2 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def initialize_iteration(
) -> Optional[Callable[[DataPipe], DataPipe]]:
assert self._end_datapipe is not None

# Set random seeds for DataPipe that are in the main process (NOT those in worker processes)
# Worker seeds are set in `process_reset_fn`
set_graph_random_seed(self._end_datapipe, seed_generator)

if self._mp:
Expand Down

0 comments on commit 288d369

Please sign in to comment.