Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataLoader2] Saving and restoring initial seed generator #998

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
05c0bf1
[DataLoader2] Saving and restoring initial seed generator
NivekT Feb 8, 2023
45b6998
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 8, 2023
9d6f38d
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 9, 2023
389e567
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 9, 2023
955e412
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 10, 2023
90278bf
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 10, 2023
1a8ebdd
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 13, 2023
fa1f93f
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 15, 2023
4653af3
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 15, 2023
15f774e
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 15, 2023
5de743f
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 16, 2023
18a5c26
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 28, 2023
9f87c01
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 28, 2023
3cdd2ea
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 17, 2023
ef850ed
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 17, 2023
b715066
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 17, 2023
5e56222
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 24, 2023
0433509
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 24, 2023
0d4854c
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions test/dataloader2/test_mprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import multiprocessing as mp
import unittest
from unittest import TestCase
Expand All @@ -14,7 +13,7 @@
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


def _add_one(x: int) -> int:
Expand Down Expand Up @@ -46,6 +45,17 @@ def _dispatching_dp(n_elements=1000):
return dp


class NonShardableDataPipe(IterDataPipe):
def __init__(self, dp: IterDataPipe):
self.dp = dp

def is_replicable(self):
return False

def __iter__(self):
yield from self.dp


class TestMultiProcessingReadingService(TestCase):
r"""
This tests specific functionalities of MultiProcessingReadingService, notably
Expand All @@ -64,7 +74,7 @@ def test_early_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
worker_prefetch_cnt=worker_prefetch,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for _ in range(10):
_ = next(it)
Expand All @@ -82,7 +92,7 @@ def test_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
worker_prefetch_cnt=worker_prefetch,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
_ = list(dl)
dl.shutdown()

Expand Down Expand Up @@ -248,6 +258,54 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
res.append(x)
self.assertEqual(9, len(res))

def test_initial_epoch_checkpointing(self):
dp = IterableWrapper(range(20)).shuffle().sharding_filter()
# Note that the second `shuffle` occurs in the main process, which uses a different RNG from
# the `shuffle` done in the worker processes
dp = NonShardableDataPipe(dp).shuffle() # type: ignore[assignment, arg-type]
rs = MultiProcessingReadingService(num_workers=2)

# Functional Test: Saving state before iterator is created
dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
initial_state = dl.state_dict()
it1 = iter(dl)

restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created and began iterating
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
temp = next(it1) # Starts iterating
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()

self.assertEqual([temp] + list(it1), list(restored_dl)) # Note skipping over 1st element from actual result

dl.shutdown()
restored_dl.shutdown()

# TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding
# Currently, using sharding_round_robin raises a warning
# def test_round_robin_dispatching_pause_limit(self):
Expand Down
37 changes: 36 additions & 1 deletion torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import pickle
import warnings

from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union
Expand All @@ -19,6 +19,7 @@
T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
READING_SERVICE_STATE_KEY_NAME = "reading_service_state"
RANDOMNESS_STATE_KEY_NAME = "randomness_state"


class DataLoader2Iterator(Iterator[T_co]):
Expand Down Expand Up @@ -176,6 +177,8 @@ def __init__(
self._seed_generator: SeedGenerator = SeedGenerator()
self._seed: Optional[int] = None
self._reset_seed: bool = True
# Seed generator as of beginning of each epoch
self._initial_seed_generator: SeedGenerator = clone(self._seed_generator)

def __iter__(self) -> DataLoader2Iterator[T_co]:
r"""
Expand All @@ -198,6 +201,9 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:
else:
self._seed_generator.seed()

# Saving initial seed generator state
self._initial_seed_generator = clone(self._seed_generator)

if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
self.datapipe = self.reading_service.initialize(self.datapipe)
Expand Down Expand Up @@ -269,10 +275,17 @@ def state_dict(self) -> Dict[str, Any]:

NivekT marked this conversation as resolved.
Show resolved Hide resolved
# Serialize datapipe after applying adapters and before reading service adaption
serialized_datapipe = serialize_datapipe(self._datapipe_before_reading_service_adapt)
serialized_randomness_state = (
self._seed,
self._reset_seed,
pickle.dumps(self._seed_generator),
NivekT marked this conversation as resolved.
Show resolved Hide resolved
pickle.dumps(self._initial_seed_generator),
)

return {
SERIALIZED_DATAPIPE_KEY_NAME: serialized_datapipe,
READING_SERVICE_STATE_KEY_NAME: reading_service_state,
RANDOMNESS_STATE_KEY_NAME: serialized_randomness_state,
}

@classmethod
Expand All @@ -294,6 +307,12 @@ def from_state(
reading_service=reading_service,
)
data_loader.reading_service_state = reading_service_state

randomness_state = state[RANDOMNESS_STATE_KEY_NAME]
data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1]
data_loader._seed_generator = pickle.loads(randomness_state[2])
data_loader._initial_seed_generator = pickle.loads(randomness_state[3])

return data_loader

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand All @@ -320,12 +339,28 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.datapipe = deserialized_datapipe
self.reading_service_state = reading_service_state

randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME]
self._seed, self._reset_seed = randomness_state[0], randomness_state[1]
self._seed_generator = pickle.loads(randomness_state[2])
self._initial_seed_generator = pickle.loads(randomness_state[3])

# re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt
if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
self.datapipe = adapter_fn(self.datapipe)
self._datapipe_before_reading_service_adapt = clone(self.datapipe)

def _restore_checkpoint_beginning_of_epoch(self) -> None:
r"""
At the beginning of each iteration (epoch), the initial state of randomness is automatically saved.
That state is also saved as part of ``state_dict``. This method restores the current DataLoader2 RNG state
to that initial state.

The common use case is to invoke this method after ``DataLoader2``'s state is restored (through
``.from_state(...)`` or ``load_state_dict(...)``) in order to resume from the beginning of the last-ran epoch.
"""
self._seed_generator = self._initial_seed_generator

def _pause(self):
if hasattr(self.reading_service, "_pause"):
self._is_paused = True
Expand Down
10 changes: 10 additions & 0 deletions torchdata/dataloader2/random/seed_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,13 @@ def spawn(self, worker_id: int, inplace: bool = False) -> "SeedGenerator":
self._worker_rng = self._worker_rng.spawn(worker_id)
return self
return SeedGenerator(seed=None, _rngs=(self._shared_rng.clone(), self._worker_rng.spawn(worker_id)))

def __getstate__(self):
state = (
self._shared_rng,
self._worker_rng,
)
return state

def __setstate__(self, state):
self._shared_rng, self._worker_rng = state
2 changes: 2 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def initialize_iteration(
) -> Optional[Callable[[DataPipe], DataPipe]]:
assert self._end_datapipe is not None

# Set random seeds for DataPipe that are in the main process (NOT those in worker processes)
# Worker seeds are set in `process_reset_fn`
set_graph_random_seed(self._end_datapipe, seed_generator)

if self._mp:
Expand Down