Skip to content

Commit

Permalink
Remove decoder_position_ids from `check_decoder_model_past_large_in…
Browse files Browse the repository at this point in the history
…puts` (#18980)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh authored Sep 12, 2022
1 parent a86acb7 commit 0b36970
Showing 1 changed file with 2 additions and 13 deletions.
15 changes: 2 additions & 13 deletions tests/models/bart/test_modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,10 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)

decoder_position_ids = tf.cast(tf.cumsum(next_attention_mask, axis=1, exclusive=True), dtype=tf.int32)
output_from_no_past = model(
next_input_ids, attention_mask=next_attention_mask, position_ids=decoder_position_ids
)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)
output_from_no_past = output_from_no_past[0]

decoder_position_ids = (
tf.cast(tf.cumsum(next_attn_mask, axis=1, exclusive=True), dtype=tf.int32) + past_key_values[0][0].shape[2]
)
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
position_ids=decoder_position_ids,
)
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)
output_from_past = output_from_past[0]

self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
Expand Down

0 comments on commit 0b36970

Please sign in to comment.