Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix initialized weights resetting in Fabric.setup() when using FSDP #19755

Merged
merged 7 commits into from
Apr 11, 2024

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Apr 10, 2024

What does this PR do?

Fixes a subtle issue with weight initialization in FSDP triggered by the change #19152. The real cause is that FSDP overrides the apply method of nn.Module:
https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py/#L567-L605

We call it indirectly here:

_update_properties(
module, device=self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device
)

and then root.apply here:

which by itself is innocent because all we want to do is apply a function to each submodule.
Since all we want to do is update an attribute recursively, we don't need the parameters to be unshareded anyway. So this PR simply iterates over the modules instead of calling apply. Consequently, this should also make fabric.setup() faster because no longer does each submodule be unsharded.

The added test fails on master.
The bug is critical enough that we will need to prioritize a release after this is merged.

Needs #19756 to unblock failing tests.


📚 Documentation preview 📚: https://pytorch-lightning--19755.org.readthedocs.build/en/19755/

cc @Borda @tchaton @carmocca @justusschock @awaelchli

@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Apr 10, 2024
@awaelchli awaelchli added bug Something isn't working priority: 0 High priority task strategy: fsdp Fully Sharded Data Parallel labels Apr 10, 2024
@awaelchli awaelchli added this to the 2.2.x milestone Apr 10, 2024
@awaelchli awaelchli marked this pull request as ready for review April 10, 2024 14:34
Copy link
Contributor

github-actions bot commented Apr 10, 2024

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.13, oldest) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.10, 2.1) success
pl-cpu (macOS-11, lightning, 3.10, 2.2) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.13, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.2) success
pl-cpu (windows-2022, lightning, 3.8, 1.13, oldest) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.10, 2.1) success
pl-cpu (windows-2022, lightning, 3.10, 2.2) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success
pl-cpu (macOS-12, pytorch, 3.11, 2.0) success
pl-cpu (macOS-12, pytorch, 3.11, 2.1) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.1) success
pl-cpu (windows-2022, pytorch, 3.11, 2.0) success
pl-cpu (windows-2022, pytorch, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/utilities/device_dtype_mixin.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) (testing Lightning | latest) success
pytorch-lightning (GPUs) (testing PyTorch | latest) success

These checks are required after the changes to src/lightning/fabric/utilities/device_dtype_mixin.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to src/lightning/fabric/utilities/device_dtype_mixin.py.

🟢 fabric: Docs
Check ID Status
docs-make (fabric, doctest) success
docs-make (fabric, html) success

These checks are required after the changes to src/lightning/fabric/utilities/device_dtype_mixin.py.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 1.13, oldest) success
fabric-cpu (macOS-11, lightning, 3.10, 1.13) success
fabric-cpu (macOS-11, lightning, 3.11, 2.1) success
fabric-cpu (macOS-11, lightning, 3.11, 2.2) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.13, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.1) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2) success
fabric-cpu (windows-2022, lightning, 3.8, 1.13, oldest) success
fabric-cpu (windows-2022, lightning, 3.10, 1.13) success
fabric-cpu (windows-2022, lightning, 3.11, 2.1) success
fabric-cpu (windows-2022, lightning, 3.11, 2.2) success
fabric-cpu (macOS-11, fabric, 3.8, 1.13) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 1.13) success
fabric-cpu (windows-2022, fabric, 3.8, 1.13) success
fabric-cpu (macOS-12, fabric, 3.11, 2.0) success
fabric-cpu (macOS-12, fabric, 3.11, 2.1) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.0) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.1) success
fabric-cpu (windows-2022, fabric, 3.11, 2.0) success
fabric-cpu (windows-2022, fabric, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/utilities/device_dtype_mixin.py, tests/tests_fabric/strategies/test_fsdp_integration.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
lightning-fabric (GPUs) (testing Fabric | latest) success
lightning-fabric (GPUs) (testing Lightning | latest) success

These checks are required after the changes to src/lightning/fabric/utilities/device_dtype_mixin.py, tests/tests_fabric/strategies/test_fsdp_integration.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/fabric/utilities/device_dtype_mixin.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.11) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.11) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.11) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/lightning/fabric/utilities/device_dtype_mixin.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

Copy link

codecov bot commented Apr 10, 2024

Codecov Report

Merging #19755 (10b28da) into master (316cc71) will decrease coverage by 26%.
The diff coverage is 100%.

Additional details and impacted files
@@            Coverage Diff            @@
##           master   #19755     +/-   ##
=========================================
- Coverage      84%      59%    -26%     
=========================================
  Files         424      419      -5     
  Lines       34919    34820     -99     
=========================================
- Hits        29349    20374   -8975     
- Misses       5570    14446   +8876     

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you understand why unsharding to summon the parameters would impact the optimization?

@awaelchli awaelchli mentioned this pull request Apr 10, 2024
@mergify mergify bot added the ready PRs ready to be merged label Apr 11, 2024
@awaelchli awaelchli merged commit dcb91d5 into master Apr 11, 2024
111 of 112 checks passed
@awaelchli awaelchli deleted the bugfix/fabric-apply-fsdp branch April 11, 2024 09:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working fabric lightning.fabric.Fabric priority: 0 High priority task ready PRs ready to be merged strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants