Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent settings for FSDP Precision #17506

Closed
awaelchli opened this issue Apr 27, 2023 · 1 comment · Fixed by #17670
Closed

Inconsistent settings for FSDP Precision #17506

awaelchli opened this issue Apr 27, 2023 · 1 comment · Fixed by #17670
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Apr 27, 2023

Bug description

The FSDPPrecision class controls the dtype of the model parameters, autocasting and grad scaler.
It has two possible inputs: precision="16-mixed" or "bf16-mixed":
https://github.com/Lightning-AI/lightning/blob/6464650a3c91a75d4a7c4009d08a68a84037a9a1/src/lightning/fabric/plugins/precision/fsdp.py#L32-L34

In all other precision plugins, "mixed" refers to mixed precision training with model weights in float32 and autocasting inputs and operations to lower precision. However, the FSDPPrecision plugin sets the dtype of the model parameters to float16/bfloat16 regardless. This means we are actually running "16-true".

https://github.com/Lightning-AI/lightning/blob/6464650a3c91a75d4a7c4009d08a68a84037a9a1/src/lightning/fabric/plugins/precision/fsdp.py#L50-L60

Proposal:

For the mixed precision settings, set param_type=torch.float32. For the current case, introduce the precision settings "16-true" and "bf16-true".

cc @awaelchli @carmocca

@awaelchli awaelchli added bug Something isn't working needs triage Waiting to be triaged by maintainers strategy: fsdp Fully Sharded Data Parallel labels Apr 27, 2023
@awaelchli awaelchli removed needs triage Waiting to be triaged by maintainers ver: 1.9.x labels Apr 27, 2023
@leng-yue
Copy link
Contributor

This bug causes training to fail when performing certain fp32 operations, such as bicubic interpolation. I will submit a pull request later tonight to address this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants