Skip to content

Commit

Permalink
Implement TopP, TopK and Beam samplers (#652)
Browse files Browse the repository at this point in the history
* initial commit

* Add keras_nlp.samplers

* Change padding to left to right

* Add serialization support, and move some args from constructor to call

* Add string example

* small changes

* Address comments: fix docstring, remove multicase support

* Address comments: move token_probability_fn to the second place

* some initials

* add more sampler class, and a few changes on the base sampler class

* add some arg defaults

* refactor the interface

* add get_config
  • Loading branch information
chenmoneygithub authored Jan 18, 2023
1 parent 1cb93f2 commit 612bbf5
Show file tree
Hide file tree
Showing 12 changed files with 1,105 additions and 191 deletions.
13 changes: 10 additions & 3 deletions keras_nlp/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

from tensorflow import keras

from keras_nlp.samplers.greedy import Greedy
from keras_nlp.samplers.beam_sampler import BeamSampler
from keras_nlp.samplers.greedy_sampler import GreedySampler
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.top_k_sampler import TopKSampler
from keras_nlp.samplers.top_p_sampler import TopPSampler


def serialize(sampler):
Expand All @@ -24,7 +28,10 @@ def serialize(sampler):
def deserialize(config, custom_objects=None):
"""Return a `Sampler` object from its config."""
all_classes = {
"greedy": Greedy,
"beam": BeamSampler,
"greedy": GreedySampler,
"top_k": TopKSampler,
"top_p": TopPSampler,
}
return keras.utils.deserialize_keras_object(
config,
Expand All @@ -46,7 +53,7 @@ def get(identifier):
dict containing `class_name` and `config` as an identifier. Also note that
the `class_name` must map to a `Sampler` class.
>>> cfg = {'class_name': 'keras_nlp>Greedy', 'config': {}}
>>> cfg = {'class_name': 'keras_nlp>GreedySampler', 'config': {}}
>>> sampler = keras_nlp.samplers.get(cfg)
In the case that the `identifier` is a class, this method will return a new
Expand Down
218 changes: 218 additions & 0 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Beam Sampler."""

import tensorflow as tf
from tensorflow import keras

from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import base_sampler_args_docstring
from keras_nlp.samplers.sampler import call_args_docstring
from keras_nlp.utils.python_utils import format_docstring


@format_docstring(
base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring
)
@keras.utils.register_keras_serializable(package="keras_nlp")
class BeamSampler(Sampler):
"""Beam Sampler class.
This sampler implements beam search algorithm. At each time-step, beam
search keeps the beams (sequences) of the top `num_beams` highest
accumulated probabilities, and uses each one of the beams to predict
candidate next tokens.
Args:
num_beams: int. The number of beams that should be kept at each
time-step. `num_beams` should be strictly positive.
{{base_sampler_args}}
Call Args:
{{call_args}}
Examples:
```python
VOCAB_SIZE = 10
# Create a dummy model to predict the next token.
model = keras.Sequential(
[
keras.Input(shape=[None]),
keras.layers.Embedding(
input_dim=VOCAB_SIZE,
output_dim=16,
),
keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
]
)
# Define a function that outputs the next token's probability for each token
# in the input sequence.
def token_probability_fn(inputs, mask):
return model(inputs)
prompt = tf.fill((8, 1), 1)
sampler = keras_nlp.samplers.BeamSampler(num_beams=3)
# Print the generated sequence (token ids).
print(sampler(prompt, token_probability_fn, max_length=10))
```
"""

def __init__(
self,
num_beams=5,
jit_compile=True,
run_eagerly=False,
):
self.num_beams = num_beams
super().__init__(jit_compile=jit_compile, run_eagerly=run_eagerly)

def get_next_token(self, next_token_probs):
# Beam search overrides the whole `sample` method.
pass

def sample(
self, prompt, token_probability_fn, mask, num_steps, from_logits=True
):
"""Sampling logic implementation.
Because beam search uses a different loop body, we have to override the
whole `sample` method instead of just the `get_next_token` method.
"""
batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
max_length = tf.cast(max_length, num_steps.dtype)
length = max_length - num_steps
dummy_preds = token_probability_fn(prompt, mask=mask)
vocab_size = tf.shape(dummy_preds)[-1]
pred_dtype = dummy_preds.dtype

num_beams = self.num_beams

# Initialize beam with shape `(batch_size, num_beams, length)`.
beams = tf.repeat(tf.expand_dims(prompt, axis=1), num_beams, axis=1)
# Initialize `beams_prob` with shape `(batch_size, num_beams)`.
beams_prob = tf.zeros([batch_size, 1], dtype=pred_dtype)
beams_prob = tf.concat(
[beams_prob, tf.fill((batch_size, num_beams - 1), pred_dtype.min)],
axis=-1,
)

def one_step(beams, beams_prob, length, mask):

flattened_beams = tf.reshape(
beams, shape=[batch_size * num_beams, -1]
)
repeated_mask = tf.tile(mask, [num_beams, 1])
probs = token_probability_fn(flattened_beams, repeated_mask)
preds = tf.gather(
probs,
tf.repeat(length - 1, batch_size * num_beams),
axis=1,
batch_dims=1,
)
if from_logits:
preds = keras.activations.softmax(preds, axis=-1)
# Reshape `preds` to shape `(batch_size, num_beams * vocab_size)`.

preds = tf.reshape(preds, shape=[batch_size, -1])

cum_probs = tf.math.log(preds) + tf.repeat(
beams_prob, repeats=vocab_size, axis=1
)

candidate_prob, candidate_indexes = tf.math.top_k(
cum_probs, k=num_beams, sorted=False
)

candidate_beam_indexes = candidate_indexes // vocab_size
next_token = candidate_indexes % vocab_size

beams = tf.gather(
beams, candidate_beam_indexes, axis=1, batch_dims=1
)

# Build a new column of updates to scatter into the beam tensor.
next_token = tf.where(
condition=mask[..., length, tf.newaxis],
x=beams[..., length],
y=next_token,
)
next_token = tf.reshape(next_token, shape=[-1])

mask = tf.tensor_scatter_nd_update(
tensor=mask,
indices=tf.stack(
(
tf.cast(tf.range(batch_size), dtype=length.dtype),
tf.repeat(length, batch_size),
),
axis=1,
),
updates=tf.repeat(True, batch_size),
)

# Generate `(batch_index, beam_index)` tuples for each beam.
beam_indices = tf.where(tf.ones((batch_size, num_beams), tf.bool))
beam_indices = tf.cast(beam_indices, dtype=length.dtype)
# Build a tensor of repeated `length` values.
length_indices = tf.fill((batch_size * num_beams, 1), length)
# Concatenate to a triplet of `(batch_index, beam_index, length)`.
indices = tf.concat([beam_indices, length_indices], axis=-1)

# Update `beams[:, :, length]` with `next_token`.
beams = tf.tensor_scatter_nd_update(
tensor=beams,
indices=indices,
updates=next_token,
)

beams_prob = candidate_prob

length = tf.add(length, 1)
return beams, beams_prob, length, mask

# Run a while loop till text of length `max_length` has been generated.
beams, beams_prob, length, mask = tf.while_loop(
cond=lambda beams, beams_prob, length, mask: tf.less(
length, max_length
),
body=one_step,
loop_vars=[beams, beams_prob, length, mask],
# There is a strange issue that when `batch_size=1`, the first loop
# iteration changes `beams_prob`'s shape from [1, None] to
# [None, None], which does not happen for `batch_size>1`.
# As a workaround, we set shape invariants.
shape_invariants=[
beams.get_shape(),
tf.TensorShape([None, None]),
length.get_shape(),
mask.get_shape(),
],
)

# Get the beam with the maximum probability.
max_indexes = tf.math.argmax(beams_prob, axis=-1)
max_beams = tf.gather(
beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1
)

return tf.squeeze(max_beams, axis=1)

def get_config(self):
config = super().get_config()

config.update({"num_beams": self.num_beams})
return config
137 changes: 137 additions & 0 deletions keras_nlp/samplers/beam_sampler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Beam sampler."""

import random

import tensorflow as tf
from absl.testing import parameterized
from tensorflow import keras

from keras_nlp.samplers.beam_sampler import BeamSampler
from keras_nlp.samplers.greedy_sampler import GreedySampler


class BeamSamplerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.vocab_size = 10
self.feature_size = 16

# Create a dummy model to predict the next token.
model = keras.Sequential(
[
keras.Input(shape=[None]),
keras.layers.Embedding(
input_dim=self.vocab_size,
output_dim=self.feature_size,
),
keras.layers.Dense(self.vocab_size),
keras.layers.Softmax(),
]
)

def token_probability_fn(inputs, mask):
return model(inputs)

self.token_probability_fn = token_probability_fn
self.sampler = BeamSampler(num_beams=2)

def test_generate_with_1d_prompt(self):
inputs = tf.constant([1])
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [5])

def test_generate_with_2d_prompt(self):
inputs = tf.constant([[1], [1]])
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [2, 5])

