diff --git a/README.md b/README.md index 215d58d95..b61b9e339 100644 --- a/README.md +++ b/README.md @@ -200,10 +200,11 @@ convert it to `MapDataPipe` as needed. Q: How is multiprocessing handled with DataPipes? -A: Multi-process data loading is still handled by DataLoader, see the +A: Multi-process data loading is still handled by the `DataLoader`, see the [DataLoader documentation for more details](https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading). -If you would like to shard data across processes, use `ShardingFilter` and provide a `worker_init_fn` as shown in the -[tutorial](https://pytorch.org/data/beta/tutorial.html#working-with-dataloader). +As of PyTorch version >= 1.12.0 (TorchData version >= 0.4.0), data sharding is automatically done for DataPipes within +the `DataLoader` as long as a `ShardingFiler` DataPipe exists in your pipeline. Please see the +[tutorial](https://pytorch.org/data/beta/tutorial.html#working-with-dataloader) for an example. Q: What is the upcoming plan for DataLoader? diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index c5c241ded..e049aa0d2 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -105,7 +105,7 @@ pass defined functions to DataPipes rather than lambda functions because the for return datapipe Lastly, we will put everything together in ``'__main__'`` and pass the DataPipe into the DataLoader. Note that -if you choose to use `Batcher` while setting `batch_size > 1` for DataLoader, your samples will be +if you choose to use ``Batcher`` while setting ``batch_size > 1`` for DataLoader, your samples will be batched more than once. You should choose one or the other. .. code:: python @@ -154,21 +154,10 @@ In order for DataPipe sharding to work with ``DataLoader``, we need to add the f def build_datapipes(root_dir="."): datapipe = ... # Add the following line to `build_datapipes` - # Note that it is somewhere after `Shuffler` in the DataPipe line + # Note that it is somewhere after `Shuffler` in the DataPipe line, but before expensive operations datapipe = datapipe.sharding_filter() return datapipe - def worker_init_fn(worker_id): - info = torch.utils.data.get_worker_info() - num_workers = info.num_workers - datapipe = info.dataset - torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id) - - # Pass `worker_init_fn` into `DataLoader` within '__main__' - ... - dl = DataLoader(dataset=datapipe, shuffle=True, batch_size=5, num_workers=2, worker_init_fn=worker_init_fn) - ... - When we re-run, we will get: .. code:: diff --git a/examples/vision/imagefolder.py b/examples/vision/imagefolder.py index c1286db4e..5bec089ef 100644 --- a/examples/vision/imagefolder.py +++ b/examples/vision/imagefolder.py @@ -9,8 +9,6 @@ import re import threading -import torch -import torch.utils.data.backward_compatibility import torchvision.datasets as datasets import torchvision.datasets.folder import torchvision.transforms as transforms @@ -168,7 +166,6 @@ def MyHTTPImageFolder(transform=None): batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, - worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn, ) items = list(dl) assert len(items) == 6 @@ -186,7 +183,6 @@ def MyHTTPImageFolder(transform=None): batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, - worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn, ) try: diff --git a/test/test_local_io.py b/test/test_local_io.py index 29f8de3b5..867517506 100644 --- a/test/test_local_io.py +++ b/test/test_local_io.py @@ -77,13 +77,6 @@ def filepath_fn(temp_dir_name, name: str) -> str: return os.path.join(temp_dir_name, os.path.basename(name)) -def init_fn(worker_id): - info = torch.utils.data.get_worker_info() - num_workers = info.num_workers - datapipe = info.dataset - torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id) - - def _unbatch(x): return x[0] @@ -753,7 +746,7 @@ def test_io_path_saver_file_lock(self): num_workers = 2 line_lengths = [] - dl = DataLoader(saver_dp, num_workers=num_workers, worker_init_fn=init_fn, multiprocessing_context="spawn") + dl = DataLoader(saver_dp, num_workers=num_workers, multiprocessing_context="spawn") for filename in dl: with open(filename[0]) as f: lines = f.readlines()