Skip to content

Commit

Permalink
Fix a bug for CallbackHandler.callback_list (huggingface#8052)
Browse files Browse the repository at this point in the history
* Fix callback_list

* Add test

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>

* Fix test

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
  • Loading branch information
harupy authored and fabiocapsouza committed Nov 15, 2020
1 parent 278a2d4 commit d8380db
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def remove_callback(self, callback):

@property
def callback_list(self):
return "\n".join(self.callbacks)
return "\n".join(cb.__class__.__name__ for cb in self.callbacks)

def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_init_end", args, state, control)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,10 @@ def test_event_flow(self):
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

# warning should be emitted for duplicated callbacks
with unittest.mock.patch("transformers.trainer_callback.logger.warn") as warn_mock:
trainer = self.get_trainer(
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
)
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]

0 comments on commit d8380db

Please sign in to comment.