diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 1030ec3e1120..38539ba57033 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1844,8 +1844,7 @@ def _overflow_clean_up(self, prev_scale): def _overflow_check_and_loss_scale_update(self): # First compute norm for all group so we know if there is overflow - if self.dtype == torch.float16: - self.check_overflow() + self.check_overflow() #loss scaling related computation prev_scale = self.loss_scale diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 6fd1ac14ec51..c3b4160ebf31 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1705,8 +1705,7 @@ def step(self, closure=None): see_memory_usage(f"In step before checking overflow") # First compute norm for all group so we know if there is overflow - if self.dtype == torch.float16: - self.check_overflow() + self.check_overflow() prev_scale = self.loss_scale self._update_scale(self.overflow)