Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support list/tuple inputs for special tokens in MultiSegmentPacker layer #1046

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 65 additions & 17 deletions keras_nlp/layers/multi_segment_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,16 @@ class MultiSegmentPacker(keras.layers.Layer):

Args:
sequence_length: The desired output length.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add the type notes you have in the other layer for consistency

start_value: The id or token that is to be placed at the start of each
sequence (called "[CLS]" for BERT). The dtype must match the dtype
of the input tensors to the layer.
end_value: The id or token that is to be placed at the end of each
input segment (called "[SEP]" for BERT). The dtype much match the
start_value: The id(s) or token(s) that are to be placed at the start of
each sequence (called "[CLS]" for BERT). The dtype must match the
dtype of the input tensors to the layer.
end_value: The id(s) or token(s) that is/are to be placed at the end of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is/are -> are

the last input segment (called "[SEP]" for BERT). The dtype much
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much -> must

match the dtype of the input tensors to the layer.
sep_value: The id(s) or token(s) that is/are to be placed at the end of
every segment, except the last segment (called "[SEP]" for BERT).
If `None`, `end_value` is used. The dtype much match the dtype of
the input tensors to the layer.
pad_value: The id or token that is to be placed into the unused
positions after the last segment in the sequence
(called "[PAD]" for BERT).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add an example below? maybe roberta double sep?

Expand Down Expand Up @@ -110,6 +114,7 @@ def __init__(
sequence_length,
start_value,
end_value,
sep_value=None,
pad_value=None,
truncate="round_robin",
**kwargs,
Expand All @@ -124,17 +129,37 @@ def __init__(
"supported. Received %s" % truncate
)
self.truncate = truncate

# Maintain private copies of start/end values for config purposes.
self._start_value = start_value
self._sep_value = sep_value
self._end_value = end_value

if not isinstance(start_value, (list, tuple)):
start_value = [start_value]

if sep_value is None:
sep_value = end_value
if not isinstance(sep_value, (list, tuple)):
sep_value = [sep_value]

if not isinstance(end_value, (list, tuple)):
end_value = [end_value]

self.start_value = start_value
self.sep_value = sep_value
self.end_value = end_value

self.pad_value = pad_value

def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.sequence_length,
"start_value": self.start_value,
"end_value": self.end_value,
"start_value": self._start_value,
"end_value": self._end_value,
"sep_value": self._sep_value,
"pad_value": self.pad_value,
"truncate": self.truncate,
}
Expand Down Expand Up @@ -170,7 +195,12 @@ def _convert_dense(self, x):

