Skip to content

Commit

Permalink
test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Feb 14, 2024
1 parent 74134cc commit d19adad
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 22 deletions.
8 changes: 5 additions & 3 deletions keras_nlp/layers/modeling/transformer_layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ def compute_causal_mask(batch_size, input_length, output_length, cache_index=0):
`(batch_size, output_length, input_length)` that can be passed to a
attention layer.
"""
i = ops.expand_dims(ops.arange(output_length), axis=1) + cache_index
j = ops.arange(input_length)
mask = ops.expand_dims(ops.cast(i >= j, dtype="int32"), axis=0)
i = ops.arange(output_length, dtype="float32")
i = i + ops.cast(cache_index, "float32")
i = ops.expand_dims(i, axis=1)
j = ops.arange(input_length, dtype="float32")
mask = ops.expand_dims(i >= j, axis=0)
return ops.broadcast_to(mask, (batch_size, output_length, input_length))


Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/bart/bart_seq_2_seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def call_with_cache(
):
tokens = self.backbone.token_embedding(token_ids)
positions = self.backbone.decoder_position_embedding(
tokens, start_index=index,
tokens, start_index=index
)
# Sum, normalize and apply dropout to embeddings.
x = self.backbone.decoder_embeddings_add((tokens, positions))
Expand Down
14 changes: 5 additions & 9 deletions keras_nlp/models/bart/bart_seq_2_seq_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,30 +100,26 @@ def test_generate(self):

def test_early_stopping(self):
seq_2_seq_lm = BartSeq2SeqLM(**self.init_kwargs)
call_decoder_with_cache = seq_2_seq_lm.call_decoder_with_cache
call_with_cache = seq_2_seq_lm.call_with_cache

def wrapper(*args, **kwargs):
"""Modify output logits to always favor end_token_id"""
(
logits,
hidden_states,
self_attention_cache,
cross_attention_cache,
) = call_decoder_with_cache(*args, **kwargs)
cache,
) = call_with_cache(*args, **kwargs)
index = self.preprocessor.tokenizer.end_token_id
update = ops.ones_like(logits)[:, :, index] * 1.0e9
update = ops.expand_dims(update, axis=-1)
logits = ops.slice_update(logits, (0, 0, index), update)
return (
logits,
hidden_states,
self_attention_cache,
cross_attention_cache,
cache,
)

with patch.object(
seq_2_seq_lm, "call_decoder_with_cache", wraps=wrapper
):
with patch.object(seq_2_seq_lm, "call_with_cache", wraps=wrapper):
inputs = {
"encoder_text": [
" airplane at airport",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def generate_postprocess(
token_ids = ops.convert_to_numpy(token_ids)
if not isinstance(padding_mask, tf.Tensor):
padding_mask = ops.convert_to_numpy(padding_mask)
# Make sure the numpy array has type `int32` since
# `SentencePieceProcessor.detokenize` only accepts `int32` arrays.
token_ids = tf.cast(token_ids, "int32")
padding_mask = tf.cast(padding_mask, "bool")
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ def generate_postprocess(
# Convert the inputs to numpy arrays if they aren't a tensor already.
if not isinstance(token_ids, tf.Tensor):
token_ids = ops.convert_to_numpy(token_ids)
# Make sure the numpy array has type `int32` since
# `SentencePieceProcessor.detokenize` only accepts `int32` arrays.
token_ids = token_ids.astype("int32")
if not isinstance(padding_mask, tf.Tensor):
padding_mask = ops.convert_to_numpy(padding_mask)
padding_mask = padding_mask.astype("bool")
# Make sure the numpy array has type `int32` since
# `SentencePieceProcessor.detokenize` only accepts `int32` arrays.
token_ids = tf.cast(token_ids, "int32")
padding_mask = tf.cast(padding_mask, "bool")
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
Expand Down
5 changes: 3 additions & 2 deletions keras_nlp/models/mistral/mistral_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,9 @@ def _compute_self_attention_mask(
# Below is a workaround for `ops.triu` for Keras 2.
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed.
# causal_mask = ops.triu(causal_mask_lower, k=-self.sliding_window)
i = ops.arange(output_length)[:, None] + cache_update_index
j = ops.arange(input_length)[None, :]
i = ops.arange(output_length, dtype="int32")[:, None]
i = i + ops.cast(cache_update_index, "int32")
j = ops.arange(input_length, dtype="int32")[None, :]
causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32")
causal_mask = ops.minimum(causal_mask_lower, causal_mask_upper)

Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/t5/t5_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def call(
shape = ops.shape(hidden_states)
batch_size, length = shape[0], shape[1]
causal_mask = compute_causal_mask(batch_size, length, length)
attention_mask = ops.cast(attention_mask, "int32")
attention_mask = ops.cast(attention_mask, "bool")
attention_mask = causal_mask & attention_mask

x = hidden_states # Intermediate result.
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def start(self, data):
data = tree.map_structure(self.create_beams, data)
# Setup the initial beam log-likelihoods.
log_probs = [[0.0] + [-1e9] * (self.num_beams - 1)]
log_probs = ops.array(log_probs)
log_probs = ops.array(log_probs, dtype="float32")
log_probs = self.flatten_beams(ops.repeat(log_probs, batch_size, 0))
return {**data, "log_probabilities": log_probs}

Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.beam_sampler import BeamSampler
from keras_nlp.samplers.contrastive_sampler import ContrastiveSampler
from keras_nlp.samplers.greedy_sampler import GreedySampler
from keras_nlp.samplers.random_sampler import RandomSampler
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.top_k_sampler import TopKSampler
from keras_nlp.samplers.top_p_sampler import TopPSampler

Expand Down

0 comments on commit d19adad

Please sign in to comment.