Skip to content

Commit

Permalink
report grad_norm during training
Browse files Browse the repository at this point in the history
  • Loading branch information
shijie-wu committed Jan 17, 2024
1 parent fa6d12f commit fb2c2ce
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,7 @@ def _inner_training_loop(
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
grad_norm: Optional[float] = None

self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

Expand Down Expand Up @@ -1900,18 +1901,19 @@ def _inner_training_loop(
# deepspeed does its own clipping

if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision
nn.utils.clip_grad_norm_(
_grad_norm = nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
else:
self.accelerator.clip_grad_norm_(
_grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
grad_norm = _grad_norm.item() if _grad_norm is not None else None

# Optimizer step
self.optimizer.step()
Expand All @@ -1926,7 +1928,7 @@ def _inner_training_loop(
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)

self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

Expand All @@ -1941,7 +1943,7 @@ def _inner_training_loop(
self.control.should_training_stop = True

self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)

if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
Expand Down Expand Up @@ -2246,7 +2248,7 @@ def _issue_warnings_after_load(self, load_result):
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
)

def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_tpu_available():
xm.mark_step()
Expand All @@ -2260,6 +2262,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
tr_loss -= tr_loss

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm
logs["learning_rate"] = self._get_learning_rate()

self._total_loss_scalar += tr_loss_scalar
Expand Down

0 comments on commit fb2c2ce

Please sign in to comment.