Skip to content

Commit

Permalink
FSDP bug fix for load_state_dict (#18596)
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 authored Aug 12, 2022
1 parent d344534 commit 4eed2be
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
load_result = model.load_state_dict(state_dict, strict=False)
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
# release memory
del state_dict
self._issue_warnings_after_load(load_result)
Expand Down Expand Up @@ -1989,7 +1991,9 @@ def _load_best_model(self):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
load_result = model.load_state_dict(state_dict, strict=False)
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
Expand Down

0 comments on commit 4eed2be

Please sign in to comment.