Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFEncoderDecoder not handling labels correctly #14357

Closed
jorgemcgomes opened this issue Nov 10, 2021 · 8 comments · Fixed by #15001
Closed

TFEncoderDecoder not handling labels correctly #14357

jorgemcgomes opened this issue Nov 10, 2021 · 8 comments · Fixed by #15001

Comments

@jorgemcgomes
Copy link
Contributor

jorgemcgomes commented Nov 10, 2021

Environment info

Google Colab

  • transformers version: master branch. With the latest release (4.12.3) you can't replicate this problem, as it fails with other issue that has already been fixed in master (support for cross-attention in TF GPT2)
  • Tensorflow version: 2.7.0

Who can help

Tagging @patrickvonplaten as he has done the latest merges on TFEncoderDecoder.

Information

In TFEncoderDecoder, when the input is passed as dict, the encoder input_processing function "unpacks it", also unpacking the labels (if they are there). The labels end up being passed to the encoder call, which shouldn't happen, as the labels are only needed for the decoder, and causes the encoder call to fail.

The consequence is that trying to fit a TFEncoderDecoder using .fit() with a tf.data.Dataset results in this error.

To reproduce

from transformers import TFEncoderDecoderModel
model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "gpt2")
model(model.dummy_inputs)  # works fine

with_labels = dict(labels=model.dummy_inputs["decoder_input_ids"], **model.dummy_inputs)
model(**with_labels)  # works fine
model(with_labels)  # fails with the error bellow
/usr/local/lib/python3.7/dist-packages/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py in call(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, training, **kwargs)
    493                 decoder_attention_mask = encoder_inputs.pop("decoder_attention_mask")
    494 
--> 495             encoder_outputs = self.encoder(**encoder_inputs)
    496 
    497         encoder_hidden_states = encoder_outputs[0]

/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_tf_bert.py in call(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, training, **kwargs)
   1125             return_dict=return_dict,
   1126             training=training,
-> 1127             kwargs_call=kwargs,
   1128         )
   1129         outputs = self.bert(

/usr/local/lib/python3.7/dist-packages/transformers/modeling_tf_utils.py in input_processing(func, config, input_ids, **kwargs)
    386     if len(kwargs["kwargs_call"]) > 0:
    387         raise ValueError(
--> 388             f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
    389         )
    390 

ValueError: Exception encountered when calling layer "encoder" (type TFBertModel).

The following keyword arguments are not supported by this model: ['labels'].

Call arguments received:
  • input_ids=tf.Tensor(shape=(3, 5), dtype=int32)
  • attention_mask=None
  • token_type_ids=None
  • position_ids=None
  • head_mask=None
  • inputs_embeds=None
  • encoder_hidden_states=None
  • encoder_attention_mask=None
  • past_key_values=None
  • use_cache=True
  • output_attentions=False
  • output_hidden_states=False
  • return_dict=True
  • training=False
  • kwargs={'labels': 'tf.Tensor(shape=(3, 5), dtype=int32)'}

Expected behavior

This should handle labels correctly, as they are needed in order to fit the model.

A workaround that works is adding this bit on the call:

encoder_inputs = input_processing(**encoder_processing_inputs)
# start new code
if "labels" in encoder_inputs:
    labels = encoder_inputs.pop("labels")
# end new code
...
@qqaatw
Copy link
Contributor

qqaatw commented Nov 13, 2021

Hi,

I think all the inputs should be unpacked as keyword arguments before inputted into TFEncoderDecoderModel.__call__,
as stated in the docs,

Is there any reason that you want to pass a dict directly?

@jorgemcgomes
Copy link
Contributor Author

jorgemcgomes commented Nov 13, 2021

The inputs are not unpacked in the model train_step(), which is what is used when you train the model using fit().

See TFPretrainedModel.train_step (line 802):

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)

@patrickvonplaten
Copy link
Contributor

@Rocketknight1,

Do you maybe find some time to look into this? :-)

@huggingface huggingface deleted a comment from github-actions bot Dec 13, 2021
@patrickvonplaten
Copy link
Contributor

@Rocketknight1 , @NielsRogge @ydshieh - I think we can solve this issue with the new design now no?

@ydshieh
Copy link
Collaborator

ydshieh commented Dec 13, 2021

I didn't follow this issue until now. I can try to look at this if @Rocketknight1 is OK.

@Rocketknight1
Copy link
Member

@ydshieh Sure, yes! I'm sorry I've been slow with it.

ydshieh added a commit to ydshieh/transformers that referenced this issue Jan 1, 2022
@github-actions
Copy link

github-actions bot commented Jan 6, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ydshieh
Copy link
Collaborator

ydshieh commented Jan 6, 2022

activate :-)

Rocketknight1 pushed a commit that referenced this issue Jan 12, 2022
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants