Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recommendations regarding use of datapipes for multi-processing, shuffling, DDP, etc. #1755

Merged
merged 9 commits into from
Jun 2, 2022

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Jun 1, 2022

Closes #1751

Rendered docs: https://output.circle-artifacts.com/output/job/62244176-8485-40e7-8a19-92d89b4a64a3/artifacts/0/docs/datasets.html (not fully up to date)

[Edit: Parmeet]: Rendered docs: https://output.circle-artifacts.com/output/job/a328cfc0-5824-4a1a-a3bf-630e09ff37ee/artifacts/0/docs/datasets.html

@parmeet @Nayef211 the rest of the week is a bank holiday for the UK, so I might not be able to address comments right away. Please feel free to directly edit this PR!

Also FYI @ejguan @NivekT @VitalyFedyunin

Copy link
Contributor

@parmeet parmeet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much @NicolasHug for creating this PR. I think overall it LGTM. I just have few general questions in terms of Distributed training :).

Comment on lines 42 to 44
- All workers (DDP workers *and* DataLoader workers) see a different part
of the data. You might need to call ``dp = dp.apply_sharding(world_size,
rank)``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to understand if we are able to do Distributed together with Multi-processing. IIUC, when we do
dp.apply_sharding(world_size, rank) on the dataset, it would shard data across ranks and ensure each rank see unique samples. Now when I pass dp to DataLoader with num_workers>0, would I still be able to shard it further with the help of worker_init_fn? I am not sure if we can wrap the dp again inside sharding_filter in order for worker_init_fn to shard the dp further across multiple processes in each rank? cc: @ejguan

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parmeet good catch - users would need to wrap the dp again before calling apply_sharding(). I'll update the comment

from torchdata.datapipes.iter import IterableWrapper, ShardingFilter

a = IterableWrapper(range(30))

a = ShardingFilter(a)
a.apply_sharding(num_of_instances=2, instance_id=0)

a = ShardingFilter(a)  # without this, only the second sharding (5) applies
a.apply_sharding(num_of_instances=5, instance_id=0)

list(a)
# [0, 10, 20]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, users can run dp.apply_sharding(world_size * worker_num, rank * worker_num + worker_id) in the worker_init_fn.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this sounds neat. This would also avoid applying sharding filter multiple times. Given that datasets already wrapped the dp inside sharding filter, all we need to do is provide the implementation of worker_init_fn that does sharding according to world size and worker IDs.

def worker_init_fn(worker_id, world_size, rank):
    info = torch.utils.data.get_worker_info()
    num_workers = info.num_workers
    datapipe = info.dataset
    datapipe.apply_sharding(world_size * worker_num, rank * worker_num + worker_id)

# Initialize DataLoader within each rank with above function
dl = DataLoader(dataset, worker_init_fn=functools.partial(worker_init_fn, world_size=world_size, rank=rank), ...)

Does this sounds about right @ejguan?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best to get work_size and rank from within the function? One can use int(os.environ["RANK"]) and int(os.environ["WORLD_SIZE"]). Also for clarity maybe we could rename these to num_ddp_workers, num_dataloader_workers, ddp_worker_id, dataloader_worker_id.

It might be worth documenting both ways? The second one is somehow more complex, because it exposes the DataLoader sharding level.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, these env variable will exist as long as the script is launched with torchrun, which is the recommended practice for DDP scripts. They might not exist cases in some special cases, like when using submitit, but these are advanced use-cases and such users would know how to tweak the snippet easily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM! I think we can leave the details of implementation to user. At higher level, probably it is sufficient to communicate the workaround with worker_init_fn :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not use ENV variables. There are built-in things for it:

import torch.distributed as dist
if dist.is_available():
        total_workers *= dist.get_world_size()
        global_worker_id = dist.get_rank()*info.num_workers + global_worker_id

I will submit PR(s) to apply it automatically this week.

The only remaining problem is the possibility of having a different number of processes on different ranks. I prefer to keep one .sharding_filter for it, but perhaps would have to introduce sharing ranks/levels into it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @VitalyFedyunin for the feedback. Just so I understand, do you mean to submit PR for updating worker_init_fn such that it not only works for multi-processing case (as of now) but shall also take into account distributed settings like the one propose here #1755 (comment)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. And also classical DataLoader will automatically add this worker_init_fn pieces to all DataPipe-based Datasets. pytorch/pytorch#78631

- The suffling seed is different across epochs.
- The rest of the RNG (typically used for transformations) is
**different** across workers, for maximal entropy and optimal accuracy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the numbers of total batches fetched across different ranks is different, it could potentially stall the training right? So we somehow need to ensure that the number of samples seen by each rank is same. I wonder what are the workarounds for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it gets more complex real fast :)

In torchvision we have this custom Taker datapipe that limits the size of the dp so that it's consistent across DDP workers https://github.com/pytorch/vision/blob/main/torchvision/prototype/datasets/utils/_internal.py#L144-L145

I'll mention this is as well, but I think we'd rather not directly link to this torchvision snippet, as this is definitely not a recommended practice yet. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not unusual to have different number of samples in each rank, for example, if the datapipe expects the same number of tokens/characters for each batch, each sentence may have different number of tokens/characters, then the batch size will be different. As long as the loss value is normalized in a correct way, the training should be fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm... Interseting. My current understanding is that having datapipes with different sizes in DDP will lead to either crashes, or infinite hanging.

Here's a snippet to illustrate this:

import torch
import torch.utils.data as data
import torch.distributed as dist

def replace_print():
    import builtins as __builtin__
    builtin_print = __builtin__.print
    def print(*args, **kwargs):
        for rank in range(dist.get_world_size()):
            if rank == dist.get_rank():
                builtin_print(f"[DDP worker with rank={rank}]", *args, **kwargs)
            dist.barrier()

    __builtin__.print = print


