diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e537b3b6357adb..64d5a3fadf4d6d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1935,7 +1935,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): else: # We load the model state dict on the CPU to avoid an OOM error. state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") - load_result = model.load_state_dict(state_dict, strict=False) + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) # release memory del state_dict self._issue_warnings_after_load(load_result) @@ -1989,7 +1991,9 @@ def _load_best_model(self): # We load the model state dict on the CPU to avoid an OOM error. state_dict = torch.load(best_model_path, map_location="cpu") # If the model is on the GPU, it still works! - load_result = model.load_state_dict(state_dict, strict=False) + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) if not is_sagemaker_mp_enabled(): self._issue_warnings_after_load(load_result) elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):