Skip to content

Commit

Permalink
Rollback serving_output and add TODO
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Moriarity committed Apr 28, 2022
1 parent 177684a commit a4abfa6
Showing 1 changed file with 3 additions and 38 deletions.
41 changes: 3 additions & 38 deletions src/transformers/models/clip/modeling_tf_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,42 +1401,7 @@ def call(
return outputs

def serving_output(self, output: TFCLIPOutput) -> TFCLIPOutput:
text_hs = (
tf.convert_to_tensor(output.text_model_output.hidden_states) if self.config.output_hidden_states else None
)
text_attns = (
tf.convert_to_tensor(output.text_model_output.attentions) if self.config.output_attentions else None
)
text_model_output = TFBaseModelOutputWithPooling(
last_hidden_state=output.text_model_output.last_hidden_state,
pooler_output=output.text_model_output.pooler_output,
hidden_states=text_hs,
attentions=text_attns,
)

vision_hs = (
tf.convert_to_tensor(output.vision_model_output.hidden_states)
if self.config.output_hidden_states
else None
)
vision_attns = (
tf.convert_to_tensor(output.vision_model_output.attentions) if self.config.output_attentions else None
)
vision_model_output = TFBaseModelOutputWithPooling(
last_hidden_state=output.vision_model_output.last_hidden_state,
pooler_output=output.vision_model_output.pooler_output,
hidden_states=vision_hs,
attentions=vision_attns,
)

output = TFCLIPOutput(
loss=output.loss,
logits_per_image=output.logits_per_image,
logits_per_text=output.logits_per_text,
text_embeds=output.text_embeds,
image_embeds=output.image_embeds,
text_model_output=text_model_output,
vision_model_output=vision_model_output,
)

# TODO: As is this currently fails with saved_model=True, because
# TensorFlow cannot trace through nested dataclasses. Reference:
# https://github.com/huggingface/transformers/pull/16886
return output

0 comments on commit a4abfa6

Please sign in to comment.