Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TF test_step to match train_step #15111

Merged
merged 2 commits into from
Jan 11, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ def compile(
logger.warning(
"No loss specified in compile() - the model's internal loss computation will be used as the "
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
"Please ensure your labels are passed as the 'labels' key of the input dict so that they are "
"Please ensure your labels are passed as keys in the input dict so that they are "
"accessible to the model during the forward pass. To disable this behaviour, please pass a "
"loss argument, or explicitly pass loss=None if you do not want your model to compute a loss."
)
Expand Down Expand Up @@ -920,6 +920,9 @@ def test_step(self, data):
# the input dict (and loss is computed internally)
if y is None and "labels" in x:
y = x["labels"] # Stops confusion with metric computations
elif y is None and "input_ids" in x:
# Just make any kind of dummy array to make loss work
y = tf.zeros(tf.shape(x["input_ids"])[0], dtype=tf.int64)
y_pred = self(x, training=False)
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# Updates stateful loss metrics.
Expand Down