Skip to content

Commit

Permalink
storing & logging gradient norm in trainer (#27326)
Browse files Browse the repository at this point in the history
* report grad_norm during training

* support getting grad_norm from deepspeed
  • Loading branch information
shijie-wu authored and Ita Zaporozhets committed May 14, 2024
1 parent f7eb3dc commit 6adc8d3
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@
from accelerate import __version__ as accelerate_version
from accelerate.utils import (
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
load_fsdp_model,
load_fsdp_optimizer,
Expand Down Expand Up @@ -1856,6 +1857,7 @@ def _inner_training_loop(
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
grad_norm: Optional[float] = None

self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

Expand Down Expand Up @@ -1973,19 +1975,27 @@ def _inner_training_loop(
# deepspeed does its own clipping

if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision
nn.utils.clip_grad_norm_(
_grad_norm = nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
else:
self.accelerator.clip_grad_norm_(
_grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)

if (
is_accelerate_available()
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
):
grad_norm = model.get_global_grad_norm()
else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None

# Optimizer step
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
Expand All @@ -1999,7 +2009,7 @@ def _inner_training_loop(
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)

self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

Expand All @@ -2019,7 +2029,7 @@ def _inner_training_loop(
self.control.should_training_stop = True

self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)

if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
Expand Down Expand Up @@ -2356,7 +2366,7 @@ def _issue_warnings_after_load(self, load_result):
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
)

def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_tpu_available():
xm.mark_step()
Expand All @@ -2370,6 +2380,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
tr_loss -= tr_loss

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm
logs["learning_rate"] = self._get_learning_rate()

self._total_loss_scalar += tr_loss_scalar
Expand Down

0 comments on commit 6adc8d3

Please sign in to comment.