Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update examples and tutorial after automatic sharding has landed #505

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,11 @@ convert it to `MapDataPipe` as needed.

Q: How is multiprocessing handled with DataPipes?

A: Multi-process data loading is still handled by DataLoader, see the
A: Multi-process data loading is still handled by the `DataLoader`, see the
[DataLoader documentation for more details](https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading).
If you would like to shard data across processes, use `ShardingFilter` and provide a `worker_init_fn` as shown in the
[tutorial](https://pytorch.org/data/beta/tutorial.html#working-with-dataloader).
As of PyTorch version >= 1.12.0 (TorchData version >= 0.4.0), data sharding is automatically done for DataPipes within
the `DataLoader` as long as a `ShardingFiler` DataPipe exists in your pipeline. Please see the
[tutorial](https://pytorch.org/data/beta/tutorial.html#working-with-dataloader) for an example.

Q: What is the upcoming plan for DataLoader?

Expand Down
15 changes: 2 additions & 13 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pass defined functions to DataPipes rather than lambda functions because the for
return datapipe

Lastly, we will put everything together in ``'__main__'`` and pass the DataPipe into the DataLoader. Note that
if you choose to use `Batcher` while setting `batch_size > 1` for DataLoader, your samples will be
if you choose to use ``Batcher`` while setting ``batch_size > 1`` for DataLoader, your samples will be
batched more than once. You should choose one or the other.

.. code:: python
Expand Down Expand Up @@ -154,21 +154,10 @@ In order for DataPipe sharding to work with ``DataLoader``, we need to add the f
def build_datapipes(root_dir="."):
datapipe = ...
# Add the following line to `build_datapipes`
# Note that it is somewhere after `Shuffler` in the DataPipe line
# Note that it is somewhere after `Shuffler` in the DataPipe line, but before expensive operations
datapipe = datapipe.sharding_filter()
return datapipe

def worker_init_fn(worker_id):
info = torch.utils.data.get_worker_info()
num_workers = info.num_workers
datapipe = info.dataset
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)

# Pass `worker_init_fn` into `DataLoader` within '__main__'
...
dl = DataLoader(dataset=datapipe, shuffle=True, batch_size=5, num_workers=2, worker_init_fn=worker_init_fn)
...

When we re-run, we will get:

.. code::
Expand Down
4 changes: 0 additions & 4 deletions examples/vision/imagefolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import re
import threading

import torch
import torch.utils.data.backward_compatibility
import torchvision.datasets as datasets
import torchvision.datasets.folder
import torchvision.transforms as transforms
Expand Down Expand Up @@ -168,7 +166,6 @@ def MyHTTPImageFolder(transform=None):
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn,
)
items = list(dl)
assert len(items) == 6
Expand All @@ -186,7 +183,6 @@ def MyHTTPImageFolder(transform=None):
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn,
)

try:
Expand Down
9 changes: 1 addition & 8 deletions test/test_local_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,6 @@ def filepath_fn(temp_dir_name, name: str) -> str:
return os.path.join(temp_dir_name, os.path.basename(name))


def init_fn(worker_id):
info = torch.utils.data.get_worker_info()
num_workers = info.num_workers
datapipe = info.dataset
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)


def _unbatch(x):
return x[0]

Expand Down Expand Up @@ -753,7 +746,7 @@ def test_io_path_saver_file_lock(self):

num_workers = 2
line_lengths = []
dl = DataLoader(saver_dp, num_workers=num_workers, worker_init_fn=init_fn, multiprocessing_context="spawn")
dl = DataLoader(saver_dp, num_workers=num_workers, multiprocessing_context="spawn")
for filename in dl:
with open(filename[0]) as f:
lines = f.readlines()
Expand Down