Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GenerationOutputs] Fix GenerationOutputs Tests #9443

Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jan 6, 2021

What does this PR do?

The GenerationOutputs PR: #9150 was not rebased, so that the cicrle ci on master is red now. This PR fixes it.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.

@@ -522,6 +522,7 @@ def test_greedy_generate_dict_outputs_use_cache(self):
return

config.use_cache = True
config.is_decoder = True
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Jan 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure to use causal mask for models like BERT, RoBERTa, ...

@@ -455,7 +455,7 @@ def prepare_inputs_for_generation(
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": decoder_inputs["past_key_values"],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj -> think this is a bit safer

@@ -570,7 +570,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
if past is not None:
input_ids = input_ids[:, -1:]

return {"input_ids": input_ids, "attention_mask": attention_mask}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj we forgot to add this in the BERT cache PR. Bert-like models can also be used as stand-alone BertForCausalLM models -> so we need to return past_key_values here.

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Jan 6, 2021

This PR actually made me correct 2 bugs additionally:

  1. past_key_values for BertForCausalLM
  2. T5 should not return T5 cross attentions if just encoder model -> make sure encoder model has never config.is_decoder=True

@patrickvonplaten patrickvonplaten merged commit b8462b5 into huggingface:master Jan 6, 2021
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@patrickvonplaten patrickvonplaten deleted the fix_output_generate branch January 6, 2021 20:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants