Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Simplify sharding API instantiation #9375

Closed
tchaton opened this issue Sep 8, 2021 · 8 comments · Fixed by #9920
Closed

[RFC] Simplify sharding API instantiation #9375

tchaton opened this issue Sep 8, 2021 · 8 comments · Fixed by #9920
Labels
design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@tchaton
Copy link
Contributor

tchaton commented Sep 8, 2021

🚀 Feature

Currently, the Lightning users working with sharded models and with DeepSpeed or FSDP Plugin needs to know about configure_sharded_model as follow:

Here is an example.

    class MyModel(pl.LightningModule):
        def __init__(self):
            super().__init__()
            self.linear_layer = nn.Linear(32, 32)
            self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
            self.final_block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())

        def configure_sharded_model(self):
            # modules are sharded across processes
            # as soon as they are wrapped with ``wrap`` or ``auto_wrap``.
            # During the forward/backward passes, weights get synced across processes
            # and de-allocated once computation is complete, saving memory.

            # Wraps the layer in a Fully Sharded Wrapper automatically
            linear_layer = wrap(self.linear_layer)

            # Wraps the module recursively
            # based on a minimum number of parameters (default 100M parameters)
            block = auto_wrap(self.block)

            # For best memory efficiency,
            # add FairScale activation checkpointing
            final_block = auto_wrap(checkpoint_wrapper(self.final_block))
            self.model = nn.Sequential(linear_layer, nn.ReLU(), block, final_block)

        def configure_optimizers(self):
            return torch.optim.AdamW(self.model.parameters())

It could be possibly to uniformize the API using the new meta device introduced by PyTorch 1.9.0: https://pytorch.org/tutorials/prototype/skip_param_init.html

from pytorch_lightning.distributed import skip_param_init, apply_param_init

class Model(LightningModule):

    def __init__(self):
        super().__init__()

        with skip_param_init():
            # create parameters as before, possibly ones which don't fit on 1 gpu


    def setup(self, ...):
        ...


# This model doesn't have any params for module defined within skip_param_init context manager
model = Model() #

# In DeepSpeed / FSDP, when distributed is fully available, `apply_param_init` is applied to create the sharded weights.

Motivation

Pitch

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

@tchaton tchaton added feature Is an improvement or enhancement help wanted Open to be worked on labels Sep 8, 2021
@tchaton
Copy link
Contributor Author

tchaton commented Sep 8, 2021

Hey @ananthsub, @SeanNaren, @yifuwang Any thoughts ?

@SeanNaren
Copy link
Contributor

SeanNaren commented Sep 9, 2021

Thanks @tchaton for this!

My knowledge around the meta device is limited so any additional docs/information/reference PRs would be appreciated.

Going off https://pytorch.org/tutorials/prototype/skip_param_init.html#implementation-details I think the goal is to provide a way to instantiate modules within the init function, and skip actual allocation of parameters till later when actually requested.

This would work well in the FSDP case where we wrap individual modules which can then be allocated and wrapped instantly. @ananthsub is the meta device going to be the method to do this in PyTorch?

@ananthsub
Copy link
Contributor

@SeanNaren Here's a relevant RFC for FSDP being part of torch.distributed, which would support meta devices: pytorch/pytorch#64394
(cc @zhaojuanmao)

@ananthsub
Copy link
Contributor

ananthsub commented Sep 10, 2021

@tchaton I agree we'll need consistent recommendations around how sharded models are expected to be used with lightning.

I think there's potentially these 2 categories they fall into right now:

  1. Arbitrarily complex models and partitioning: we might see parts of the overall model wrapped with FSDP, parts wrapped with deespeed, parts wrapped with DDP, parts wrapped with whatever new library/wrapper comes out. In this case, there's no way a plugin can automate this as the partitioning is user-determined and potentially leverages multiple module wrappers simultaneously. In such a scenario, it is preferable for the user to handle this (given they know exactly what they want to do). this way the user is free to compose whatever libraries they choose to without burdening lightning to come up with a unifying API around them, which might not cleanly exist. This can be thought of as manual parallelization, similar to manual optimization. Accordingly, since the user is shouldering how the model ought to be parallelized, the LightningModule becomes a simple shell for the interaction with the various loops. Users benefit from the loop structure & surrounding utilities such as profiling, logging, and checkpointing. We'd need a corresponding training type plugin & accelerator though for this scenario. It should be incredibly simple as it's essentially a passthrough/dummy class which does no extra device movement or wrapping. This is again similar to the loop complexity differences between automatic optimization vs manual optimization.

