Skip to content

Commit

Permalink
Generation refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Apr 9, 2024
1 parent a2a9602 commit 52b3d77
Show file tree
Hide file tree
Showing 28 changed files with 739 additions and 1,567 deletions.
26 changes: 25 additions & 1 deletion keras_nlp/layers/modeling/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,28 @@ def build(
# Create layers based on input shape.
self.built = True

def compute_self_attention_cache(
self,
decoder_sequence,
):
x = decoder_sequence
if self.normalize_first:
x = self._self_attention_layer_norm(x)
key = self._self_attention_layer._key_dense(x)
value = self._self_attention_layer._value_dense(x)
return ops.stack((key, value), axis=1)

def compute_cross_attention_cache(
self,
encoder_sequence,
):
x = encoder_sequence
if self.normalize_first:
x = self._cross_attention_layer_norm(x)
key = self._cross_attention_layer._key_dense(x)
value = self._cross_attention_layer._value_dense(x)
return ops.stack((key, value), axis=1)

def __call__(
self,
decoder_sequence,
Expand Down Expand Up @@ -325,7 +347,9 @@ def call(
the layer has cross-attention.
"""

has_encoder_sequence = encoder_sequence is not None
has_encoder_sequence = (
encoder_sequence is not None or cross_attention_cache is not None
)

has_cross_attention = self._cross_attention_layer is not None
if not has_cross_attention and has_encoder_sequence:
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/layers/preprocessing/start_end_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def call(
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs

if self.return_padding_mask:
mask = tf.ones_like(x, dtype="bool")
mask = tf.ones_like(x, dtype="int32")
mask = mask.to_tensor(shape=(batch_size, sequence_length))
mask = tf.squeeze(mask, axis=0) if unbatched else mask
return outputs, mask
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/models/bart/bart_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,5 +254,4 @@ def get_config(self):
"max_sequence_length": self.max_sequence_length,
}
)

return config
310 changes: 43 additions & 267 deletions keras_nlp/models/bart/bart_seq_2_seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
BartSeq2SeqLMPreprocessor,
)
from keras_nlp.models.seq_2_seq_lm import Seq2SeqLM
from keras_nlp.utils.tensor_utils import any_equal


@keras_nlp_export("keras_nlp.models.BartSeq2SeqLM")
Expand Down Expand Up @@ -199,291 +198,68 @@ def __init__(
**kwargs,
)

def call_decoder_with_cache(
def build_cache(self, batch_size, max_length):
num_layers = self.backbone.num_layers
num_heads = self.backbone.num_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
return ops.zeros(shape, dtype=self.compute_dtype)

def compute_cross_attention_cache(
self, encoder_token_ids, encoder_padding_mask
):
"""Does a forward pass on the encoder and returns the encoder output."""
# Embedding layers.
tokens = self.backbone.token_embedding(encoder_token_ids)
positions = self.backbone.encoder_position_embedding(tokens)
# Sum, normalize and apply dropout to embeddings.
x = self.backbone.encoder_embeddings_add((tokens, positions))
x = self.backbone.encoder_embeddings_layer_norm(x)
x = self.backbone.encoder_embeddings_dropout(x)
# Transformer encoder layers.
for layer in self.backbone.encoder_transformer_layers:
x = layer(x, padding_mask=encoder_padding_mask)
# Transformer encoder layers.
caches = []
for layer in self.backbone.decoder_transformer_layers:
caches.append(layer.compute_cross_attention_cache(x))
return ops.stack(caches, axis=1)

def call_with_cache(
self,
encoder_hidden_states,
token_ids,
cache,
index,
*,
encoder_padding_mask,
decoder_token_ids,
self_attention_cache=None,
self_attention_cache_update_index=None,
cross_attention_cache=None,
cross_attention_cache_update_index=None,
cross_attention_cache,
):
"""Forward pass with a key/value caches for generative decoding..
`call_decoder_with_cache` adds an additional inference-time forward pass
for the model for seq2seq text generation. Unlike calling the model
directly, this method does two things to optimize text generation:
- Allows caching previous key/value tensors in the decoder's
self-attention layer to avoid recomputing the outputs of seen tokens.
- Allows caching key/value tensors in the decoder's cross-attention
layer to avoid recomputing the encoder outputs.
Args:
encoder_hidden_states: a dense float Tensor of shape
`(batch_size, encoder_sequence_length, hidden_dim)`. The
sequence of hidden states at the output of the encoder's last
layer.
encoder_padding_mask: a dense float Tensor of shape
`(batch_size, encoder_sequence_length)`. The padding mask for
the encoder input.
decoder_token_ids: a dense int Tensor of shape
`(batch_size, max_length)`. Input token ids to be fed to
the decoder.
self_attention_cache: a dense float Tensor of shape
`(batch_size, num_layers, 2, max_length, num_heads, key_dims)`.
The cached key/value tensors of previously seen tokens in the
decoder's self-attention layer.
self_attention_cache_update_index: an int or int Tensor, the index
at which to update the `self_attention_cache`. Usually, this is
the index of the current token being processed during decoding.
cross_attention_cache: a dense float Tensor of shape
`(batch_size, num_layers, 2, encoder_sequence_length, num_heads, key_dims)`.
The cached key/value tensors of the encoder outputs in the
decoder's cross-attention layer.
cross_attention_cache_update_index: an int or int Tensor, the index
at which to update the `cross_attention_cache`. Usually, this is
either `0` (compute the entire `cross_attention_cache`), or
`None` (reuse a previously computed `cross_attention_cache`).
Returns:
A `(logits, hidden_states, self_attention_cache, cross_attention_cache)`
tuple, where `logits` is the language model logits for the input
`decoder_token_ids`, `hidden_states` is the final hidden
representation of the input tokens, `self_attention_cache` is the
key/value cache in the decoder's self-attention layer and
`cross_attention_cache` is the key/value cache in the decoder's
cross-attention layer.
"""
# Embedding layers.
tokens = self.backbone.token_embedding(decoder_token_ids)
tokens = self.backbone.token_embedding(token_ids)
positions = self.backbone.decoder_position_embedding(
tokens,
start_index=self_attention_cache_update_index,
tokens, start_index=index
)
# Sum, normalize and apply dropout to embeddings.
x = self.backbone.decoder_embeddings_add((tokens, positions))
x = self.backbone.decoder_embeddings_layer_norm(x)
x = self.backbone.decoder_embeddings_dropout(x)

# Every decoder layer has a separate cache for the self-attention layer
# and the cross-attention layer. We update all of them separately.
self_attention_caches = []
cross_attention_caches = []
# Each decoder layer has a cache; we update them separately.
caches = []
for i, layer in enumerate(self.backbone.decoder_transformer_layers):
current_self_attention_cache = self_attention_cache[:, i, ...]
current_self_attention_cache = cache[:, i, ...]
current_cross_attention_cache = cross_attention_cache[:, i, ...]
(
x,
next_self_attention_cache,
next_cross_attention_cache,
) = layer(
x, next_cache, _ = layer(
decoder_sequence=x,
encoder_sequence=encoder_hidden_states,
encoder_padding_mask=encoder_padding_mask,
self_attention_cache=current_self_attention_cache,
self_attention_cache_update_index=self_attention_cache_update_index,
self_attention_cache_update_index=index,
cross_attention_cache=current_cross_attention_cache,
cross_attention_cache_update_index=cross_attention_cache_update_index,
)
if self_attention_cache_update_index is not None:
self_attention_caches.append(next_self_attention_cache)
if cross_attention_cache_update_index is not None:
cross_attention_caches.append(next_cross_attention_cache)

if self_attention_cache_update_index is not None:
self_attention_cache = ops.stack(self_attention_caches, axis=1)
if cross_attention_cache_update_index is not None:
cross_attention_cache = ops.stack(cross_attention_caches, axis=1)

caches.append(next_cache)
cache = ops.stack(caches, axis=1)
hidden_states = x
logits = self.backbone.token_embedding(hidden_states, reverse=True)
return (
logits,
hidden_states,
self_attention_cache,
cross_attention_cache,
cache,
)

def call_encoder(self, token_ids, padding_mask):
"""Does a forward pass on the encoder and returns the encoder output."""
tokens = self.backbone.token_embedding(token_ids)
positions = self.backbone.encoder_position_embedding(tokens)
x = self.backbone.decoder_embeddings_add((tokens, positions))
x = self.backbone.encoder_embeddings_layer_norm(x)
x = self.backbone.encoder_embeddings_dropout(x)
for transformer_layer in self.backbone.encoder_transformer_layers:
x = transformer_layer(x, padding_mask=padding_mask)
return x

def _initialize_cache(self, encoder_token_ids, decoder_token_ids):
"""Initializes empty self-attention cache and cross-attention cache."""
batch_size = ops.shape(encoder_token_ids)[0]
encoder_max_length = ops.shape(encoder_token_ids)[1]
decoder_max_length = ops.shape(decoder_token_ids)[1]

num_layers = self.backbone.num_layers
num_heads = self.backbone.num_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_heads

shape = [
batch_size,
num_layers,
2,
decoder_max_length,
num_heads,
head_dim,
]
self_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)

shape[3] = encoder_max_length
cross_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)

return (self_attention_cache, cross_attention_cache)

def _build_cache(
self, encoder_token_ids, encoder_padding_mask, decoder_token_ids
):
"""Builds the self-attention cache and the cross-attention cache (key/value pairs)."""
encoder_hidden_states = self.call_encoder(
token_ids=encoder_token_ids, padding_mask=encoder_padding_mask
)
self_attention_cache, cross_attention_cache = self._initialize_cache(
encoder_token_ids, decoder_token_ids
)

# Seed the self-attention cache and the cross-attention cache.
(
_,
hidden_states,
self_attention_cache,
cross_attention_cache,
) = self.call_decoder_with_cache(
encoder_hidden_states=encoder_hidden_states,
encoder_padding_mask=encoder_padding_mask,
decoder_token_ids=decoder_token_ids,
self_attention_cache=self_attention_cache,
self_attention_cache_update_index=0,
cross_attention_cache=cross_attention_cache,
cross_attention_cache_update_index=0,
)
return (
hidden_states,
encoder_hidden_states,
self_attention_cache,
cross_attention_cache,
)

def generate_step(
self,
inputs,
stop_token_ids=None,
):
"""A compilable generation function for a batch of inputs.
This function represents the inner, XLA-compilable, generation function
for a single batch of inputs. Inputs should have the same structure as
model inputs, a dictionary with keys `"encoder_token_ids"`,
`"encoder_padding_mask"`, `"decoder_token_ids"` and
`"decoder_padding_mask"`.
Args:
inputs: A dictionary with four keys - `"encoder_token_ids"`,
`"encoder_padding_mask"`, `"decoder_token_ids"` and
`"decoder_padding_mask"`, with batched tensor values.
stop_token_ids: Tuple of id's of end token's to stop on. If all
sequences have produced a new stop token, generation
will stop.
"""
(
encoder_token_ids,
encoder_padding_mask,
decoder_token_ids,
decoder_padding_mask,
) = (
inputs["encoder_token_ids"],
inputs["encoder_padding_mask"],
inputs["decoder_token_ids"],
inputs["decoder_padding_mask"],
)

batch_size = ops.shape(encoder_token_ids)[0]

# Create and seed cache with a single forward pass.
(
hidden_states,
encoder_hidden_states,
self_attention_cache,
cross_attention_cache,
) = self._build_cache(
encoder_token_ids, encoder_padding_mask, decoder_token_ids
)
# Compute the lengths of all user inputted tokens ids.
row_lengths = ops.sum(ops.cast(decoder_padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
index = ops.min(row_lengths)

def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_index = index - 1
num_samples = ops.shape(prompt)[0]
prompt = ops.slice(prompt, [0, cache_index], [num_samples, 1])

def repeat_tensor(x):
"""Repeats tensors along batch axis to match dim for beam search."""
if ops.shape(x)[0] == num_samples:
return x
return ops.repeat(x, repeats=num_samples // batch_size, axis=0)

logits, hidden_states, cache, _ = self.call_decoder_with_cache(
encoder_hidden_states=repeat_tensor(encoder_hidden_states),
encoder_padding_mask=repeat_tensor(encoder_padding_mask),
decoder_token_ids=prompt,
self_attention_cache=cache,
self_attention_cache_update_index=cache_index,
cross_attention_cache=repeat_tensor(cross_attention_cache),
cross_attention_cache_update_index=None,
)
return (
ops.squeeze(logits, axis=1),
ops.squeeze(hidden_states, axis=1),
cache,
)

decoder_token_ids = self.sampler(
next=next,
prompt=decoder_token_ids,
cache=self_attention_cache,
index=index,
mask=decoder_padding_mask,
stop_token_ids=stop_token_ids,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
if stop_token_ids is not None:
# Build a mask of `stop_token_ids` locations not in the original
# prompt (not in locations where `decoder_padding_mask` is True).
end_locations = any_equal(
decoder_token_ids,
stop_token_ids,
ops.logical_not(decoder_padding_mask),
)
end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after `end_locations`.
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
overflow = cumsum - end_locations
# Our padding mask is the inverse of these overflow locations.
decoder_padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
else:
# Without early stopping, all locations will have been updated.
decoder_padding_mask = ops.ones_like(
decoder_token_ids, dtype="bool"
)

return {
"decoder_token_ids": decoder_token_ids,
"decoder_padding_mask": decoder_padding_mask,
}
Loading

0 comments on commit 52b3d77

Please sign in to comment.