Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stateful dataloader tutorial docs #1303

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Features described in this documentation are classified by release status:
:maxdepth: 2
:caption: Tutorial and Examples:

stateful_dataloader_tutorial.rst
dp_tutorial.rst
dlv2_tutorial.rst
examples.rst
Expand Down
177 changes: 177 additions & 0 deletions docs/source/stateful_dataloader_tutorial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
Stateful DataLoader Tutorial
============================

Saving and loading state
------------------------

Stateful DataLoader adds the ``load_state_dict``, ``state_dict`` methods to the ``torch.utils.data.DataLoader``. State fetch and set can be done as follows:

.. code:: python

from torchdata.stateful_dataloader import StatefulDataLoader

dataloader = StatefulDataLoader(dataset, num_workers=2)
for i, batch in enumerate(dataloader):
...
if i == 10:
state_dict = dataloader.state_dict()
break

# Training run resumes with the previous checkpoint
dataloader = StatefulDataLoader(dataset, num_workers=2)
# Resume state with DataLoader
dataloader.load_state_dict(state_dict)
for i, batch in enumerate(dataloader):
...

Saving Custom State with Map-Style Datasets
-------------------------------------------

For efficient resuming of `Map-style datasets <https://pytorch.org/docs/stable/data.html#map-style-datasets>`_, you can resume iteration by defining ``state_dict`` / ``load_state_dict`` methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add ``state_dict`` / ``load_state_dict`` methods to your dataset.

.. code:: python

from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader

# If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary
class MySampler(torch.utils.data.Sampler[int]):
def __init__(self, high: int, seed: int, limit: int):
self.seed, self.high, self.limit = seed, high, limit
self.g = torch.Generator()
self.g.manual_seed(self.seed)
self.i = 0

def __iter__(self):
while self.i < self.limit:
val = int(torch.randint(high=self.high, size=(1,), generator=self.g))
self.i += 1
yield val

def load_state_dict(self, state_dict: Dict[str, Any]):
self.i = state_dict["i"]
self.g.set_state(state_dict["rng"])

def state_dict(self) -> Dict[str, Any]:
return {"i": self.i, "rng": self.g.get_state()}

# Optional: save dataset random transform state
class NoisyRange(torch.utils.data.Dataset):
def __init__(self, high: int, mean: float, std: float):
self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std)

def __len__(self):
return self.high

def __getitem__(self, idx: int) -> float:
if not (0 <= idx < self.high):
raise IndexError()
x = torch.normal(self.mean, self.std)
noise = x.item()
return idx + noise

def load_state_dict(self, state_dict):
torch.set_rng_state(state_dict["rng"])

def state_dict(self):
return {"rng": torch.get_rng_state()}

# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10),
batch_size=2, drop_last=False, num_workers=num_workers)

batches = []
for i, batch in enumerate(dl):
batches.append(batch)
if i == 2:
sd = dl.state_dict()

dl.load_state_dict(sd)
batches2 = list(dl)

print(batches[3:])
print(batches2)

"""
Output:
num_workers=0
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
num_workers=2
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
"""

Saving Custom State with Iterable-Style Datasets
------------------------------------------------

Tracking iteration order with `Iterable-style datasets <https://pytorch.org/docs/stable/data.html#iterable-style-datasets>`_ requires state from each worker-level instance of the dataset to be captured. You can define ``state_dict`` / ``load_state_dict`` methods on your dataset which capture worker-level state. :class:`StatefulDataLoader` will handle aggregation across workers and distribution back to the workers. Calling ``load_state_dict`` requires :class:`StatefulDataLoader`` to have same ``num_workers`` as those of the provided ``state_dict``.

.. code:: python

from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader


class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, high: int, seed: int):
self.high, self.seed = high, seed
self.g = torch.Generator()
self.i = 0

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
num_workers = worker_info.num_workers
else:
worker_id = 0
num_workers = 1
self.g.manual_seed(self.seed)
arr = torch.randperm(self.high, generator=self.g)
arr = arr[worker_id:self.high:num_workers]
for idx in range(self.i, len(arr)):
self.i += 1
yield arr[idx]
self.i = 0

def state_dict(self):
return {"i": self.i}

def load_state_dict(self, state_dict):
self.i = state_dict["i"]

# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(
MyIterableDataset(12, 0), batch_size=2, drop_last=False,
num_workers=num_workers)

batches = []
for i, batch in enumerate(dl):
batches.append(batch)
if i == 2:
sd = dl.state_dict()

dl.load_state_dict(sd)
batches2 = list(dl)

print(batches[3:])
print(batches2)

"""
Output:
num_workers=0
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])]
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])]
num_workers=2
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
"""
Loading