Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix resume fsdp #23111

Merged
merged 4 commits into from
May 4, 2023
Merged

fix resume fsdp #23111

merged 4 commits into from
May 4, 2023

Conversation

qywu
Copy link
Contributor

@qywu qywu commented May 2, 2023

What does this PR do?

Fixes # 23034

When training a model with FSDP, the checkpoint is not saved and loaded correctly. Only rank 0's optimizer state dict is saved. This PR fixes this issue.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@pacman100

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 2, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @qywu for the super quick fix, LGTM! 🤗

@pacman100
Copy link
Contributor

Please run make style and make quality to fix the quality issues

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I went over it again and noticed that you aren't saving and loading the optimizer state only on rank 0. Please do that. Refer the implementation in accelerate here for reference: https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L924-L952

@qywu
Copy link
Contributor Author

qywu commented May 3, 2023

I have fixed the issues. The optimizer saving had no problems. For using scatter_full_optim_state_dict, indeed loading on rank 0 is enough, which can save CPU memory usage.

@qywu qywu requested a review from pacman100 May 3, 2023 20:22
@@ -2388,7 +2394,11 @@ def _save_checkpoint(self, model, trial, metrics=None):
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
if self.fsdp:
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saving on rank 0 should be efficient and enough, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe self.args.should_save in this case is already handling saving on rank 0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, okay, got it, thank you!

@qywu qywu requested a review from pacman100 May 3, 2023 22:35
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @qywu for iterating 🤗

@pacman100
Copy link
Contributor

cc @sgugger for a second look

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@sgugger sgugger merged commit adb0760 into huggingface:main May 4, 2023
@wentinghome
Copy link

thanks for the fix!

gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
* fix resume fsdp

* fix rank 0 loading

* fix style and quality
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* fix resume fsdp

* fix rank 0 loading

* fix style and quality
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cannot resume FSDP optimizer state
5 participants