diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index eebea8b4a2dd72..74654241b14419 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1420,14 +1420,15 @@ def _save_checkpoint(self, model, trial, metrics=None): xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) reissue_pt_warnings(caught_warnings) elif is_sagemaker_mp_enabled(): - # Consolidate the state dict on all processed of dp_rank 0 - opt_state_dict = self.optimizer.state_dict() - # Save it and the scheduler on the main process - if self.is_world_process_zero(): - torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt")) - with warnings.catch_warnings(record=True) as caught_warnings: - torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) - reissue_pt_warnings(caught_warnings) + if smp.dp_rank() == 0: + # Consolidate the state dict on all processed of dp_rank 0 + opt_state_dict = self.optimizer.state_dict() + # Save it and the scheduler on the main process + if self.is_world_process_zero(): + torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt")) + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + reissue_pt_warnings(caught_warnings) elif self.is_world_process_zero() and not self.deepspeed: # deepspeed.save_checkpoint above saves model/optim/sched torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))