Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bloom preprocessor #1424

Merged
merged 7 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions keras_nlp/models/bloom/bloom_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _bloom_kernel_initializer(stddev=0.02):

@keras_nlp_export("keras_nlp.models.BloomBackbone")
class BloomBackbone(Backbone):
"""A Bloom decoder network.
"""A BLOOM decoder network.

This network implements a Transformer-based decoder network, BigScience
Language Open-science Open-access Multilingual (BLOOM), as descriped in
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
intermediate_dim,
dropout=0.0,
layer_norm_epsilon=1e-5,
max_sequence_length=512,
max_sequence_length=2048,
**kwargs,
):
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
Expand Down
182 changes: 182 additions & 0 deletions keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from absl import logging

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import ops
from keras_nlp.models.bloom.bloom_preprocessor import BloomPreprocessor
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras_nlp_export("keras_nlp.models.BloomCausalLMPreprocessor")
class BloomCausalLMPreprocessor(BloomPreprocessor):
"""BLOOM Causal LM preprocessor.

This preprocessing layer is meant for use with
`keras_nlp.models.BloomCausalLM`. By default, it will take in batches of
strings, and return outputs in a `(x, y, sample_weight)` format, where the
`y` label is the next token id in the `x` sequence.

For use with generation, the layer also exposes two methods
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
is attached to a `keras_nlp.models.BloomCausalLM` instance, these methods
will be called implicitly in `generate()`. They can also be called
standalone (e.g. to precompute preprocessing inputs for generation in a
separate process).

Args:
tokenizer: A `keras_nlp.models.BloomTokenizer` instance.
sequence_length: The length of the packed inputs.
add_start_token: If `True`, the preprocessor will prepend the tokenizer
start token to each input sequence.
add_end_token: If `True`, the preprocessor will append the tokenizer
end token to each input sequence.

Call arguments:
x: A string, `tf.Tensor` or list of python strings.
y: Label data. Should always be `None` as the layer generates labels.
sample_weight: Label weights. Should always be `None` as the layer
generates label weights.
sequence_length: Pass to override the configured `sequence_length` of
the layer.

Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.BloomCausalLMPreprocessor.from_preset(
"bloom_560m_multi"
)

# Tokenize and pack a single sentence.
sentence = tf.constant("League of legends")
preprocessor(sentence)
# Same output.
preprocessor("League of legends")

# Tokenize a batch of sentences.
sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
preprocessor(sentences)
# Same output.
preprocessor(["Taco tuesday", "Fish taco please!"])

# Map a dataset to preprocess a single sentence.
features = tf.constant(
[
"Avatar 2 is amazing!",
"Well, I am not sure.",
]
)
labels = tf.constant([1, 0])
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Map a dataset to preprocess unlabled sentences.
ds = tf.data.Dataset.from_tensor_slices(features)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
```
"""

def call(
self,
x,
y=None,
sample_weight=None,
sequence_length=None,
):
if y is not None or sample_weight is not None:
logging.warning(
"`BloomCausalLMPreprocessor` generates `y` and `sample_weight` "
"based on your input data, but your data already contains `y` "
"or `sample_weight`. Your `y` and `sample_weight` will be "
"ignored."
)
sequence_length = sequence_length or self.sequence_length

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
# Pad with one extra token to account for the truncation below.
token_ids, padding_mask = self.packer(
x,
sequence_length=sequence_length + 1,
add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
)
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)

def generate_preprocess(
self,
x,
sequence_length=None,
):
"""Covert strings to integer token input for generation.

Similar to calling the layer for training, this method takes in strings
or tensor strings, tokenizes and packs the input, and computes a padding
mask masking all inputs not filled in with a padded value.

Unlike calling the layer for training, this method does not compute
labels and will never append a `tokenizer.end_token_id` to the end of
the sequence (as generation is expected to continue at the end of the
inputted prompt).
"""
if not self.built:
self.build(None)

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
token_ids, padding_mask = self.packer(
x, sequence_length=sequence_length, add_end_value=False
)
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def generate_postprocess(
self,
x,
):
"""Covert integer token output to strings for generation.

This method reverses `generate_preprocess()`, by first removing all
padding and start/end tokens, and then converting the integer sequence
back to a string.
"""
if not self.built:
self.build(None)

token_ids, padding_mask = x["token_ids"], x["padding_mask"]
token_ids = ops.convert_to_numpy(token_ids)
padding_mask = ops.convert_to_numpy(padding_mask)
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = (
padding_mask
& (token_ids != self.tokenizer.eos_token_id)
& (token_ids != self.tokenizer.bos_token_id)
)
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
return self.tokenizer.detokenize(token_ids)
94 changes: 94 additions & 0 deletions keras_nlp/models/bloom/bloom_causal_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import (
BloomCausalLMPreprocessor,
)
from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer
from keras_nlp.tests.test_case import TestCase


class BloomCausalLMPreprocessorTest(TestCase):
def setUp(self):
self.vocab = ["<pad>", "<s>", "</s>"]
self.vocab += ["!", "air", "Ġair", "plane", "Ġat", "port"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
self.tokenizer = BloomTokenizer(
vocabulary=self.vocab,
merges=self.merges,
)
self.init_kwargs = {
"tokenizer": self.tokenizer,
"sequence_length": 8,
}
self.input_data = ["airplane at airport"]

def test_preprocessor_basics(self):
self.run_preprocessing_layer_test(
cls=BloomCausalLMPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output=(
{
"token_ids": [[1, 4, 6, 7, 5, 8, 2, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
[[4, 6, 7, 5, 8, 2, 0, 0]], # Pass through labels.
[[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights.
),
)

def test_no_start_end_token(self):
input_data = ["airplane at airport"] * 4

preprocessor = BloomCausalLMPreprocessor(
**self.init_kwargs,
add_start_token=False,
add_end_token=False,
)
x, y, sw = preprocessor(input_data)
self.assertAllEqual(x["token_ids"], [[4, 6, 7, 5, 8, 0, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
self.assertAllEqual(y, [[6, 7, 5, 8, 0, 0, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)

def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = BloomCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
self.assertAllEqual(x["token_ids"], [1, 4, 6, 7, 5, 8, 0, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])

def test_generate_postprocess(self):
input_data = {
"token_ids": [1, 4, 6, 7, 5, 8, 0, 0],
"padding_mask": [1, 1, 1, 1, 1, 1, 0, 0],
}
preprocessor = BloomCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_postprocess(input_data)
self.assertAllEqual(x, "airplane at airport")

@pytest.mark.extra_large
def test_all_presets(self):
for preset in BloomCausalLMPreprocessor.presets:
self.run_preset_test(
cls=BloomCausalLMPreprocessor,
preset=preset,
input_data=self.input_data,
)
Loading
Loading