Skip to content

Commit

Permalink
Fix callback imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Jan 29, 2024
1 parent dd262c6 commit d26dd15
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/keras_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
logger = logging.getLogger(__name__)


class KerasMetricCallback(keras.Callback):
class KerasMetricCallback(keras.callbacks.Callback):
"""
Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be
compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string
Expand Down Expand Up @@ -265,7 +265,7 @@ def generation_function(inputs, attention_mask):
logs.update(metric_output)


class PushToHubCallback(keras.Callback):
class PushToHubCallback(keras.callbacks.Callback):
"""
Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can
be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such
Expand Down

0 comments on commit d26dd15

Please sign in to comment.