Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix FSDP model resume optimizer & scheduler #25852

Merged
merged 2 commits into from
Sep 1, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"

Expand Down Expand Up @@ -2359,16 +2360,12 @@ def _save_checkpoint(self, model, trial, metrics=None):
partial=True,
v3=smp.state.cfg.shard_optimizer_state,
)
if self.args.should_save:
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.is_deepspeed_enabled and not (self.fsdp or self.is_fsdp_enabled):
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

# Save SCHEDULER & SCALER
if self.args.should_save and not self.is_deepspeed_enabled and not is_torch_tpu_available():
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -2439,7 +2436,8 @@ def _load_optimizer_and_scheduler(self, checkpoint):
checkpoint_file_exists = (
glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
if is_sagemaker_mp_enabled()
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
else (os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN)))
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
Expand Down