diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index b2278c579fc9..e809e187a025 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1884,9 +1884,12 @@ def load_checkpoint(self, load_module_only=load_module_only) if self.zero_optimization() and load_path is not None: - self._load_zero_checkpoint(load_dir, - tag, - load_optimizer_states=load_optimizer_states) + success = self._load_zero_checkpoint( + load_dir, + tag, + load_optimizer_states=load_optimizer_states) + if not success: + self.optimizer._restore_from_fp16_weights() return load_path, client_states @@ -1998,7 +2001,7 @@ def _load_checkpoint(self, def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) if zero_sd_list is None: - return + return False self.optimizer.load_state_dict( state_dict_list=zero_sd_list, @@ -2007,6 +2010,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): print( f'loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}' ) + return True def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size): zero_ckpt_names = []