Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duplicate call to save_checkpoint when using deepspeed #14946

Merged

Conversation

MihaiBalint
Copy link
Contributor

What does this PR do?

Drop duplicate call to deepspeed.save_checkpoint(), the trainer.save_model() function already handles that case.

Following this change: https://github.com/huggingface/transformers/pull/14652/files#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deR1986
The call to save_checkpoint() was duplicated.

I found this issue after seeing the following logs (note the last 4 lines):

[INFO|trainer.py:2033] 2021-12-26 19:42:00,421 >> Saving model checkpoint to finetuned-ro-en-dev/checkpoint-2
[INFO|configuration_utils.py:425] 2021-12-26 19:42:00,423 >> Configuration saved in finetuned-ro-en-dev/checkpoint-2/config.json
[INFO|modeling_utils.py:1070] 2021-12-26 19:44:09,064 >> Model weights saved in finetuned-ro-en-dev/checkpoint-2/pytorch_model.bin
[INFO|tokenization_utils_base.py:2043] 2021-12-26 19:44:09,110 >> tokenizer config file saved in finetuned-ro-en-dev/checkpoint-2/tokenizer_config.json
[INFO|tokenization_utils_base.py:2049] 2021-12-26 19:44:09,112 >> Special tokens file saved in finetuned-ro-en-dev/checkpoint-2/special_tokens_map.json
[2021-12-26 19:44:09,596] [INFO] [logging.py:69:log_dist] [Rank 0] Saving model checkpoint: finetuned-ro-en-dev/checkpoint-2/global_step2/mp_rank_00_model_states.pt
[2021-12-26 19:59:09,484] [INFO] [engine.py:2964:_save_zero_checkpoint] zero checkpoint saved finetuned-ro-en-dev/checkpoint-2/global_step2/zero_pp_rank_0_mp_rank_00_optim_states.pt
[2021-12-26 19:59:09,575] [INFO] [logging.py:69:log_dist] [Rank 0] Saving model checkpoint: finetuned-ro-en-dev/checkpoint-2/global_step2/mp_rank_00_model_states.pt
[2021-12-26 20:16:17,005] [INFO] [engine.py:2964:_save_zero_checkpoint] zero checkpoint saved finetuned-ro-en-dev/checkpoint-2/global_step2/zero_pp_rank_0_mp_rank_00_optim_states.pt

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@stas00
Copy link
Contributor

stas00 commented Dec 27, 2021

sorry, didn't realize that the 2 PRs were the same, just different source branch. ok, let's work on this one.

So indeed there is a duplication as you discovered https://github.com/huggingface/transformers/pull/14652/files#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deR1986

So it should be removed and not the way this PR proposes. Would you like to fix that and then I will merge it?

Basically revert the change you proposed and then this PR should revert my change you linked to where a duplication was added.


Notes for myself:

So after merging this the only issue is:

push_to_hub => save_model => questionable outcome 

for z3 if stage3_gather_fp16_weights_on_model_save=false.

otherwise this path:

_save_checkpoint => save_model

conditionally saves the model but certainly saves the deepspeed checkpoint inside _save_checkpoint

I'm thinking that perhaps save_model should have the logic to use self.deepspeed.save_checkpoint(output_dir) as a saving grace for z3+ stage3_gather_fp16_weights_on_model_save=false, since weights can be recovered in this case.

I will probably make a change on the deepspeed side and then it'll be easier for the Trainer to know whether to fall back or not.

It will be resolved in this PR #14948

@stas00 stas00 self-assigned this Dec 27, 2021
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, @MihaiBalint - Thank you!

I will deal with the edge case correctly here: #14948

@stas00 stas00 merged commit c113827 into huggingface:master Dec 27, 2021
@MihaiBalint
Copy link
Contributor Author

@stas00 many thanks for the review!

stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Jan 6, 2022
…ce#14946)

* Fix duplicate call to save_checkpoint when using deepspeed / stage3_gather_fp16_weights_on_model_save

* Revert "Fix duplicate call to save_checkpoint when using deepspeed / stage3_gather_fp16_weights_on_model_save"

This reverts commit 6a3dec0.

* Delete correct duplicate invocation of deepspeed save_checkpoint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants