From 609629263126f429be339ecddeeeac87f033935c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 27 Apr 2021 20:01:31 -0400 Subject: [PATCH] Accumulate opt state dict on do_rank 0 --- src/transformers/trainer.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5565bdb2eab4fb..536c838a63dee5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1416,14 +1416,15 @@ def _save_checkpoint(self, model, trial, metrics=None): xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) reissue_pt_warnings(caught_warnings) elif is_sagemaker_mp_enabled(): - # Consolidate the state dict on all processed of dp_rank 0 - opt_state_dict = self.optimizer.state_dict() - # Save it and the scheduler on the main process - if self.is_world_process_zero(): - torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt")) - with warnings.catch_warnings(record=True) as caught_warnings: - torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) - reissue_pt_warnings(caught_warnings) + if smp.dp_rank() == 0: + # Consolidate the state dict on all processed of dp_rank 0 + opt_state_dict = self.optimizer.state_dict() + # Save it and the scheduler on the main process + if self.is_world_process_zero(): + torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt")) + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + reissue_pt_warnings(caught_warnings) elif self.is_world_process_zero() and not self.deepspeed: # deepspeed.save_checkpoint above saves model/optim/sched torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))