diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5791b3fab85d..28e49d655930 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -36,6 +36,7 @@ ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS, ZERO_OPTIMIZATION_WEIGHTS from deepspeed.runtime.csr_tensor import CSRTensor import deepspeed.runtime.lr_schedules as lr_schedules +from deepspeed.runtime.utils import get_grad_norm from deepspeed.utils import logger, log_dist, init_distributed from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.utils.debug import debug_extract_module_and_param_names @@ -123,6 +124,7 @@ def __init__(self, self.gas_boundary_ctr = 0 self.dist_backend = "nccl" self._step_applied = False + self._global_grad_norm = None # for debug purposes - can then debug print: debug_get_module_name(module) debug_extract_module_and_param_names(model) @@ -256,6 +258,30 @@ def set_train_batch_size(self, train_batch_size): self._config.train_batch_size = train_batch_size self._config.gradient_accumulation_steps = new_gas + def _compute_global_grad_norm(self): + params = [p for p in self.module.parameters() if p.grad is not None] + return get_grad_norm(params, mpu=self.mpu) + + def get_global_grad_norm(self, force_compute=False) -> float: + """Return the 2-norm of all gradients. If there is model parallelism, + the norm will be global. + + The computed norm will be cached and reused until the next step() + pass unless ``force_compute=True``. + .. note:: + In the presence of model parallelism, this is a collective call + and acts as a barrier among ``mpu.get_model_parallel_group()``. + Args: + force_compute (bool, optional): Force a recomputation of the norm. Defaults to False. + Returns: + float: norm + """ + # Check for an outdated parameter norm. + if force_compute or self._global_grad_norm is None: + self._global_grad_norm = self._compute_global_grad_norm() + + return self._global_grad_norm + def checkpoint_tag_validation_enabled(self): return self._config.checkpoint_tag_validation_enabled @@ -1315,6 +1341,9 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): self.optimizer.step() + if hasattr(self.optimizer, '_global_grad_norm'): + self._global_grad_norm = self.optimizer._global_grad_norm + # Quantize the updated parameter if there no overflow if self.quantizer: self.quantizer.quantize( diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index fba0d6b1fd59..1d223c15a9b2 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -45,6 +45,8 @@ def __init__(self, self.fp16_groups_flat = [] self.fp32_groups_flat = [] + self._global_grad_norm = 0. + # loop to deal with groups for i, param_group in enumerate(self.optimizer.param_groups): # push this group to list before modify @@ -251,6 +253,8 @@ def step(self, closure=None): all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) self.stop_timers([COMPUTE_NORM]) + self._global_grad_norm = all_groups_norm + self.start_timers([UNSCALE_AND_CLIP]) self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm]) self.stop_timers([UNSCALE_AND_CLIP]) diff --git a/deepspeed/runtime/fp16/unfused_optimizer.py b/deepspeed/runtime/fp16/unfused_optimizer.py index c30df0bef1d0..27dbc45dca3e 100755 --- a/deepspeed/runtime/fp16/unfused_optimizer.py +++ b/deepspeed/runtime/fp16/unfused_optimizer.py @@ -32,6 +32,7 @@ def __init__(self, fused_lamb_legacy=False): self.fused_lamb_legacy = fused_lamb_legacy + self._global_grad_norm = 0. if torch.distributed.get_rank() == 0: logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ') @@ -217,6 +218,7 @@ def unscale_and_clip_grads(self, norm_groups, apply_scale=True): for norm in norm_groups: total_norm += norm**2.0 total_norm = math.sqrt(total_norm) + self._global_grad_norm = total_norm # compute combined scale factor for this group combined_scale = self.cur_scale