diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 9898abce801d..6b63efbb23f7 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -534,7 +534,7 @@ def state(self): def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): assert self.immediate_grad_update - self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=True) + self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False) def create_grad_acc_hooks(self): self.grad_accs = []