Skip to content

Commit

Permalink
Update examples and tutorial after automatic sharding has landed
Browse files Browse the repository at this point in the history
ghstack-source-id: 37f8c80c20998c15438b7e4129cb518c49cdd1ac
Pull Request resolved: #505
  • Loading branch information
NivekT committed Jun 9, 2022
1 parent b056a6f commit 171e657
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 28 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,12 @@ convert it to `MapDataPipe` as needed.

Q: How is multiprocessing handled with DataPipes?

A: Multi-process data loading is still handled by DataLoader, see the
A: Multi-process data loading is still handled by the `DataLoader`, see the
[DataLoader documentation for more details](https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading).
If you would like to shard data across processes, use `ShardingFilter` and provide a `worker_init_fn` as shown in the
[tutorial](https://pytorch.org/data/beta/tutorial.html#working-with-dataloader).
As of PyTorch version >= 1.12.0 (TorchData version >= 0.4.0), data sharding is automatically done for DataPipes within
the `DataLoader` as long as a `ShardingFiler` DataPipe exists in your pipeline, you no longer need to pass in a backward
compatility `worker_init_fn`. Please see the
[tutorial](https://pytorch.org/data/beta/tutorial.html#working-with-dataloader) for an example.

Q: What is the upcoming plan for DataLoader?

Expand Down
15 changes: 2 additions & 13 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pass defined functions to DataPipes rather than lambda functions because the for
return datapipe
Lastly, we will put everything together in ``'__main__'`` and pass the DataPipe into the DataLoader. Note that
if you choose to use `Batcher` while setting `batch_size > 1` for DataLoader, your samples will be
if you choose to use ``Batcher`` while setting ``batch_size > 1`` for DataLoader, your samples will be
batched more than once. You should choose one or the other.

.. code:: python
Expand Down Expand Up @@ -154,21 +154,10 @@ In order for DataPipe sharding to work with ``DataLoader``, we need to add the f
def build_datapipes(root_dir="."):
datapipe = ...
# Add the following line to `build_datapipes`
# Note that it is somewhere after `Shuffler` in the DataPipe line
# Note that it is somewhere after `Shuffler` in the DataPipe line, but before expensive operations
datapipe = datapipe.sharding_filter()
return datapipe
def worker_init_fn(worker_id):
info = torch.utils.data.get_worker_info()
num_workers = info.num_workers
datapipe = info.dataset
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)
# Pass `worker_init_fn` into `DataLoader` within '__main__'
...
dl = DataLoader(dataset=datapipe, shuffle=True, batch_size=5, num_workers=2, worker_init_fn=worker_init_fn)
...
When we re-run, we will get:

.. code::
Expand Down
4 changes: 0 additions & 4 deletions examples/vision/imagefolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import re
import threading

import torch
import torch.utils.data.backward_compatibility
import torchvision.datasets as datasets
import torchvision.datasets.folder
import torchvision.transforms as transforms
Expand Down Expand Up @@ -168,7 +166,6 @@ def MyHTTPImageFolder(transform=None):
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn,
)
items = list(dl)
assert len(items) == 6
Expand All @@ -186,7 +183,6 @@ def MyHTTPImageFolder(transform=None):
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn,
)

try:
Expand Down
9 changes: 1 addition & 8 deletions test/test_local_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,6 @@ def filepath_fn(temp_dir_name, name: str) -> str:
return os.path.join(temp_dir_name, os.path.basename(name))


def init_fn(worker_id):
info = torch.utils.data.get_worker_info()
num_workers = info.num_workers
datapipe = info.dataset
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)


def _unbatch(x):
return x[0]

Expand Down Expand Up @@ -753,7 +746,7 @@ def test_io_path_saver_file_lock(self):

num_workers = 2
line_lengths = []
dl = DataLoader(saver_dp, num_workers=num_workers, worker_init_fn=init_fn, multiprocessing_context="spawn")
dl = DataLoader(saver_dp, num_workers=num_workers, multiprocessing_context="spawn")
for filename in dl:
with open(filename[0]) as f:
lines = f.readlines()
Expand Down

0 comments on commit 171e657

Please sign in to comment.