-
Notifications
You must be signed in to change notification settings - Fork 811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add recommendations regarding use of datapipes for multi-processing, shuffling, DDP, etc. #1755
Conversation
…huffling, DDP, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much @NicolasHug for creating this PR. I think overall it LGTM. I just have few general questions in terms of Distributed training :).
docs/source/datasets.rst
Outdated
- All workers (DDP workers *and* DataLoader workers) see a different part | ||
of the data. You might need to call ``dp = dp.apply_sharding(world_size, | ||
rank)``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am trying to understand if we are able to do Distributed together with Multi-processing. IIUC, when we do
dp.apply_sharding(world_size, rank)
on the dataset, it would shard data across ranks and ensure each rank see unique samples. Now when I pass dp to DataLoader with num_workers>0, would I still be able to shard it further with the help of worker_init_fn? I am not sure if we can wrap the dp again inside sharding_filter in order for worker_init_fn to shard the dp further across multiple processes in each rank? cc: @ejguan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@parmeet good catch - users would need to wrap the dp again before calling apply_sharding()
. I'll update the comment
from torchdata.datapipes.iter import IterableWrapper, ShardingFilter
a = IterableWrapper(range(30))
a = ShardingFilter(a)
a.apply_sharding(num_of_instances=2, instance_id=0)
a = ShardingFilter(a) # without this, only the second sharding (5) applies
a.apply_sharding(num_of_instances=5, instance_id=0)
list(a)
# [0, 10, 20]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, users can run dp.apply_sharding(world_size * worker_num, rank * worker_num + worker_id)
in the worker_init_fn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this sounds neat. This would also avoid applying sharding filter multiple times. Given that datasets already wrapped the dp inside sharding filter, all we need to do is provide the implementation of worker_init_fn that does sharding according to world size and worker IDs.
def worker_init_fn(worker_id, world_size, rank):
info = torch.utils.data.get_worker_info()
num_workers = info.num_workers
datapipe = info.dataset
datapipe.apply_sharding(world_size * worker_num, rank * worker_num + worker_id)
# Initialize DataLoader within each rank with above function
dl = DataLoader(dataset, worker_init_fn=functools.partial(worker_init_fn, world_size=world_size, rank=rank), ...)
Does this sounds about right @ejguan?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably best to get work_size and rank from within the function? One can use int(os.environ["RANK"])
and int(os.environ["WORLD_SIZE"])
. Also for clarity maybe we could rename these to num_ddp_workers
, num_dataloader_workers
, ddp_worker_id
, dataloader_worker_id
.
It might be worth documenting both ways? The second one is somehow more complex, because it exposes the DataLoader sharding level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, these env variable will exist as long as the script is launched with torchrun
, which is the recommended practice for DDP scripts. They might not exist cases in some special cases, like when using submitit
, but these are advanced use-cases and such users would know how to tweak the snippet easily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM! I think we can leave the details of implementation to user. At higher level, probably it is sufficient to communicate the workaround with worker_init_fn :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do not use ENV variables. There are built-in things for it:
import torch.distributed as dist
if dist.is_available():
total_workers *= dist.get_world_size()
global_worker_id = dist.get_rank()*info.num_workers + global_worker_id
I will submit PR(s) to apply it automatically this week.
The only remaining problem is the possibility of having a different number of processes on different ranks. I prefer to keep one .sharding_filter
for it, but perhaps would have to introduce sharing ranks/levels into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @VitalyFedyunin for the feedback. Just so I understand, do you mean to submit PR for updating worker_init_fn such that it not only works for multi-processing case (as of now) but shall also take into account distributed settings like the one propose here #1755 (comment)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. And also classical DataLoader will automatically add this worker_init_fn pieces to all DataPipe-based Datasets. pytorch/pytorch#78631
- The suffling seed is different across epochs. | ||
- The rest of the RNG (typically used for transformations) is | ||
**different** across workers, for maximal entropy and optimal accuracy. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the numbers of total batches fetched across different ranks is different, it could potentially stall the training right? So we somehow need to ensure that the number of samples seen by each rank is same. I wonder what are the workarounds for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, it gets more complex real fast :)
In torchvision we have this custom Taker
datapipe that limits the size of the dp so that it's consistent across DDP workers https://github.com/pytorch/vision/blob/main/torchvision/prototype/datasets/utils/_internal.py#L144-L145
I'll mention this is as well, but I think we'd rather not directly link to this torchvision snippet, as this is definitely not a recommended practice yet. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not unusual to have different number of samples in each rank, for example, if the datapipe expects the same number of tokens/characters for each batch, each sentence may have different number of tokens/characters, then the batch size will be different. As long as the loss value is normalized in a correct way, the training should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm... Interseting. My current understanding is that having datapipes with different sizes in DDP will lead to either crashes, or infinite hanging.
Here's a snippet to illustrate this:
import torch
import torch.utils.data as data
import torch.distributed as dist
def replace_print():
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
for rank in range(dist.get_world_size()):
if rank == dist.get_rank():
builtin_print(f"[DDP worker with rank={rank}]", *args, **kwargs)
dist.barrier()
__builtin__.print = print
class MyIterableDS(data.IterableDataset):
def __init__(self, size=100):
self.size = size
def __iter__(self):
worker_info = data.get_worker_info()
num_dl_workers = worker_info.num_workers
dl_worker_id = worker_info.id
num_ddp_workers = dist.get_world_size()
ddp_worker_id = dist.get_rank()
for i, s in enumerate(range(self.size)):
if i % num_ddp_workers == ddp_worker_id:
if i % num_dl_workers == dl_worker_id:
yield s
# EXTRA SAMPLE
# Uncomment this and you'll get an error
# if ddp_worker_id == 0:
# yield 100
def __len__(self):
return self.size
dist.init_process_group(backend="gloo")
replace_print()
dist.barrier()
ds = MyIterableDS()
dl = torch.utils.data.DataLoader(ds, batch_size=10, num_workers=4)
for i, batch in enumerate(dl):
print(batch)
You can run this locally (without GPUs) with e.g. torchrun --nproc_per_node=4 scipt.py
Uncommentting the EXTRA SAMPLE
part, I get the following error:
Traceback (most recent call last):
File "/home/nicolashug/dev/vision/lol.py", line 55, in <module>
print(batch)
File "/home/nicolashug/dev/vision/lol.py", line 12, in print
dist.barrier()
File "/home/nicolashug/.miniconda3/envs/pt/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2785, in barrier
work.wait()
RuntimeError: [/opt/conda/conda-bld/pytorch_1649142626512/work/third_party/gloo/gloo/transport/tcp/pair.cc:598] Connection closed by peer [fe80::f9d3:3540:64c7:287b]:34704
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 241787) of binary: /home/nicolashug/.miniconda3/envs/pt/bin/python
Did you observe similar behaviour @nateanl ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, the total batch number should be the same but the batch size can vary across RANKs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NicolasHug Oh you are referring the number of batches in each rank. That's correct, DDP will hang forever if the number of batches are not even in all ranks. I was talking about the batch_size and it can be different :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks all, I added a section about all this just above
- All DDP workers work on the same number of batches. One way to do this is to by limit the size of the datapipe within each worker to
len(datapipe) // num_ddp_workers
, but this might not suit all use-cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yupp, SGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In torchvision we have this custom
Taker
datapipe that limits the size of the dp so that it's consistent across DDP workers https://github.com/pytorch/vision/blob/main/torchvision/prototype/datasets/utils/_internal.py#L144-L145
Thanks for the code pointer! Sounds interesting :). I guess user would need to know the length of dataset in advance in order to ensure that num_take doesn't exceed the size.
I'll mention this is as well, but I think we'd rather not directly link to this torchvision snippet, as this is definitely not a recommended practice yet. WDYT?
I think I agree! We should probably just make sure users are aware of this issue which you have already addressed here #1755 (comment)
- The suffling seed is different across epochs. | ||
- The rest of the RNG (typically used for transformations) is | ||
**different** across workers, for maximal entropy and optimal accuracy. | ||
|
||
General use cases are as follows: :: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Would it be useful to have a short code snippet using a dataset with DataLoader (with shuffle=True
, worker_init_fn
)?
I do think the current examples are likely sufficient but I wonder if some users need more hand-holding/guidance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yupp, I guess it would be good to add usage examples with DataLoader. I guess, we can remove worker_init_fn from code examples as @VitalyFedyunin is working on PR that would automatically apply this to all datapipes based datasets :).
docs/source/datasets.rst
Outdated
- All workers (DDP workers *and* DataLoader workers) see a different part | ||
of the data. You might need to call ``dp = dp.apply_sharding(world_size, | ||
rank)`` after wrapping the datapipe into a `ShardingFilter | ||
<https://pytorch.org/data/main/generated/torchdata.datapipes.iter.ShardingFilter.html>`_. | ||
You'll want to do that early in the datapipe, to avoid needlessly | ||
processing samples that eventually get dropped by the workers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NicolasHug I would suggest we wait for @VitalyFedyunin to land PRs that would take care of sharding automatically through usage of worker_init_fn and we can remove this particular point. As per the issue #1727, we have already wrapped datasets in sharding filter, which means that the users do not have to apply it explicitly nor have to take into account DataLoaderV1+worker_init_fn mechanics :).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depends if you want to publish this doc with 1.12 as my patch will not make into it, and will be in master until 1.13
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. Sure, will document it with current release.
docs/source/datasets.rst
Outdated
- All workers (DDP workers *and* DataLoader workers) see a different part | ||
of the data. The datasets are already wrapped inside `ShardingFilter | ||
<https://pytorch.org/data/main/generated/torchdata.datapipes.iter.ShardingFilter.html>`_ | ||
and you may need to call ``dp.apply_sharing(num_shards, shard_id)`` in order to shard the | ||
data across ranks (DDP workers) and DataLoader workers. One way to do this | ||
is to create ``worker_init_fn`` that calls ``apply_sharding`` with appropriate | ||
number of shards (DDP workers * DataLoader workers) and shard id (inferred through rank | ||
and worker ID of corresponding DataLoader withing rank). Note however, that this assumes | ||
equal number of DataLoader workers for all the ranks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NicolasHug, @VitalyFedyunin , @ejguan : I made the description rather high level while leaving implementation details to the users. My thinking here is to make the user aware of pitfalls without necessarily providing specific recommendation on implementing specific solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM! We should be good to land this. We can cherry-pick if there are additional recommendations :). Thanks so much @NicolasHug for your help with documenting various pitfalls and recommendations.
Closes #1751
Rendered docs: https://output.circle-artifacts.com/output/job/62244176-8485-40e7-8a19-92d89b4a64a3/artifacts/0/docs/datasets.html (not fully up to date)
[Edit: Parmeet]: Rendered docs: https://output.circle-artifacts.com/output/job/a328cfc0-5824-4a1a-a3bf-630e09ff37ee/artifacts/0/docs/datasets.html
@parmeet @Nayef211 the rest of the week is a bank holiday for the UK, so I might not be able to address comments right away. Please feel free to directly edit this PR!
Also FYI @ejguan @NivekT @VitalyFedyunin