-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GenerationOutputs] Fix GenerationOutputs Tests #9443
[GenerationOutputs] Fix GenerationOutputs Tests #9443
Conversation
@@ -522,6 +522,7 @@ def test_greedy_generate_dict_outputs_use_cache(self): | |||
return | |||
|
|||
config.use_cache = True | |||
config.is_decoder = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure to use causal mask for models like BERT, RoBERTa, ...
@@ -455,7 +455,7 @@ def prepare_inputs_for_generation( | |||
"decoder_attention_mask": decoder_attention_mask, | |||
"decoder_input_ids": decoder_inputs["input_ids"], | |||
"encoder_outputs": encoder_outputs, | |||
"past_key_values": past, | |||
"past_key_values": decoder_inputs["past_key_values"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj -> think this is a bit safer
@@ -570,7 +570,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non | |||
if past is not None: | |||
input_ids = input_ids[:, -1:] | |||
|
|||
return {"input_ids": input_ids, "attention_mask": attention_mask} | |||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj we forgot to add this in the BERT cache PR. Bert-like models can also be used as stand-alone BertForCausalLM
models -> so we need to return past_key_values
here.
This PR actually made me correct 2 bugs additionally:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
What does this PR do?
The
GenerationOutputs
PR: #9150 was not rebased, so that the cicrle ci on master is red now. This PR fixes it.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.