Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for init_meta_context, materialize_module #9920

Merged
merged 42 commits into from
Oct 21, 2021
Merged

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Oct 13, 2021

What does this PR do?

Fixes #9375

This PR builds on top of pytorch/pytorch#66317. A code section will be dropped once merged to PyTorch.

The goal is no code change for the end users.

from pytorch_lightning.utilities.meta import init_meta_context, materialize_module

class MLP(nn.Module):
    def __init__(self, num_convs: int):
        super().__init__()
        self.lins = []
        for _ in range(num_convs):
            self.lins.append(nn.Linear(1, 1))
        self.layer = nn.Sequential(*self.lins)

with init_meta_context():
    m = nn.Linear(in_features=1, out_features=1)
    assert m.weight.device.type == "meta"
    mlp = MLP(4)
    assert mlp.layer[0].weight.device.type == "meta"

    materialize_module(mlp)
    assert mlp.layer[0].weight.device.type == "cpu"

m = nn.Linear(in_features=1, out_features=1)
assert m.weight.device.type == "cpu"

with init_meta_context():
    m = nn.Linear(in_features=1, out_features=1)
    assert m.weight.device.type == "meta"

m = nn.Linear(in_features=1, out_features=1)
assert m.weight.device.type == "cpu"

TODO:

  • Verify it works with DeepSpeed on BoringModel
    with init_meta_context():
        model = BoringModel()
    assert model.layer.weight.device.type == "meta"
    trainer = Trainer(
        default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16
    )
    trainer.fit(model)
    assert model.layer.weight.device.type == "cpu"

Screenshot 2021-10-15 at 12 59 53

Does your PR introduce any breaking changes? If yes, please list them.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@tchaton tchaton changed the title Add support for meta device Add support for use_meta_device Oct 13, 2021
@tchaton tchaton added this to the v1.5 milestone Oct 14, 2021
@tchaton tchaton self-assigned this Oct 14, 2021
@tchaton tchaton added the feature Is an improvement or enhancement label Oct 14, 2021
@tchaton tchaton marked this pull request as ready for review October 14, 2021 11:56
@tchaton tchaton changed the title Add support for use_meta_device Add support for set_device Oct 14, 2021
@SeanNaren
Copy link
Contributor

SeanNaren commented Oct 14, 2021

This is incredible work allowing users to not have to change their model definition, and hopefully in the future FSDP supports sharding directly from the meta device modules.

(cc @myleott @blefaudeux @anj-s who may have some stuff to say about the API/future integration!)

Should we wait for the code to be merged into PyTorch and available in the nightly build?

EDIT: looping @jeffra and @tjruwase from the DeepSpeed team as well :)

@tchaton
Copy link
Contributor Author

tchaton commented Oct 14, 2021

This is incredible work allowing users to not have to change their model definition, and hopefully in the future FSDP supports sharding directly from the meta device modules.

(cc @myleott @blefaudeux @anj-s who may have some stuff to say about the API/future integration!)

Should we wait for the code to be merged into PyTorch and available in the nightly build?

Adding @cbalioglu to the conversation.

IMO, we definitely want this for Lightning v1.5 as it is already working with PyTorch 1.10 nightly.
Hopefully, this feature will be merged in 1.10 and released before Lightning v1.5, so we can remove the copy / paste.

Best,
T.C

@tchaton tchaton requested a review from SeanNaren October 14, 2021 13:17
@SeanNaren SeanNaren self-requested a review October 17, 2021 18:33
@SeanNaren
Copy link
Contributor

Lets chase up on @zou3519 comment before proceeding further with this PR! We shouldn't merge till we understand the edge cases that are causing unstability.

Also I think it would be beneficial if we showed a use case of benefit (which I think @tchaton has with DeepSpeed but might need more testing based on instantiation times vs configure_sharded_model) before considering to merge this PR!

@mergify mergify bot removed the has conflicts label Oct 19, 2021
@mergify mergify bot removed the has conflicts label Oct 20, 2021
@cbalioglu
Copy link

Looks good to me as an experimental API. Please consider me your point of contact for any customer feedback or issue specific to the parts you copied over from the PyTorch PR #66317.

@tchaton tchaton merged commit 454e93b into master Oct 21, 2021
@tchaton tchaton deleted the set_meta_device branch October 21, 2021 14:48
awaelchli added a commit that referenced this pull request Oct 22, 2021
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
@Lightning-AI Lightning-AI deleted a comment from github-actions bot Apr 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] Simplify sharding API instantiation
9 participants