-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Simplify sharding API instantiation #9375
Comments
Hey @ananthsub, @SeanNaren, @yifuwang Any thoughts ? |
Thanks @tchaton for this! My knowledge around the Going off https://pytorch.org/tutorials/prototype/skip_param_init.html#implementation-details I think the goal is to provide a way to instantiate modules within the init function, and skip actual allocation of parameters till later when actually requested. This would work well in the FSDP case where we wrap individual modules which can then be allocated and wrapped instantly. @ananthsub is the |
@SeanNaren Here's a relevant RFC for FSDP being part of torch.distributed, which would support meta devices: pytorch/pytorch#64394 |
@tchaton I agree we'll need consistent recommendations around how sharded models are expected to be used with lightning. I think there's potentially these 2 categories they fall into right now:
Assumptions behind this scenario:
issues:
|
Thanks @ananthsub! Following on our offline discussion and your post, I agree with your view of 'manual parallelization'. in a perfect world distributed communications would be setup for the user since the module On the main post, I'm hesitant to introduce any magic into Lightning to force the user into behaviour primarily because sharding/lazy init is so experimental and could change very quickly. Considering https://github.com/NVIDIA/Megatron-LM and https://github.com/bigscience-workshop/Megatron-DeepSpeed as use cases, we see how hard-coded/sophisticated it is to train large models. Overall in my opinion let the user define meta devices if they want to skip initialization separately from anything to do with sharded for now. |
Hey @ananthsub @SeanNaren, Thanks for sharing this RFC on PyTorch side, I missed it. This is particularly relevant to this use-case. b. use meta device can potentially let users to use FSDP without changing model construction codes. e.g. users can pass a meta module to FSDP API, then the FSDP API can recursively wrap its submodules and materialize these modules. In this way, users can simply replace DDP API with FSDP API for larger models without changing any model codes. The best approach would be third parties framework like FairScale, DeepSpeed to support meta modules and for PyTorch to provide a way to users to decorate their module with a @ananthsub Best, |
Thanks @tchaton for letting me know about this issue! For other folks; I have been working on a new lazy model materialization API in PyTorch that requires no code changes in the model. You can check out its draft PR here. And for easier readability here its rendered doc: https://docs-preview.pytorch.org/66317/distributed.html#lazy-model-materialization Note that this is still WIP and things might change. I am working on a couple auxiliary tasks right now (e.g. LazyTensor support) related to this work and hope to have a PR soon though. |
amazing! thanks for this :) lazily initialising the layers is one thing, but having FSDP or some distributed component control what weights get initialised on each device is a different problem. Do you know if there has already been work to combine the two? cc @ananthsub |
🚀 Feature
Currently, the Lightning users working with sharded models and with DeepSpeed or FSDP Plugin needs to know about
configure_sharded_model
as follow:Here is an example.
It could be possibly to uniformize the API using the new meta device introduced by PyTorch 1.9.0: https://pytorch.org/tutorials/prototype/skip_param_init.html
Motivation
Pitch
Alternatives
Additional context
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
The text was updated successfully, but these errors were encountered: