-
Notifications
You must be signed in to change notification settings - Fork 811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add recommendations regarding use of datapipes for multi-processing, shuffling, DDP, etc. #1755
Changes from 7 commits
07c7dbe
3df58d7
3d205e0
77c46e7
cae3f4e
87526fd
eb4e2a7
f3f2b08
8b0f17c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,61 @@ torchtext.datasets | |
|
||
.. currentmodule:: torchtext.datasets | ||
|
||
|
||
.. _datapipes_warnings: | ||
|
||
.. warning:: | ||
|
||
The datasets supported by torchtext are datapipes from the `torchdata | ||
project <https://pytorch.org/data/beta/index.html>`_, which is still in Beta | ||
status. This means that the API is subject to change without deprecation | ||
cycles. In particular, we expect a lot of the current idioms to change with | ||
the eventual release of ``DataLoaderV2`` from ``torchdata``. | ||
|
||
Here are a few recommendations regarding the use of datapipes: | ||
|
||
- For shuffling the datapipe, do that in the DataLoader: ``DataLoader(dp, shuffle=True)``. | ||
You do not need to call ``dp.shuffle()``, because ``torchtext`` has | ||
already done that for you. Note however that the datapipe won't be | ||
shuffled unless you explicitly pass ``shuffle=True`` to the DataLoader. | ||
|
||
- When using multi-processing (``num_workers=N``), use the builtin ``worker_init_fn``:: | ||
|
||
from torch.utils.data.backward_compatibility import worker_init_fn | ||
DataLoader(dp, num_workers=4, worker_init_fn=worker_init_fn, drop_last=True) | ||
|
||
This will ensure that data isn't duplicated across workers. | ||
|
||
- We also recommend using ``drop_last=True``. Without this, the batch sizes | ||
at the end of an epoch may be very small in some cases (smaller than with | ||
other map-style datasets). This might affect accuracy greatly especially | ||
when batch-norm is used. ``drop_last=True`` ensures that all batch sizes | ||
are equal. | ||
|
||
- Distributed training with ``DistributedDataParallel`` is not yet entirely | ||
stable / supported, and we don't recommend it at this point. It will be | ||
better supported in DataLoaderV2. If you still wish to use DDP, make sure | ||
that: | ||
|
||
- All workers (DDP workers *and* DataLoader workers) see a different part | ||
of the data. The datasets are already wrapped inside `ShardingFilter | ||
<https://pytorch.org/data/main/generated/torchdata.datapipes.iter.ShardingFilter.html>`_ | ||
and you may need to call ``dp.apply_sharing(num_shards, shard_id)`` in order to shard the | ||
data across ranks (DDP workers) and DataLoader workers. One way to do this | ||
is to create ``worker_init_fn`` that calls ``apply_sharding`` with appropriate | ||
number of shards (DDP workers * DataLoader workers) and shard id (inferred through rank | ||
and worker ID of corresponding DataLoader withing rank). Note however, that this assumes | ||
equal number of DataLoader workers for all the ranks. | ||
- All DDP workers work on the same number of batches. One way to do this | ||
is to by limit the size of the datapipe within each worker to | ||
``len(datapipe) // num_ddp_workers``, but this might not suit all | ||
use-cases. | ||
- The shuffling seed is the same across all workers. You might need to | ||
call ``torch.utils.data.graph_settings.apply_shuffle_seed(dp, rng)`` | ||
- The shuffling seed is different across epochs. | ||
- The rest of the RNG (typically used for transformations) is | ||
**different** across workers, for maximal entropy and optimal accuracy. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the numbers of total batches fetched across different ranks is different, it could potentially stall the training right? So we somehow need to ensure that the number of samples seen by each rank is same. I wonder what are the workarounds for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, it gets more complex real fast :) In torchvision we have this custom I'll mention this is as well, but I think we'd rather not directly link to this torchvision snippet, as this is definitely not a recommended practice yet. WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not unusual to have different number of samples in each rank, for example, if the datapipe expects the same number of tokens/characters for each batch, each sentence may have different number of tokens/characters, then the batch size will be different. As long as the loss value is normalized in a correct way, the training should be fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm... Interseting. My current understanding is that having datapipes with different sizes in DDP will lead to either crashes, or infinite hanging. Here's a snippet to illustrate this: import torch
import torch.utils.data as data
import torch.distributed as dist
def replace_print():
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
for rank in range(dist.get_world_size()):
if rank == dist.get_rank():
builtin_print(f"[DDP worker with rank={rank}]", *args, **kwargs)
dist.barrier()
__builtin__.print = print
class MyIterableDS(data.IterableDataset):
def __init__(self, size=100):
self.size = size
def __iter__(self):
worker_info = data.get_worker_info()
num_dl_workers = worker_info.num_workers
dl_worker_id = worker_info.id
num_ddp_workers = dist.get_world_size()
ddp_worker_id = dist.get_rank()
for i, s in enumerate(range(self.size)):
if i % num_ddp_workers == ddp_worker_id:
if i % num_dl_workers == dl_worker_id:
yield s
# EXTRA SAMPLE
# Uncomment this and you'll get an error
# if ddp_worker_id == 0:
# yield 100
def __len__(self):
return self.size
dist.init_process_group(backend="gloo")
replace_print()
dist.barrier()
ds = MyIterableDS()
dl = torch.utils.data.DataLoader(ds, batch_size=10, num_workers=4)
for i, batch in enumerate(dl):
print(batch) You can run this locally (without GPUs) with e.g.
Did you observe similar behaviour @nateanl ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC, the total batch number should be the same but the batch size can vary across RANKs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NicolasHug Oh you are referring the number of batches in each rank. That's correct, DDP will hang forever if the number of batches are not even in all ranks. I was talking about the batch_size and it can be different :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks all, I added a section about all this just above
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yupp, SGTM! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks for the code pointer! Sounds interesting :). I guess user would need to know the length of dataset in advance in order to ensure that num_take doesn't exceed the size.
I think I agree! We should probably just make sure users are aware of this issue which you have already addressed here #1755 (comment) |
||
General use cases are as follows: :: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Would it be useful to have a short code snippet using a dataset with DataLoader (with I do think the current examples are likely sufficient but I wonder if some users need more hand-holding/guidance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yupp, I guess it would be good to add usage examples with DataLoader. I guess, we can remove worker_init_fn from code examples as @VitalyFedyunin is working on PR that would automatically apply this to all datapipes based datasets :). |
||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NicolasHug, @VitalyFedyunin , @ejguan : I made the description rather high level while leaving implementation details to the users. My thinking here is to make the user aware of pitfalls without necessarily providing specific recommendation on implementing specific solution.