Skip to content

Commit

Permalink
[PrototypeRS] Adding support for naive snapshotting
Browse files Browse the repository at this point in the history
ghstack-source-id: 5a70ee1b1b792e148e0bd9503ee7eeaf15f3f128
Pull Request resolved: #915
  • Loading branch information
NivekT committed Dec 7, 2022
1 parent 9eec278 commit 250aa2b
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 5 deletions.
63 changes: 59 additions & 4 deletions test/dataloader2/test_proto_multi_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,67 @@ def test_reading_service_pause_stop_yield(self) -> None:
)
dl.shutdown()

# TODO: Implemented in an upcoming PR
def test_reading_service_pause_stop_yield(self) -> None:

# Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called
rs7 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=1)
rs8 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=1, main_prefetch_cnt=0)
rs9 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)

test_rss2 = [rs7, rs8, rs9]
for n, rs in enumerate(test_rss2):
dl: DataLoader2 = DataLoader2(self.double_pause_dp, reading_service=rs)
res = []
for i, x in enumerate(dl):
res.append(x)
if i in {2}:
dl.pause()
self.assertEqual(
3,
len(res),
msg=f"The test is failing for rs{n + 7}, with num_workers = {rs.num_workers}, "
f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}",
)
dl.shutdown()

# def test_reading_service_snapshot(self) -> None:
# pass
#
# def test_dataloader2_snapshot(self) -> None:
# pass

def test_dataloader2_snapshot(self) -> None:

rs1 = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs2 = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
# rs5 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs6 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)

n_samples_before_snapshot = 3

n_samples_yielded = 0
initial_seed_rng = None

test_rss = [rs1]
for rs in test_rss:
dl: DataLoader2 = DataLoader2(self.dp1, reading_service=rs)
res = []
for i, x in enumerate(dl):
res.append(x)
if i in {n_samples_before_snapshot - 1}:
n_samples_yielded, initial_seed_rng = dl._get_naive_datapipe_snapshot()
break
dl.shutdown()
self.assertEqual(n_samples_before_snapshot, len(res))
self.assertEqual(n_samples_before_snapshot, n_samples_yielded)

dl_restored: DataLoader2 = DataLoader2(self.dp1, reading_service=rs)
dl_restored._restore_naive_datapipe_snapshot(n_samples_yielded, initial_seed_rng)
restored_res = list(dl_restored)
self.assertEqual(res, restored_res[0 : n_samples_before_snapshot - 1]) # Check if elements are the same
self.assertEqual(list(range(self.n_elements)), sorted(restored_res))
dl_restored.shutdown()

# TODO: Need to figure out the reset situation within `_simple_graph_snapshot_restoration` and ProtoRS


if __name__ == "__main__":
Expand Down
33 changes: 32 additions & 1 deletion torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
serialize_datapipe,
wrap_datapipe_for_serialization,
)
from torchdata.dataloader2.reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface
from torchdata.dataloader2.reading_service import (
CheckpointableReadingServiceInterface,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
)

T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
Expand Down Expand Up @@ -152,6 +156,7 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:
if self._reset_iter:
if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
# Only called once when `self._adapted = False`
self.datapipe = self.reading_service.initialize(self.datapipe)
else:
if not isinstance(self.reading_service, CheckpointableReadingServiceInterface):
Expand Down Expand Up @@ -273,3 +278,29 @@ def resume(self):
self._paused = False
else:
warnings.warn("ReadingService doesn't support resume.")

def _get_naive_datapipe_snapshot(self):
"""
Return a snapshot of the DataPipe
"""
if not isinstance(self.reading_service, PrototypeMultiProcessingReadingService):
raise RuntimeError(
"Only `PrototypeMultiProcessingReadingService` " "currently supports naive DataPipe snapshotting."
)
self.pause()
n_samples_yielded, _initial_seed = self.reading_service._get_naive_datapipe_snapshot()
self.resume()
return n_samples_yielded, _initial_seed

def _restore_naive_datapipe_snapshot(self, n_samples_yielded, initial_seed) -> None:
if not isinstance(self.reading_service, PrototypeMultiProcessingReadingService):
raise RuntimeError(
"Only `PrototypeMultiProcessingReadingService` " "currently supports naive DataPipe snapshotting."
)
if not self._adapted:
self.datapipe = self.reading_service.initialize(self.datapipe)
self._adapted = True
self.reading_service._restore_naive_datapipe_snapshot(n_samples_yielded, initial_seed)
# TODO: I might want to skip `initialize_iteration` after this????

# TODO: Integrate this with the existing API? Is anyone using these at the moment?
12 changes: 12 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter.grouping import SHARDING_PRIORITIES
from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration

from torchdata._constants import default_dl2_worker_join_timeout_in_s, default_timeout_in_s
from torchdata.dataloader2 import communication
Expand Down Expand Up @@ -249,6 +250,7 @@ def __init__(
self._pg = None
self._world_size = 1
self._rank = 0
self._initial_seed = None

def initialize(self, datapipe: DataPipe) -> DataPipe:
r"""
Expand Down Expand Up @@ -303,6 +305,7 @@ def initialize_iteration(self) -> None:
shared_seed_int: int = shared_seed.item() # type: ignore[assignment]
_seed_generator = torch.Generator()
_seed_generator.manual_seed(shared_seed_int)
self._initial_seed = shared_seed_int
torch.utils.data.graph_settings.apply_random_seed(
self.end_datapipe, # type: ignore[arg-type]
_seed_generator,
Expand Down Expand Up @@ -397,6 +400,15 @@ def _resume(self):
if self.main_prefetch_cnt > 0 and self.num_workers > 0:
self.end_datapipe.resume() # type: ignore[union-attr]

def _get_naive_datapipe_snapshot(self):
return self.end_datapipe._number_of_samples_yielded, self._initial_seed

def _restore_naive_datapipe_snapshot(self, n_samples_yielded, initial_seed):
initial_seed_generator = torch.Generator()
initial_seed_generator.manual_seed(initial_seed)
_simple_graph_snapshot_restoration(self.end_datapipe, n_samples_yielded, initial_seed_generator)
# TODO: I might want to skip `initialize_iteration` after this????


class MultiProcessingReadingService(ReadingServiceInterface):
r"""
Expand Down

0 comments on commit 250aa2b

Please sign in to comment.