Assumptions behind this scenario:

  • The user is handling the distributed launch (via some job scheduler) and can control the driver code which instantiates the lightning module & trainer. inside the driver code, they can leverage meta-devices to construct their model before passing this to the lightning module to be used for training/validation/test/prediction
  • The sticking point is that the model state_dict and load_state_dict must be the universal APIs through which we save/load this state (ideally leveraging sharded tensors Add ShardedTensor support in LightningModule #8944).

issues:

  1. Point 1) should not preclude Lightning from offering plugins like FSDP and deepspeed which assume a more fixed structure, just as offering manual optimization does not preclude users from using automatic optimization. These cases are a subset of point 1), but Lightning could offer more structure here. However, I'd need to further investigate the RFC linked above for what this could look like. One issue with relying on lazy/deferred initialization using Lightning hooks is we lose clarity around when exactly the model is defined in order to perform checkpoint loading. We run into this today with FSDP for instance: Cleanup FSDP integration to not require boilerplate logic #8722

@ananthsub ananthsub added the design Includes a design discussion label Sep 10, 2021
@SeanNaren
Copy link
Contributor

Thanks @ananthsub!

Following on our offline discussion and your post, I agree with your view of 'manual parallelization'. in a perfect world distributed communications would be setup for the user since the module __init__ and the user can do whatever distributed communications they'd like!

On the main post, I'm hesitant to introduce any magic into Lightning to force the user into behaviour primarily because sharding/lazy init is so experimental and could change very quickly. Considering https://github.com/NVIDIA/Megatron-LM and https://github.com/bigscience-workshop/Megatron-DeepSpeed as use cases, we see how hard-coded/sophisticated it is to train large models. Overall in my opinion let the user define meta devices if they want to skip initialization separately from anything to do with sharded for now.

@tchaton
Copy link
Contributor Author

tchaton commented Sep 10, 2021

Hey @ananthsub @SeanNaren,

Thanks for sharing this RFC on PyTorch side, I missed it.

This is particularly relevant to this use-case.

b. use meta device can potentially let users to use FSDP without changing model construction codes. e.g. users can pass a meta module to FSDP API, then the FSDP API can recursively wrap its submodules and materialize these modules. In this way, users can simply replace DDP API with FSDP API for larger models without changing any model codes.

The best approach would be third parties framework like FairScale, DeepSpeed to support meta modules and for PyTorch to provide a way to users to decorate their module with a skip_param_init context manager. In this case, Lightning would just have to resolve the meta module only once distributed training is available.

@ananthsub
For the 1., I really like this idea of manual parallelisation. I believe the plugin could take care of creation the progress_group, subgroups, and let the user use them its manual parallelisation.

Best,
T.C

@cbalioglu
Copy link

Thanks @tchaton for letting me know about this issue! For other folks; I have been working on a new lazy model materialization API in PyTorch that requires no code changes in the model. You can check out its draft PR here.

And for easier readability here its rendered doc: https://docs-preview.pytorch.org/66317/distributed.html#lazy-model-materialization

Note that this is still WIP and things might change. I am working on a couple auxiliary tasks right now (e.g. LazyTensor support) related to this work and hope to have a PR soon though.

@SeanNaren
Copy link
Contributor

Thanks @tchaton for letting me know about this issue! For other folks; I have been working on a new lazy model materialization API in PyTorch that requires no code changes in the model. You can check out its draft PR here.

And for easier readability here its rendered doc: https://docs-preview.pytorch.org/66317/distributed.html#lazy-model-materialization

Note that this is still WIP and things might change. I am working on a couple auxiliary tasks right now (e.g. LazyTensor support) related to this work and hope to have a PR soon though.

amazing! thanks for this :) lazily initialising the layers is one thing, but having FSDP or some distributed component control what weights get initialised on each device is a different problem. Do you know if there has already been work to combine the two? cc @ananthsub

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants