From 093850eaa06686261fbef48fff4a094e01b2414b Mon Sep 17 00:00:00 2001 From: ydshieh Date: Mon, 12 Sep 2022 10:40:21 +0200 Subject: [PATCH] Remove decoder_position_ids from check_decoder_model_past_large_inputs --- tests/models/bart/test_modeling_tf_bart.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py index db06c84e0f5b86..69cf530ee6b322 100644 --- a/tests/models/bart/test_modeling_tf_bart.py +++ b/tests/models/bart/test_modeling_tf_bart.py @@ -125,21 +125,10 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1) - decoder_position_ids = tf.cast(tf.cumsum(next_attention_mask, axis=1, exclusive=True), dtype=tf.int32) - output_from_no_past = model( - next_input_ids, attention_mask=next_attention_mask, position_ids=decoder_position_ids - ) + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask) output_from_no_past = output_from_no_past[0] - decoder_position_ids = ( - tf.cast(tf.cumsum(next_attn_mask, axis=1, exclusive=True), dtype=tf.int32) + past_key_values[0][0].shape[2] - ) - output_from_past = model( - next_tokens, - attention_mask=next_attention_mask, - past_key_values=past_key_values, - position_ids=decoder_position_ids, - ) + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values) output_from_past = output_from_past[0] self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])