class MyIterableDS(data.IterableDataset):
    
    def __init__(self, size=100):
        self.size = size
        
    def __iter__(self):

        worker_info = data.get_worker_info()
        num_dl_workers = worker_info.num_workers
        dl_worker_id = worker_info.id

        num_ddp_workers = dist.get_world_size()
        ddp_worker_id = dist.get_rank()
        
        for i, s in enumerate(range(self.size)):
            if i % num_ddp_workers == ddp_worker_id:
                if i % num_dl_workers == dl_worker_id:
                    yield s
        
        # EXTRA SAMPLE
        # Uncomment this and you'll get an error
        # if ddp_worker_id == 0:
        #     yield 100
    
    def __len__(self):
        return self.size

dist.init_process_group(backend="gloo")
replace_print()
dist.barrier()

ds = MyIterableDS()
dl = torch.utils.data.DataLoader(ds, batch_size=10, num_workers=4)

for i, batch in enumerate(dl):    
    print(batch)

You can run this locally (without GPUs) with e.g. torchrun --nproc_per_node=4 scipt.py
Uncommentting the EXTRA SAMPLE part, I get the following error:

Traceback (most recent call last):
  File "/home/nicolashug/dev/vision/lol.py", line 55, in <module>
    print(batch)
  File "/home/nicolashug/dev/vision/lol.py", line 12, in print
    dist.barrier()
  File "/home/nicolashug/.miniconda3/envs/pt/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2785, in barrier
    work.wait()
RuntimeError: [/opt/conda/conda-bld/pytorch_1649142626512/work/third_party/gloo/gloo/transport/tcp/pair.cc:598] Connection closed by peer [fe80::f9d3:3540:64c7:287b]:34704
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 241787) of binary: /home/nicolashug/.miniconda3/envs/pt/bin/python

Did you observe similar behaviour @nateanl ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the total batch number should be the same but the batch size can vary across RANKs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug Oh you are referring the number of batches in each rank. That's correct, DDP will hang forever if the number of batches are not even in all ranks. I was talking about the batch_size and it can be different :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks all, I added a section about all this just above

  • All DDP workers work on the same number of batches. One way to do this is to by limit the size of the datapipe within each worker to len(datapipe) // num_ddp_workers, but this might not suit all use-cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yupp, SGTM!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In torchvision we have this custom Taker datapipe that limits the size of the dp so that it's consistent across DDP workers https://github.com/pytorch/vision/blob/main/torchvision/prototype/datasets/utils/_internal.py#L144-L145

Thanks for the code pointer! Sounds interesting :). I guess user would need to know the length of dataset in advance in order to ensure that num_take doesn't exceed the size.

I'll mention this is as well, but I think we'd rather not directly link to this torchvision snippet, as this is definitely not a recommended practice yet. WDYT?

I think I agree! We should probably just make sure users are aware of this issue which you have already addressed here #1755 (comment)

- The suffling seed is different across epochs.
- The rest of the RNG (typically used for transformations) is
**different** across workers, for maximal entropy and optimal accuracy.

General use cases are as follows: ::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would it be useful to have a short code snippet using a dataset with DataLoader (with shuffle=True, worker_init_fn)?

I do think the current examples are likely sufficient but I wonder if some users need more hand-holding/guidance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yupp, I guess it would be good to add usage examples with DataLoader. I guess, we can remove worker_init_fn from code examples as @VitalyFedyunin is working on PR that would automatically apply this to all datapipes based datasets :).

docs/source/datasets.rst Outdated Show resolved Hide resolved
Comment on lines 42 to 47
- All workers (DDP workers *and* DataLoader workers) see a different part
of the data. You might need to call ``dp = dp.apply_sharding(world_size,
rank)`` after wrapping the datapipe into a `ShardingFilter
<https://pytorch.org/data/main/generated/torchdata.datapipes.iter.ShardingFilter.html>`_.
You'll want to do that early in the datapipe, to avoid needlessly
processing samples that eventually get dropped by the workers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug I would suggest we wait for @VitalyFedyunin to land PRs that would take care of sharding automatically through usage of worker_init_fn and we can remove this particular point. As per the issue #1727, we have already wrapped datasets in sharding filter, which means that the users do not have to apply it explicitly nor have to take into account DataLoaderV1+worker_init_fn mechanics :).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends if you want to publish this doc with 1.12 as my patch will not make into it, and will be in master until 1.13

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. Sure, will document it with current release.

Comment on lines 42 to 50
- All workers (DDP workers *and* DataLoader workers) see a different part
of the data. The datasets are already wrapped inside `ShardingFilter
<https://pytorch.org/data/main/generated/torchdata.datapipes.iter.ShardingFilter.html>`_
and you may need to call ``dp.apply_sharing(num_shards, shard_id)`` in order to shard the
data across ranks (DDP workers) and DataLoader workers. One way to do this
is to create ``worker_init_fn`` that calls ``apply_sharding`` with appropriate
number of shards (DDP workers * DataLoader workers) and shard id (inferred through rank
and worker ID of corresponding DataLoader withing rank). Note however, that this assumes
equal number of DataLoader workers for all the ranks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug, @VitalyFedyunin , @ejguan : I made the description rather high level while leaving implementation details to the users. My thinking here is to make the user aware of pitfalls without necessarily providing specific recommendation on implementing specific solution.

@parmeet parmeet requested a review from Nayef211 June 2, 2022 20:23
Copy link
Contributor

@parmeet parmeet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM! We should be good to land this. We can cherry-pick if there are additional recommendations :). Thanks so much @NicolasHug for your help with documenting various pitfalls and recommendations.

@parmeet parmeet merged commit 2978507 into pytorch:main Jun 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Document current limitations of datapipes
7 participants