Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REQUEST] Distributed data parallel training #2

Closed
sangkeun00 opened this issue Jul 1, 2022 · 4 comments
Closed

[REQUEST] Distributed data parallel training #2

sangkeun00 opened this issue Jul 1, 2022 · 4 comments
Labels
enhancement New feature or request

Comments

@sangkeun00
Copy link
Contributor

Currently, Betty only supports torch.nn.DataParallel. Compared to torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel is much slower even in the single-machine multi-gpu settings. Therefore, we need to replace torch.nn.DataParallel with torch.nn.parallel.DistributedDataParallel for better training speed and the multi-machine multi-gpu support.

@sangkeun00 sangkeun00 added the enhancement New feature or request label Jul 1, 2022
@gongbudaizhe
Copy link

can't agree more!

@sangkeun00
Copy link
Contributor Author

sangkeun00 commented Jul 12, 2022

The main issue is that highly efficient gradient synchronization (all-reduce) of DistributedDataParallel only works with torch.autograd.backward. However, meta-learning/MLO heavily uses torch.autograd.grad instead of torch.autograd.backward.

An ad-hoc solution is manually performing gradient synchronization, however this may degrade throughput as we can't add any tricks like computation-communication overlapping (register_hook is also not supported for torch.autograd.grad 😞 ). This may not be a huge issue if your setting is single-machine multi-gpu with very high communication bandwidth (eg nvlink) or if your model is not super large (probably anything less than BERT).

Therefore, I may implement this ad-hoc option soon, and once PyTorch supports efficient synchronization for torch.autograd.grad, we may update our strategy. If you have any questions or are willing to work on this feature, feel free to let me know!

Best,
Sang

@gongbudaizhe
Copy link

gongbudaizhe commented Jul 13, 2022

I understand the difficulty of implementing this feature, but in my opinion, this is the single most important feature that betty should have while similar repos like facebookresearch/higher#116 doesn't. To make GML/MLO more impactful, it should be applied to large scale problems in the wild. Distributed training is a must have for real applications.

The ad-hoc option looks good to me, can't wait to try it 👍

@sangkeun00
Copy link
Contributor Author

sangkeun00 commented Oct 25, 2022

Hello @gongbudaizhe,

I apologize for the late reply! I finally implemented the distributed training feature for the multi-node/multi-gpu setting.
To try this feature, you should install the nightly version by directly cloning the most recent commit as:

git clone https://github.com/leopard-ai/betty.git
cd betty
pip install .

In detail, you can enable distributed training by 1) setting distributed=True in EngineConfig as:

engine_config=EngineConfig(distributed=True, ...)

and 2) launching the training script with torch.distributed.launch on every node as standard PyTorch distributed training.

In the future, we plan to simplify the launching procedure as Microsoft's DeepSpeed or HuggingFace's Accelerate. To do this, we need to write a custom launcher code. If you are willing to contribute to this, that is very welcome!

I am sorry again for the delay, and let me know if you have any questions!

Best,
Sang

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants