Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support list/tuple inputs for special tokens in MultiSegmentPacker layer #1046

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 88 additions & 23 deletions keras_nlp/layers/multi_segment_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,24 @@ class MultiSegmentPacker(keras.layers.Layer):
either rank-1 or rank-2.

Args:
sequence_length: The desired output length.
start_value: The id or token that is to be placed at the start of each
sequence (called "[CLS]" for BERT). The dtype must match the dtype
of the input tensors to the layer.
end_value: The id or token that is to be placed at the end of each
input segment (called "[SEP]" for BERT). The dtype much match the
dtype of the input tensors to the layer.
pad_value: The id or token that is to be placed into the unused
sequence_length: int. The desired output length.
start_value: int/str/list/tuple. The id(s) or token(s) that are to be
placed at the start of each sequence (called "[CLS]" for BERT). The
dtype must match the dtype of the input tensors to the layer.
end_value: int/str/list/tuple. The id(s) or token(s) that are to be
placed at the end of the last input segment (called "[SEP]" for
BERT). The dtype must match the dtype of the input tensors to the
layer.
sep_value: int/str/list/tuple. The id(s) or token(s) that are to be
placed at the end of every segment, except the last segment (called
"[SEP]" for BERT). If `None`, `end_value` is used. The dtype must
match the dtype of the input tensors to the layer.
pad_value: int/str. The id or token that is to be placed into the unused
positions after the last segment in the sequence
(called "[PAD]" for BERT).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add an example below? maybe roberta double sep?

truncate: The algorithm to truncate a list of batched segments to fit a
per-example length limit. The value can be either `round_robin` or
`waterfall`:
truncate: str. The algorithm to truncate a list of batched segments to
fit a per-example length limit. The value can be either
`"round_robin"` or `"waterfall"`:
- `"round_robin"`: Available space is assigned one token at a
time in a round-robin fashion to the inputs that still need
some, until the limit is reached.
Expand Down Expand Up @@ -101,6 +106,17 @@ class MultiSegmentPacker(keras.layers.Layer):
<tf.Tensor: shape=(8,), dtype=int32,
numpy=array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32)>)

*Pack multiple inputs for classification with different sep tokens.*
>>> seq1 = tf.constant([1, 2, 3, 4])
>>> seq2 = tf.constant([11, 12, 13, 14])
>>> packer = keras_nlp.layers.MultiSegmentPacker(
... 8, start_value=101, end_value=102, sep_value=[102, 102])
>>> packer((seq1, seq2))
(<tf.Tensor: shape=(8,), dtype=int32,
numpy=array([101, 1, 2, 102, 102, 11, 12, 102], dtype=int32)>,
<tf.Tensor: shape=(8,), dtype=int32,
numpy=array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32)>)

Reference:
[Devlin et al., 2018](https://arxiv.org/abs/1810.04805).
"""
Expand All @@ -110,6 +126,7 @@ def __init__(
sequence_length,
start_value,
end_value,
sep_value=None,
pad_value=None,
truncate="round_robin",
**kwargs,
Expand All @@ -124,17 +141,42 @@ def __init__(
"supported. Received %s" % truncate
)
self.truncate = truncate

# Maintain private copies of start/end values for config purposes.
self._start_value = start_value
self._sep_value = sep_value
self._end_value = end_value

def check_special_value_type(value, value_name):
if isinstance(value, (int, str)):
return [value]
if value and not isinstance(value, (list, tuple)):
raise ValueError(
f"{value_name} should be of type int/str/list/tuple."
f"Received type: `{type(value)}`."
)
return value

start_value = check_special_value_type(start_value, "start_value")
if sep_value is None:
sep_value = end_value
sep_value = check_special_value_type(sep_value, "sep_value")
end_value = check_special_value_type(end_value, "end_value")

self.start_value = start_value
self.sep_value = sep_value
self.end_value = end_value

self.pad_value = pad_value

def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.sequence_length,
"start_value": self.start_value,
"end_value": self.end_value,
"start_value": self._start_value,
"end_value": self._end_value,
"sep_value": self._sep_value,
"pad_value": self.pad_value,
"truncate": self.truncate,
}
Expand Down Expand Up @@ -170,7 +212,12 @@ def _convert_dense(self, x):

def _trim_inputs(self, inputs):
"""Trim inputs to desired length."""
num_special_tokens = len(inputs) + 1
num_segments = len(inputs)
num_special_tokens = (
len(self.start_value)
+ (num_segments - 1) * len(self.sep_value)
+ len(self.end_value)
)
if self.truncate == "round_robin":
return tf_text.RoundRobinTrimmer(
self.sequence_length - num_special_tokens
Expand All @@ -187,22 +234,40 @@ def _combine_inputs(self, segments):
dtype = segments[0].dtype
batch_size = segments[0].nrows()
start_value = tf.convert_to_tensor(self.start_value, dtype=dtype)
sep_value = tf.convert_to_tensor(self.sep_value, dtype=dtype)
end_value = tf.convert_to_tensor(self.end_value, dtype=dtype)

start_column = tf.fill((batch_size, 1), start_value)
end_column = tf.fill((batch_size, 1), end_value)
ones_column = tf.ones_like(start_column, dtype=tf.int32)
start_columns = tf.repeat(
start_value[tf.newaxis, :], repeats=batch_size, axis=0
)
sep_columns = tf.repeat(
sep_value[tf.newaxis, :], repeats=batch_size, axis=0
)
end_columns = tf.repeat(
end_value[tf.newaxis, :], repeats=batch_size, axis=0
)
ones_sep_columns = tf.ones_like(sep_columns, dtype=tf.int32)
ones_end_columns = tf.ones_like(end_columns, dtype=tf.int32)

segments_to_combine = [start_columns]
segment_ids_to_combine = [
tf.ones_like(start_columns, dtype=tf.int32) * 0
]

segments_to_combine = [start_column]
segment_ids_to_combine = [ones_column * 0]
for i, seg in enumerate(segments):
# Combine all segments adding end tokens.
# Combine all segments.
segments_to_combine.append(seg)
segments_to_combine.append(end_column)

# Combine segment ids accounting for end tokens.
# Combine segment ids.
segment_ids_to_combine.append(tf.ones_like(seg, dtype=tf.int32) * i)
segment_ids_to_combine.append(ones_column * i)

# Account for the sep/end tokens here.
if i == len(segments) - 1:
segments_to_combine.append(end_columns)
segment_ids_to_combine.append(ones_end_columns * i)
else:
segments_to_combine.append(sep_columns)
segment_ids_to_combine.append(ones_sep_columns * i)

token_ids = tf.concat(segments_to_combine, 1)
segment_ids = tf.concat(segment_ids_to_combine, 1)
Expand Down
28 changes: 27 additions & 1 deletion keras_nlp/layers/multi_segment_packer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Transformer Decoder."""
"""Tests for multi-segment packing."""

import os

Expand Down Expand Up @@ -147,6 +147,32 @@ def test_pad_batched_inputs(self):
),
)

def test_list_special_tokens(self):
seq1 = tf.ragged.constant([["a", "b"], ["a", "b"]])
seq2 = tf.ragged.constant([["x", "y"], ["x"]])
packer = MultiSegmentPacker(
8,
start_value="<s>",
end_value="</s>",
sep_value=["</s>", "</s>"],
pad_value="<pad>",
truncate="round_robin",
)
output = packer([seq1, seq2])
self.assertAllEqual(
output,
(
[
["<s>", "a", "b", "</s>", "</s>", "x", "y", "</s>"],
["<s>", "a", "b", "</s>", "</s>", "x", "</s>", "<pad>"],
],
[
[0, 0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 0],
],
),
)

def test_config(self):
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b"]])
seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]])
Expand Down
146 changes: 0 additions & 146 deletions keras_nlp/models/roberta/roberta_multi_segment_packer.py

This file was deleted.

9 changes: 4 additions & 5 deletions keras_nlp/models/roberta/roberta_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.models.roberta.roberta_multi_segment_packer import (
RobertaMultiSegmentPacker,
)
from keras_nlp.models.roberta.roberta_presets import backbone_presets
from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_nlp.utils.keras_utils import (
Expand Down Expand Up @@ -145,9 +143,10 @@ def __init__(
super().__init__(**kwargs)

self.tokenizer = tokenizer
self.packer = RobertaMultiSegmentPacker(
self.packer = MultiSegmentPacker(
start_value=self.tokenizer.start_token_id,
end_value=self.tokenizer.end_token_id,
sep_value=[self.tokenizer.end_token_id] * 2,
pad_value=self.tokenizer.pad_token_id,
truncate=truncate,
sequence_length=sequence_length,
Expand All @@ -166,7 +165,7 @@ def get_config(self):
def call(self, x, y=None, sample_weight=None):
x = convert_inputs_to_list_of_tensor_segments(x)
x = [self.tokenizer(segment) for segment in x]
token_ids = self.packer(x)
token_ids, _ = self.packer(x)
x = {
"token_ids": token_ids,
"padding_mask": token_ids != self.tokenizer.pad_token_id,
Expand Down
Loading