From 842fc0b85fb5699f763793c8ba54b56553607149 Mon Sep 17 00:00:00 2001 From: Gokul Date: Tue, 30 Jul 2024 11:04:07 -0700 Subject: [PATCH] Add stateful dataloader tutorial docs (#1303) --- docs/source/index.rst | 1 + docs/source/stateful_dataloader_tutorial.rst | 177 +++++++++++++++++++ 2 files changed, 178 insertions(+) create mode 100644 docs/source/stateful_dataloader_tutorial.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index b70d852a1..cec30b4ad 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -48,6 +48,7 @@ Features described in this documentation are classified by release status: :maxdepth: 2 :caption: Tutorial and Examples: + stateful_dataloader_tutorial.rst dp_tutorial.rst dlv2_tutorial.rst examples.rst diff --git a/docs/source/stateful_dataloader_tutorial.rst b/docs/source/stateful_dataloader_tutorial.rst new file mode 100644 index 000000000..7bfc7f930 --- /dev/null +++ b/docs/source/stateful_dataloader_tutorial.rst @@ -0,0 +1,177 @@ +Stateful DataLoader Tutorial +============================ + +Saving and loading state +------------------------ + +Stateful DataLoader adds the ``load_state_dict``, ``state_dict`` methods to the ``torch.utils.data.DataLoader``. State fetch and set can be done as follows: + +.. code:: python + + from torchdata.stateful_dataloader import StatefulDataLoader + + dataloader = StatefulDataLoader(dataset, num_workers=2) + for i, batch in enumerate(dataloader): + ... + if i == 10: + state_dict = dataloader.state_dict() + break + + # Training run resumes with the previous checkpoint + dataloader = StatefulDataLoader(dataset, num_workers=2) + # Resume state with DataLoader + dataloader.load_state_dict(state_dict) + for i, batch in enumerate(dataloader): + ... + +Saving Custom State with Map-Style Datasets +------------------------------------------- + +For efficient resuming of `Map-style datasets `_, you can resume iteration by defining ``state_dict`` / ``load_state_dict`` methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add ``state_dict`` / ``load_state_dict`` methods to your dataset. + +.. code:: python + + from typing import * + import torch + import torch.utils.data + from torchdata.stateful_dataloader import StatefulDataLoader + + # If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary + class MySampler(torch.utils.data.Sampler[int]): + def __init__(self, high: int, seed: int, limit: int): + self.seed, self.high, self.limit = seed, high, limit + self.g = torch.Generator() + self.g.manual_seed(self.seed) + self.i = 0 + + def __iter__(self): + while self.i < self.limit: + val = int(torch.randint(high=self.high, size=(1,), generator=self.g)) + self.i += 1 + yield val + + def load_state_dict(self, state_dict: Dict[str, Any]): + self.i = state_dict["i"] + self.g.set_state(state_dict["rng"]) + + def state_dict(self) -> Dict[str, Any]: + return {"i": self.i, "rng": self.g.get_state()} + + # Optional: save dataset random transform state + class NoisyRange(torch.utils.data.Dataset): + def __init__(self, high: int, mean: float, std: float): + self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std) + + def __len__(self): + return self.high + + def __getitem__(self, idx: int) -> float: + if not (0 <= idx < self.high): + raise IndexError() + x = torch.normal(self.mean, self.std) + noise = x.item() + return idx + noise + + def load_state_dict(self, state_dict): + torch.set_rng_state(state_dict["rng"]) + + def state_dict(self): + return {"rng": torch.get_rng_state()} + + # Test both single/multiprocess dataloading + for num_workers in [0, 2]: + print(f"{num_workers=}") + dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10), + batch_size=2, drop_last=False, num_workers=num_workers) + + batches = [] + for i, batch in enumerate(dl): + batches.append(batch) + if i == 2: + sd = dl.state_dict() + + dl.load_state_dict(sd) + batches2 = list(dl) + + print(batches[3:]) + print(batches2) + + """ + Output: + num_workers=0 + [tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)] + [tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)] + num_workers=2 + [tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)] + [tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)] + """ + +Saving Custom State with Iterable-Style Datasets +------------------------------------------------ + +Tracking iteration order with `Iterable-style datasets `_ requires state from each worker-level instance of the dataset to be captured. You can define ``state_dict`` / ``load_state_dict`` methods on your dataset which capture worker-level state. :class:`StatefulDataLoader` will handle aggregation across workers and distribution back to the workers. Calling ``load_state_dict`` requires :class:`StatefulDataLoader`` to have same ``num_workers`` as those of the provided ``state_dict``. + +.. code:: python + + from typing import * + import torch + import torch.utils.data + from torchdata.stateful_dataloader import StatefulDataLoader + + + class MyIterableDataset(torch.utils.data.IterableDataset): + def __init__(self, high: int, seed: int): + self.high, self.seed = high, seed + self.g = torch.Generator() + self.i = 0 + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + worker_id = worker_info.id + num_workers = worker_info.num_workers + else: + worker_id = 0 + num_workers = 1 + self.g.manual_seed(self.seed) + arr = torch.randperm(self.high, generator=self.g) + arr = arr[worker_id:self.high:num_workers] + for idx in range(self.i, len(arr)): + self.i += 1 + yield arr[idx] + self.i = 0 + + def state_dict(self): + return {"i": self.i} + + def load_state_dict(self, state_dict): + self.i = state_dict["i"] + + # Test both single/multiprocess dataloading + for num_workers in [0, 2]: + print(f"{num_workers=}") + dl = StatefulDataLoader( + MyIterableDataset(12, 0), batch_size=2, drop_last=False, + num_workers=num_workers) + + batches = [] + for i, batch in enumerate(dl): + batches.append(batch) + if i == 2: + sd = dl.state_dict() + + dl.load_state_dict(sd) + batches2 = list(dl) + + print(batches[3:]) + print(batches2) + + """ + Output: + num_workers=0 + [tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])] + [tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])] + num_workers=2 + [tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])] + [tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])] + """