Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GenerationOutputs] Fix GenerationOutputs Tests #9443

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 34 additions & 37 deletions src/transformers/generation_utils.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
if past is not None:
input_ids = input_ids[:, -1:]

return {"input_ids": input_ids, "attention_mask": attention_mask}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}

def _reorder_cache(self, past, beam_idx):
reordered_past = ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
if past is not None:
input_ids = input_ids[:, -1:]

return {"input_ids": input_ids, "attention_mask": attention_mask}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj we forgot to add this in the BERT cache PR. Bert-like models can also be used as stand-alone BertForCausalLM models -> so we need to return past_key_values here.


def _reorder_cache(self, past, beam_idx):
reordered_past = ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def prepare_inputs_for_generation(
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"past_key_values": decoder_inputs["past_key_values"],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj -> think this is a bit safer

"use_cache": use_cache,
}
return input_dict
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
if past is not None:
input_ids = input_ids[:, -1:]

return {"input_ids": input_ids, "attention_mask": attention_mask}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}

def _reorder_cache(self, past, beam_idx):
reordered_past = ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
if past is not None:
input_ids = input_ids[:, -1:]

return {"input_ids": input_ids, "attention_mask": attention_mask}
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}

def _reorder_cache(self, past, beam_idx):
reordered_past = ()
Expand Down
32 changes: 21 additions & 11 deletions tests/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def test_greedy_generate_dict_outputs_use_cache(self):
return

config.use_cache = True
config.is_decoder = True
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Jan 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure to use causal mask for models like BERT, RoBERTa, ...

model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
Expand Down Expand Up @@ -730,6 +731,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self):
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)

config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_beam, output_generate = self._beam_search_generate(
model=model,
Expand Down Expand Up @@ -962,12 +964,7 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_
# Attentions
if config.is_encoder_decoder:
# encoder
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
self.assertIsInstance(output.encoder_attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in output.encoder_attentions],
[encoder_expected_shape] * len(output.encoder_attentions),
)
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
# decoder
self._check_attentions_for_generate(
num_sequences_in_output,
Expand All @@ -993,11 +990,8 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_
# Hidden States
if config.is_encoder_decoder:
# encoder
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
self.assertIsInstance(output.encoder_hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in output.encoder_hidden_states],
[encoder_expected_shape] * len(output.encoder_hidden_states),
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, batch_size, config, seq_length
)

# decoder
Expand Down Expand Up @@ -1052,6 +1046,14 @@ def _check_attentions_for_generate(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)

def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions),
)

def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
Expand All @@ -1071,6 +1073,14 @@ def _check_hidden_states_for_generate(
[expected_shape] * len(iter_hidden_states),
)

def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
[encoder_expected_shape] * len(hidden_states),
)


@require_torch
class UtilsFunctionsTest(unittest.TestCase):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,32 @@ def test_retain_grad_hidden_states_attentions(self):
# longformer cannot keep gradients in attentions or hidden states
return

def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
# make sure tgt_length is padded
tgt_length = (
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
) * config.attention_window[0]

encoder_expected_shape = (batch_size, config.num_attention_heads, tgt_length, seq_length)
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions),
)

def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
# make sure seq_length is padded
seq_length = (
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
) * config.attention_window[0]

encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
[encoder_expected_shape] * len(hidden_states),
)

def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
Expand Down