-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add stateful dataloader tutorial docs (#1303)
- Loading branch information
1 parent
dc01c77
commit 842fc0b
Showing
2 changed files
with
178 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
Stateful DataLoader Tutorial | ||
============================ | ||
|
||
Saving and loading state | ||
------------------------ | ||
|
||
Stateful DataLoader adds the ``load_state_dict``, ``state_dict`` methods to the ``torch.utils.data.DataLoader``. State fetch and set can be done as follows: | ||
|
||
.. code:: python | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
dataloader = StatefulDataLoader(dataset, num_workers=2) | ||
for i, batch in enumerate(dataloader): | ||
... | ||
if i == 10: | ||
state_dict = dataloader.state_dict() | ||
break | ||
# Training run resumes with the previous checkpoint | ||
dataloader = StatefulDataLoader(dataset, num_workers=2) | ||
# Resume state with DataLoader | ||
dataloader.load_state_dict(state_dict) | ||
for i, batch in enumerate(dataloader): | ||
... | ||
Saving Custom State with Map-Style Datasets | ||
------------------------------------------- | ||
|
||
For efficient resuming of `Map-style datasets <https://pytorch.org/docs/stable/data.html#map-style-datasets>`_, you can resume iteration by defining ``state_dict`` / ``load_state_dict`` methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add ``state_dict`` / ``load_state_dict`` methods to your dataset. | ||
|
||
.. code:: python | ||
from typing import * | ||
import torch | ||
import torch.utils.data | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
# If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary | ||
class MySampler(torch.utils.data.Sampler[int]): | ||
def __init__(self, high: int, seed: int, limit: int): | ||
self.seed, self.high, self.limit = seed, high, limit | ||
self.g = torch.Generator() | ||
self.g.manual_seed(self.seed) | ||
self.i = 0 | ||
def __iter__(self): | ||
while self.i < self.limit: | ||
val = int(torch.randint(high=self.high, size=(1,), generator=self.g)) | ||
self.i += 1 | ||
yield val | ||
def load_state_dict(self, state_dict: Dict[str, Any]): | ||
self.i = state_dict["i"] | ||
self.g.set_state(state_dict["rng"]) | ||
def state_dict(self) -> Dict[str, Any]: | ||
return {"i": self.i, "rng": self.g.get_state()} | ||
# Optional: save dataset random transform state | ||
class NoisyRange(torch.utils.data.Dataset): | ||
def __init__(self, high: int, mean: float, std: float): | ||
self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std) | ||
def __len__(self): | ||
return self.high | ||
def __getitem__(self, idx: int) -> float: | ||
if not (0 <= idx < self.high): | ||
raise IndexError() | ||
x = torch.normal(self.mean, self.std) | ||
noise = x.item() | ||
return idx + noise | ||
def load_state_dict(self, state_dict): | ||
torch.set_rng_state(state_dict["rng"]) | ||
def state_dict(self): | ||
return {"rng": torch.get_rng_state()} | ||
# Test both single/multiprocess dataloading | ||
for num_workers in [0, 2]: | ||
print(f"{num_workers=}") | ||
dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10), | ||
batch_size=2, drop_last=False, num_workers=num_workers) | ||
batches = [] | ||
for i, batch in enumerate(dl): | ||
batches.append(batch) | ||
if i == 2: | ||
sd = dl.state_dict() | ||
dl.load_state_dict(sd) | ||
batches2 = list(dl) | ||
print(batches[3:]) | ||
print(batches2) | ||
""" | ||
Output: | ||
num_workers=0 | ||
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)] | ||
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)] | ||
num_workers=2 | ||
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)] | ||
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)] | ||
""" | ||
Saving Custom State with Iterable-Style Datasets | ||
------------------------------------------------ | ||
|
||
Tracking iteration order with `Iterable-style datasets <https://pytorch.org/docs/stable/data.html#iterable-style-datasets>`_ requires state from each worker-level instance of the dataset to be captured. You can define ``state_dict`` / ``load_state_dict`` methods on your dataset which capture worker-level state. :class:`StatefulDataLoader` will handle aggregation across workers and distribution back to the workers. Calling ``load_state_dict`` requires :class:`StatefulDataLoader`` to have same ``num_workers`` as those of the provided ``state_dict``. | ||
|
||
.. code:: python | ||
from typing import * | ||
import torch | ||
import torch.utils.data | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
class MyIterableDataset(torch.utils.data.IterableDataset): | ||
def __init__(self, high: int, seed: int): | ||
self.high, self.seed = high, seed | ||
self.g = torch.Generator() | ||
self.i = 0 | ||
def __iter__(self): | ||
worker_info = torch.utils.data.get_worker_info() | ||
if worker_info is not None: | ||
worker_id = worker_info.id | ||
num_workers = worker_info.num_workers | ||
else: | ||
worker_id = 0 | ||
num_workers = 1 | ||
self.g.manual_seed(self.seed) | ||
arr = torch.randperm(self.high, generator=self.g) | ||
arr = arr[worker_id:self.high:num_workers] | ||
for idx in range(self.i, len(arr)): | ||
self.i += 1 | ||
yield arr[idx] | ||
self.i = 0 | ||
def state_dict(self): | ||
return {"i": self.i} | ||
def load_state_dict(self, state_dict): | ||
self.i = state_dict["i"] | ||
# Test both single/multiprocess dataloading | ||
for num_workers in [0, 2]: | ||
print(f"{num_workers=}") | ||
dl = StatefulDataLoader( | ||
MyIterableDataset(12, 0), batch_size=2, drop_last=False, | ||
num_workers=num_workers) | ||
batches = [] | ||
for i, batch in enumerate(dl): | ||
batches.append(batch) | ||
if i == 2: | ||
sd = dl.state_dict() | ||
dl.load_state_dict(sd) | ||
batches2 = list(dl) | ||
print(batches[3:]) | ||
print(batches2) | ||
""" | ||
Output: | ||
num_workers=0 | ||
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])] | ||
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])] | ||
num_workers=2 | ||
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])] | ||
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])] | ||
""" |