def test_generate_with_list_prompt(self):
inputs = [[1], [1]]
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [2, 5])

def test_generate_with_ragged_prompt(self):
def token_probability_fn(inputs, mask):
batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1]
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.tile(prob, [batch_size, seq_length, 1])

inputs = tf.ragged.constant([[1], [2, 1, 2]])
outputs = self.sampler(inputs, token_probability_fn, max_length=5)
self.assertEqual(outputs.shape, [2, 5])

def test_one_beam_generation(self):
for _ in range(5):
inputs = tf.constant([random.randint(0, 9)])
beam_sampler = BeamSampler(num_beams=1)
greedy_sampler = GreedySampler()
beam_output = beam_sampler(
inputs,
self.token_probability_fn,
max_length=5,
)
greedy_output = greedy_sampler(
inputs,
self.token_probability_fn,
max_length=5,
)
self.assertAllEqual(beam_output, greedy_output)

@parameterized.named_parameters(
("xla_graph", True, False),
("non_xla_graph", False, False),
("eager", False, True),
)
def test_assert_generation_is_correct(self, jit_compile, run_eagerly):
def token_probability_fn(inputs, mask):
batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1]
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.tile(prob, [batch_size, seq_length, 1])

batch_size = 10
inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32)
max_length = 3
for i in range(1, 5):
sampler = BeamSampler(
num_beams=i,
jit_compile=jit_compile,
run_eagerly=run_eagerly,
)
outputs = sampler(
inputs,
token_probability_fn,
max_length=max_length,
)
self.assertAllEqual(
outputs, 3 * tf.ones(shape=[batch_size, max_length])
)

def test_end_token_id(self):
def token_probability_fn(inputs, mask):
batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1]
prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]])
return tf.tile(prob, [batch_size, seq_length, 1])

max_length = 5
inputs = tf.constant([[0, 1], [1, 2]])
outputs = self.sampler(
inputs,
token_probability_fn,
max_length=max_length,
end_token_id=2,
)
expected_outputs = tf.ragged.constant([[0, 1, 3, 3, 3], [1]])
self.assertAllEqual(outputs, expected_outputs)
Loading

0 comments on commit 612bbf5

Please sign in to comment.