Skip to content

Commit

Permalink
add a callback hook right before the optimizer step (huggingface#33444)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored and amyeroberts committed Oct 2, 2024
1 parent a090350 commit 4d66dbb
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2445,6 +2445,8 @@ def _inner_training_loop(
else:
grad_norm = _grad_norm

self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)

self.optimizer.step()

self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T
"""
pass

def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.
"""
pass

def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
Expand Down Expand Up @@ -477,6 +483,9 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T
control.should_save = False
return self.call_event("on_step_begin", args, state, control)

def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_pre_optimizer_step", args, state, control)

def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_optimizer_step", args, state, control)

Expand Down
5 changes: 4 additions & 1 deletion tests/trainer/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def on_epoch_end(self, args, state, control, **kwargs):
def on_step_begin(self, args, state, control, **kwargs):
self.events.append("on_step_begin")

def on_pre_optimizer_step(self, args, state, control, **kwargs):
self.events.append("on_pre_optimizer_step")

def on_optimizer_step(self, args, state, control, **kwargs):
self.events.append("on_optimizer_step")

Expand Down Expand Up @@ -151,7 +154,7 @@ def get_expected_events(self, trainer):
expected_events.append("on_epoch_begin")
for _ in range(train_dl_len):
step += 1
expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"]
expected_events += ["on_step_begin", "on_pre_optimizer_step", "on_optimizer_step", "on_step_end"]
if step % trainer.args.logging_steps == 0:
expected_events.append("on_log")
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
Expand Down

0 comments on commit 4d66dbb

Please sign in to comment.