-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GenerationOutputs] Fix GenerationOutputs Tests #9443
Changes from all commits
d2a875e
d7e2335
b254222
1b81d35
b4e7be4
5cd65a0
49ebcae
0f9b6c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -455,7 +455,7 @@ def prepare_inputs_for_generation( | |
"decoder_attention_mask": decoder_attention_mask, | ||
"decoder_input_ids": decoder_inputs["input_ids"], | ||
"encoder_outputs": encoder_outputs, | ||
"past_key_values": past, | ||
"past_key_values": decoder_inputs["past_key_values"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patil-suraj -> think this is a bit safer |
||
"use_cache": use_cache, | ||
} | ||
return input_dict | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -522,6 +522,7 @@ def test_greedy_generate_dict_outputs_use_cache(self): | |
return | ||
|
||
config.use_cache = True | ||
config.is_decoder = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sure to use causal mask for models like BERT, RoBERTa, ... |
||
model = model_class(config).to(torch_device).eval() | ||
output_greedy, output_generate = self._greedy_generate( | ||
model=model, | ||
|
@@ -730,6 +731,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self): | |
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) | ||
|
||
config.use_cache = True | ||
config.is_decoder = True | ||
model = model_class(config).to(torch_device).eval() | ||
output_beam, output_generate = self._beam_search_generate( | ||
model=model, | ||
|
@@ -962,12 +964,7 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_ | |
# Attentions | ||
if config.is_encoder_decoder: | ||
# encoder | ||
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length) | ||
self.assertIsInstance(output.encoder_attentions, tuple) | ||
self.assertListEqual( | ||
[layer_attentions.shape for layer_attentions in output.encoder_attentions], | ||
[encoder_expected_shape] * len(output.encoder_attentions), | ||
) | ||
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length) | ||
# decoder | ||
self._check_attentions_for_generate( | ||
num_sequences_in_output, | ||
|
@@ -993,11 +990,8 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_ | |
# Hidden States | ||
if config.is_encoder_decoder: | ||
# encoder | ||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size) | ||
self.assertIsInstance(output.encoder_hidden_states, tuple) | ||
self.assertListEqual( | ||
[layer_hidden_states.shape for layer_hidden_states in output.encoder_hidden_states], | ||
[encoder_expected_shape] * len(output.encoder_hidden_states), | ||
self._check_encoder_hidden_states_for_generate( | ||
output.encoder_hidden_states, batch_size, config, seq_length | ||
) | ||
|
||
# decoder | ||
|
@@ -1052,6 +1046,14 @@ def _check_attentions_for_generate( | |
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) | ||
) | ||
|
||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): | ||
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length) | ||
self.assertIsInstance(attentions, tuple) | ||
self.assertListEqual( | ||
[layer_attentions.shape for layer_attentions in attentions], | ||
[encoder_expected_shape] * len(attentions), | ||
) | ||
|
||
def _check_hidden_states_for_generate( | ||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 | ||
): | ||
|
@@ -1071,6 +1073,14 @@ def _check_hidden_states_for_generate( | |
[expected_shape] * len(iter_hidden_states), | ||
) | ||
|
||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length): | ||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size) | ||
self.assertIsInstance(hidden_states, tuple) | ||
self.assertListEqual( | ||
[layer_hidden_states.shape for layer_hidden_states in hidden_states], | ||
[encoder_expected_shape] * len(hidden_states), | ||
) | ||
|
||
|
||
@require_torch | ||
class UtilsFunctionsTest(unittest.TestCase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj we forgot to add this in the BERT cache PR. Bert-like models can also be used as stand-alone
BertForCausalLM
models -> so we need to returnpast_key_values
here.