def _trim_inputs(self, inputs):
"""Trim inputs to desired length."""
num_special_tokens = len(inputs) + 1
num_segments = len(inputs)
num_special_tokens = (
len(self.start_value)
+ (num_segments - 1) * len(self.sep_value)
+ len(self.end_value)
)
if self.truncate == "round_robin":
return tf_text.RoundRobinTrimmer(
self.sequence_length - num_special_tokens
Expand All @@ -187,22 +217,40 @@ def _combine_inputs(self, segments):
dtype = segments[0].dtype
batch_size = segments[0].nrows()
start_value = tf.convert_to_tensor(self.start_value, dtype=dtype)
sep_value = tf.convert_to_tensor(self.sep_value, dtype=dtype)
end_value = tf.convert_to_tensor(self.end_value, dtype=dtype)

start_column = tf.fill((batch_size, 1), start_value)
end_column = tf.fill((batch_size, 1), end_value)
ones_column = tf.ones_like(start_column, dtype=tf.int32)
start_values_tensor = tf.repeat(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these names are a little confusing start_value is already a tensor. should we co back to _column naming?

start_column, end_column, sep_column?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, but it isn't exactly a column :P. I'll call it start_columns

start_value[tf.newaxis, :], repeats=batch_size, axis=0
)
end_values_tensor = tf.repeat(
end_value[tf.newaxis, :], repeats=batch_size, axis=0
)
sep_values_tensor = tf.repeat(
sep_value[tf.newaxis, :], repeats=batch_size, axis=0
)
ones_sep_tensor = tf.ones_like(sep_values_tensor, dtype=tf.int32)
ones_end_tensor = tf.ones_like(end_values_tensor, dtype=tf.int32)

segments_to_combine = [start_values_tensor]
segment_ids_to_combine = [
tf.ones_like(start_values_tensor, dtype=tf.int32) * 0
]

segments_to_combine = [start_column]
segment_ids_to_combine = [ones_column * 0]
for i, seg in enumerate(segments):
# Combine all segments adding end tokens.
# Combine all segments.
segments_to_combine.append(seg)
segments_to_combine.append(end_column)

# Combine segment ids accounting for end tokens.
# Combine segment ids.
segment_ids_to_combine.append(tf.ones_like(seg, dtype=tf.int32) * i)
segment_ids_to_combine.append(ones_column * i)

# Account for the sep/end tokens here.
if i == len(segments) - 1:
segments_to_combine.append(end_values_tensor)
segment_ids_to_combine.append(ones_end_tensor * i)
else:
segments_to_combine.append(sep_values_tensor)
segment_ids_to_combine.append(ones_sep_tensor * i)

token_ids = tf.concat(segments_to_combine, 1)
segment_ids = tf.concat(segment_ids_to_combine, 1)
Expand Down
48 changes: 47 additions & 1 deletion keras_nlp/layers/multi_segment_packer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Transformer Decoder."""
"""Tests for multi-segment packing."""

import os

Expand Down Expand Up @@ -147,6 +147,52 @@ def test_pad_batched_inputs(self):
),
)

def test_list_special_tokens(self):
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b", "c"]])
seq2 = tf.ragged.constant([["x", "y", "z"], ["x"]])
packer = MultiSegmentPacker(
9,
start_value="[CLS]",
end_value="[SEP]",
sep_value=["[SEP]", "[SEP]"],
pad_value="[PAD]",
truncate="round_robin",
)
output = packer([seq1, seq2])
self.assertAllEqual(
output,
(
[
[
"[CLS]",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try to come up with a slightly shorter test case that will format the lists to one line

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops!

"a",
"b",
"c",
"[SEP]",
"[SEP]",
"x",
"y",
"[SEP]",
],
[
"[CLS]",
"a",
"b",
"c",
"[SEP]",
"[SEP]",
"x",
"[SEP]",
"[PAD]",
],
],
[
[0, 0, 0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1, 0],
],
),
)

def test_config(self):
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b"]])
seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]])
Expand Down
146 changes: 0 additions & 146 deletions keras_nlp/models/roberta/roberta_multi_segment_packer.py

This file was deleted.

9 changes: 4 additions & 5 deletions keras_nlp/models/roberta/roberta_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.models.roberta.roberta_multi_segment_packer import (
RobertaMultiSegmentPacker,
)
from keras_nlp.models.roberta.roberta_presets import backbone_presets
from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_nlp.utils.keras_utils import (
Expand Down Expand Up @@ -145,9 +143,10 @@ def __init__(
super().__init__(**kwargs)

self.tokenizer = tokenizer
self.packer = RobertaMultiSegmentPacker(
self.packer = MultiSegmentPacker(
start_value=self.tokenizer.start_token_id,
end_value=self.tokenizer.end_token_id,
sep_value=[self.tokenizer.end_token_id] * 2,
pad_value=self.tokenizer.pad_token_id,
truncate=truncate,
sequence_length=sequence_length,
Expand All @@ -166,7 +165,7 @@ def get_config(self):
def call(self, x, y=None, sample_weight=None):
x = convert_inputs_to_list_of_tensor_segments(x)
x = [self.tokenizer(segment) for segment in x]
token_ids = self.packer(x)
token_ids, _ = self.packer(x)
x = {
"token_ids": token_ids,
"padding_mask": token_ids != self.tokenizer.pad_token_id,
Expand Down
Loading