Skip to content

Commit

Permalink
TF: T5 can now handle a padded past (i.e. XLA generation) (#17969)
Browse files Browse the repository at this point in the history
* get the right slicing index for position_bias
  • Loading branch information
gante authored Jul 4, 2022
1 parent e3139ad commit f098268
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
16 changes: 13 additions & 3 deletions src/transformers/models/t5/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_slice

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
Expand Down Expand Up @@ -384,10 +385,19 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
else:
position_bias = self.compute_bias(real_seq_length, key_length)

# if key and values are already calculated
# we want only the last query position bias
# if key and values are already calculated we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -seq_length:, :]
if not self.has_relative_attention_bias:
position_bias = position_bias[:, :, -seq_length:, :]
else:
# we might have a padded past structure, in which case we want to fetch the position bias slice
# right after the most recently filled past index
most_recently_filled_past_index = tf.reduce_max(tf.where(past_key_value[0][0, 0, :, 0] != 0.0))
position_bias = dynamic_slice(
position_bias,
(0, 0, most_recently_filled_past_index + 1, 0),
(1, self.n_heads, seq_length, real_seq_length),
)

if mask is not None:
position_bias = tf.cast(position_bias, dtype=mask.dtype)
Expand Down
12 changes: 4 additions & 8 deletions tests/models/t5/test_modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,21 +590,17 @@ def test_beam_search_xla_generate_simple(self):
]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids

# xla_generate = tf.function(model.generate, jit_compile=True)
xla_generate = tf.function(model.generate)
xla_generate = tf.function(model.generate, jit_compile=True)

# TODO (joao): there is something not quite right with XLA T5 -- as we increase `max_length` the two outputs
# drift appart, where the XLA version clearly degrades its quality. XLA-related variables look fine (they are
# being padded and filled in the right places). This also happens in other generation modes. Investigate.
output_ids = model.generate(input_ids, num_beams=2, max_length=9)
output_ids_xla = xla_generate(input_ids, num_beams=2, max_length=9)
output_ids = model.generate(input_ids, num_beams=2)
output_ids_xla = xla_generate(input_ids, num_beams=2)

output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)

expected_output_string = [
"Aujourd'hui est une belle journée.",
"J'ai quatre chats,",
"J'ai quatre chats, trois chiens, deux oiseaux et un cheval.",
]

self.assertListEqual(expected_output_string, output_strings)
Expand Down

0 comments on commit f098268

Please sign in to comment.