Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop confusing the TF compiler with ModelOutput objects #28712

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/transformers/models/blip/modeling_tf_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,20 +1171,27 @@ def call(
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
labels=labels,
return_dict=return_dict,
return_dict=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to do this if we still have the correct indexing logic below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regrettably yes - this bit is the bit that actually fixes the bug! The change to the loss logic turned out to be actually unrelated - I just spotted that there was a NaN risk there when I was trying to find the bug. Still, I think the change to it is worth keeping.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. As it's already done in BLIP and it's a pattern we also use for flax models I think it's fine to do :)

training=training,
)

if not return_dict:
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
return tuple(output for output in outputs if output is not None)

if outputs.loss is not None and outputs.loss.shape.rank == 0:
outputs.loss = tf.reshape(outputs.loss, (1,))
if labels is not None:
loss = outputs[0]
logits = outputs[1]
else:
loss = None
logits = outputs[0]

if loss is not None and loss.shape.rank == 0:
loss = tf.reshape(loss, (1,))

return TFBlipForConditionalGenerationModelOutput(
loss=outputs.loss,
logits=outputs.logits,
loss=loss,
logits=logits,
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/blip/modeling_tf_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,8 @@ def call(
labels = labels[:, 1:]
labels = tf.reshape(labels, (-1,))
# Keras won't give us label smoothing for sparse CE, so we de-sparsify things here
one_hot_labels = tf.one_hot(labels, depth=self.config.vocab_size, dtype=tf.float32)
# Use relu to clamp masked labels at 0 to avoid NaN (we will be zeroing those out later anyway)
one_hot_labels = tf.one_hot(tf.nn.relu(labels), depth=self.config.vocab_size, dtype=tf.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silly question, but won't this massively increase the loss as it will set the model as predicting token 0 for all the cases when it's -100?

Copy link
Member Author

@Rocketknight1 Rocketknight1 Jan 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but the masked positions are set to 0 by the line lm_loss *= masked_positions! However, if we get a NaN value because we passed a negative label to an internal TF function, that won't be zeroed out correctly.

In other words, it doesn't matter what value the masked labels have, as long as it's nonnegative - any non-NaN value will be masked to 0, but NaNs can survive multiplication by zero and infest the rest of the code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK - works for me!

loss_fct = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1, reduction="none")
masked_positions = tf.cast(tf.not_equal(labels, -100), dtype=tf.float32)
lm_loss = loss_fct(one_hot_labels, shifted_prediction_scores)
Expand Down
Loading