Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix resume fsdp #23111

Merged
merged 4 commits into from
May 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,7 +2114,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)

if not any(
[os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]]
os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

Expand Down Expand Up @@ -2364,6 +2364,12 @@ def _save_checkpoint(self, model, trial, metrics=None):
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict()

if self.fsdp:
# FSDP has a different interface for saving optimizer states.
# Needs to be called on all ranks to gather all states.
# full_optim_state_dict will be deprecated after Pytorch 2.2!
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)

if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
Expand All @@ -2388,7 +2394,11 @@ def _save_checkpoint(self, model, trial, metrics=None):
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
if self.fsdp:
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saving on rank 0 should be efficient and enough, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe self.args.should_save in this case is already handling saving on rank 0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, okay, got it, thank you!

else:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -2498,9 +2508,18 @@ def opt_load_hook(mod, opt):
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
map_location = self.args.device if self.args.world_size > 1 else "cpu"
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
if self.fsdp:
full_osd = None
# In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
if self.args.process_index == 0:
full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
# call scatter_full_optim_state_dict on all ranks
sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
self.optimizer.load_state_dict(sharded_osd)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
Expand Down