Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataLoader2] Saving and restoring initial seed generator #998

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
05c0bf1
[DataLoader2] Saving and restoring initial seed generator
NivekT Feb 8, 2023
45b6998
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 8, 2023
9d6f38d
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 9, 2023
389e567
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 9, 2023
955e412
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 10, 2023
90278bf
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 10, 2023
1a8ebdd
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 13, 2023
fa1f93f
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 15, 2023
4653af3
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 15, 2023
15f774e
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 15, 2023
5de743f
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 16, 2023
18a5c26
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 28, 2023
9f87c01
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 28, 2023
3cdd2ea
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 17, 2023
ef850ed
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 17, 2023
b715066
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 17, 2023
5e56222
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 24, 2023
0433509
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 24, 2023
0d4854c
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions test/dataloader2/test_mprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,65 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
res.append(x)
self.assertEqual(9, len(res))

def test_initial_epoch_checkpointing(self):
dp = IterableWrapper(range(20)).shuffle().sharding_filter()
rs = MultiProcessingReadingService(num_workers=2)

# Functional Test: Saving state before iterator is created
dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
initial_state = dl.state_dict()
it1 = iter(dl)

restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created and began iterating
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
next(it1) # Starts iterating
NivekT marked this conversation as resolved.
Show resolved Hide resolved
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._initial_seed_generator = None

print(dl._seed_generator)
print(restored_dl._seed_generator)

print(dl._seed_generator == restored_dl._seed_generator)

restored_dl2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl2._restore_checkpoint_beginning_of_epoch()

res1 = list(it1)
res2 = list(restored_dl)
res3 = list(restored_dl2)

print(res1)
print(res2)
print(res3)
NivekT marked this conversation as resolved.
Show resolved Hide resolved
# self.assertEqual(list(it1), list(restored_dl)[1:]) # Note skipping over 1st element from actual result
NivekT marked this conversation as resolved.
Show resolved Hide resolved

dl.shutdown()
restored_dl.shutdown()
restored_dl2.shutdown()

# TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding
# Currently, using sharding_round_robin raises a warning
# def test_round_robin_dispatching_pause_limit(self):
Expand Down
30 changes: 28 additions & 2 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import pickle
import warnings

from dataclasses import dataclass
Expand All @@ -20,6 +19,7 @@
T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
READING_SERVICE_STATE_KEY_NAME = "reading_service_state"
RANDOMNESS_STATE_KEY_NAME = "randomness_state"


@dataclass
Expand Down Expand Up @@ -185,6 +185,8 @@ def __init__(
self._seed_generator: SeedGenerator = SeedGenerator()
self._seed: Optional[int] = None
self._reset_seed: bool = True
# Seed generator as of beginning of each epoch
self._initial_seed_generator: SeedGenerator = clone(self._seed_generator)

def __iter__(self) -> DataLoader2Iterator[T_co]:
r"""
Expand All @@ -207,6 +209,9 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:
else:
self._seed_generator.seed()

# Saving initial seed generator state
self._initial_seed_generator = clone(self._seed_generator)

if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
self.datapipe = self.reading_service.initialize(self.datapipe)
Expand Down Expand Up @@ -278,10 +283,17 @@ def state_dict(self) -> Dict[str, Any]:

NivekT marked this conversation as resolved.
Show resolved Hide resolved
# Serialize datapipe after applying adapters and before reading service adaption
serialized_datapipe = serialize_datapipe(self._datapipe_before_reading_service_adapt)
serialized_randomness_state = (
self._seed,
self._reset_seed,
pickle.dumps(self._seed_generator),
NivekT marked this conversation as resolved.
Show resolved Hide resolved
pickle.dumps(self._initial_seed_generator),
)

return {
SERIALIZED_DATAPIPE_KEY_NAME: serialized_datapipe,
READING_SERVICE_STATE_KEY_NAME: reading_service_state,
RANDOMNESS_STATE_KEY_NAME: serialized_randomness_state,
}

@classmethod
Expand All @@ -303,6 +315,12 @@ def from_state(
reading_service=reading_service,
)
data_loader.reading_service_state = reading_service_state

randomness_state = state[RANDOMNESS_STATE_KEY_NAME]
data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1]
data_loader._seed_generator = pickle.loads(randomness_state[2])
data_loader._initial_seed_generator = pickle.loads(randomness_state[3])

return data_loader

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand All @@ -329,12 +347,20 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.datapipe = deserialized_datapipe
self.reading_service_state = reading_service_state

randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME]
self._seed, self._reset_seed = randomness_state[0], randomness_state[1]
self._seed_generator = pickle.loads(randomness_state[2])
self._initial_seed_generator = pickle.loads(randomness_state[3])

# re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt
if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
self.datapipe = adapter_fn(self.datapipe)
self._datapipe_before_reading_service_adapt = clone(self.datapipe)

def _restore_checkpoint_beginning_of_epoch(self) -> None:
self._seed_generator = self._initial_seed_generator

def _pause(self):
if hasattr(self.reading_service, "_pause"):
self._is_paused = True
Expand Down