Skip to content

Commit

Permalink
Stop confusing the TF compiler with ModelOutput objects (#28712)
Browse files Browse the repository at this point in the history
* Stop confusing the TF compiler with ModelOutput objects

* Stop confusing the TF compiler with ModelOutput objects
  • Loading branch information
Rocketknight1 authored Jan 26, 2024
1 parent a638de1 commit 708b19e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
17 changes: 12 additions & 5 deletions src/transformers/models/blip/modeling_tf_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,20 +1171,27 @@ def call(
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
labels=labels,
return_dict=return_dict,
return_dict=False,
training=training,
)

if not return_dict:
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
return tuple(output for output in outputs if output is not None)

if outputs.loss is not None and outputs.loss.shape.rank == 0:
outputs.loss = tf.reshape(outputs.loss, (1,))
if labels is not None:
loss = outputs[0]
logits = outputs[1]
else:
loss = None
logits = outputs[0]

if loss is not None and loss.shape.rank == 0:
loss = tf.reshape(loss, (1,))

return TFBlipForConditionalGenerationModelOutput(
loss=outputs.loss,
logits=outputs.logits,
loss=loss,
logits=logits,
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/blip/modeling_tf_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,8 @@ def call(
labels = labels[:, 1:]
labels = tf.reshape(labels, (-1,))
# Keras won't give us label smoothing for sparse CE, so we de-sparsify things here
one_hot_labels = tf.one_hot(labels, depth=self.config.vocab_size, dtype=tf.float32)
# Use relu to clamp masked labels at 0 to avoid NaN (we will be zeroing those out later anyway)
one_hot_labels = tf.one_hot(tf.nn.relu(labels), depth=self.config.vocab_size, dtype=tf.float32)
loss_fct = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1, reduction="none")
masked_positions = tf.cast(tf.not_equal(labels, -100), dtype=tf.float32)
lm_loss = loss_fct(one_hot_labels, shifted_prediction_scores)
Expand Down

0 comments on commit 708b19e

Please sign in to comment.