diff --git a/guided_diffusion/train_util.py b/guided_diffusion/train_util.py index 97c7db38..33440bb0 100644 --- a/guided_diffusion/train_util.py +++ b/guided_diffusion/train_util.py @@ -112,13 +112,12 @@ def _load_and_sync_parameters(self): if resume_checkpoint: self.resume_step = parse_resume_step_from_filename(resume_checkpoint) - if dist.get_rank() == 0: - logger.log(f"loading model from checkpoint: {resume_checkpoint}...") - self.model.load_state_dict( - dist_util.load_state_dict( - resume_checkpoint, map_location=dist_util.dev() - ) + logger.log(f"loading model from checkpoint: {resume_checkpoint}...") + self.model.load_state_dict( + dist_util.load_state_dict( + resume_checkpoint, map_location=dist_util.dev() ) + ) dist_util.sync_params(self.model.parameters()) @@ -128,12 +127,11 @@ def _load_ema_parameters(self, rate): main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) if ema_checkpoint: - if dist.get_rank() == 0: - logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") - state_dict = dist_util.load_state_dict( - ema_checkpoint, map_location=dist_util.dev() - ) - ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) + logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") + state_dict = dist_util.load_state_dict( + ema_checkpoint, map_location=dist_util.dev() + ) + ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) dist_util.sync_params(ema_params) return ema_params