diff --git a/keras_nlp/layers/multi_segment_packer.py b/keras_nlp/layers/multi_segment_packer.py index 2cc3b65076..c6f291ae86 100644 --- a/keras_nlp/layers/multi_segment_packer.py +++ b/keras_nlp/layers/multi_segment_packer.py @@ -52,19 +52,24 @@ class MultiSegmentPacker(keras.layers.Layer): either rank-1 or rank-2. Args: - sequence_length: The desired output length. - start_value: The id or token that is to be placed at the start of each - sequence (called "[CLS]" for BERT). The dtype must match the dtype - of the input tensors to the layer. - end_value: The id or token that is to be placed at the end of each - input segment (called "[SEP]" for BERT). The dtype much match the - dtype of the input tensors to the layer. - pad_value: The id or token that is to be placed into the unused + sequence_length: int. The desired output length. + start_value: int/str/list/tuple. The id(s) or token(s) that are to be + placed at the start of each sequence (called "[CLS]" for BERT). The + dtype must match the dtype of the input tensors to the layer. + end_value: int/str/list/tuple. The id(s) or token(s) that are to be + placed at the end of the last input segment (called "[SEP]" for + BERT). The dtype must match the dtype of the input tensors to the + layer. + sep_value: int/str/list/tuple. The id(s) or token(s) that are to be + placed at the end of every segment, except the last segment (called + "[SEP]" for BERT). If `None`, `end_value` is used. The dtype must + match the dtype of the input tensors to the layer. + pad_value: int/str. The id or token that is to be placed into the unused positions after the last segment in the sequence (called "[PAD]" for BERT). - truncate: The algorithm to truncate a list of batched segments to fit a - per-example length limit. The value can be either `round_robin` or - `waterfall`: + truncate: str. The algorithm to truncate a list of batched segments to + fit a per-example length limit. The value can be either + `"round_robin"` or `"waterfall"`: - `"round_robin"`: Available space is assigned one token at a time in a round-robin fashion to the inputs that still need some, until the limit is reached. @@ -101,6 +106,17 @@ class MultiSegmentPacker(keras.layers.Layer): ) + *Pack multiple inputs for classification with different sep tokens.* + >>> seq1 = tf.constant([1, 2, 3, 4]) + >>> seq2 = tf.constant([11, 12, 13, 14]) + >>> packer = keras_nlp.layers.MultiSegmentPacker( + ... 8, start_value=101, end_value=102, sep_value=[102, 102]) + >>> packer((seq1, seq2)) + (, + ) + Reference: [Devlin et al., 2018](https://arxiv.org/abs/1810.04805). """ @@ -110,6 +126,7 @@ def __init__( sequence_length, start_value, end_value, + sep_value=None, pad_value=None, truncate="round_robin", **kwargs, @@ -124,8 +141,32 @@ def __init__( "supported. Received %s" % truncate ) self.truncate = truncate + + # Maintain private copies of start/end values for config purposes. + self._start_value = start_value + self._sep_value = sep_value + self._end_value = end_value + + def check_special_value_type(value, value_name): + if isinstance(value, (int, str)): + return [value] + if value and not isinstance(value, (list, tuple)): + raise ValueError( + f"{value_name} should be of type int/str/list/tuple." + f"Received type: `{type(value)}`." + ) + return value + + start_value = check_special_value_type(start_value, "start_value") + if sep_value is None: + sep_value = end_value + sep_value = check_special_value_type(sep_value, "sep_value") + end_value = check_special_value_type(end_value, "end_value") + self.start_value = start_value + self.sep_value = sep_value self.end_value = end_value + self.pad_value = pad_value def get_config(self): @@ -133,8 +174,9 @@ def get_config(self): config.update( { "sequence_length": self.sequence_length, - "start_value": self.start_value, - "end_value": self.end_value, + "start_value": self._start_value, + "end_value": self._end_value, + "sep_value": self._sep_value, "pad_value": self.pad_value, "truncate": self.truncate, } @@ -170,7 +212,12 @@ def _convert_dense(self, x): def _trim_inputs(self, inputs): """Trim inputs to desired length.""" - num_special_tokens = len(inputs) + 1 + num_segments = len(inputs) + num_special_tokens = ( + len(self.start_value) + + (num_segments - 1) * len(self.sep_value) + + len(self.end_value) + ) if self.truncate == "round_robin": return tf_text.RoundRobinTrimmer( self.sequence_length - num_special_tokens @@ -187,22 +234,40 @@ def _combine_inputs(self, segments): dtype = segments[0].dtype batch_size = segments[0].nrows() start_value = tf.convert_to_tensor(self.start_value, dtype=dtype) + sep_value = tf.convert_to_tensor(self.sep_value, dtype=dtype) end_value = tf.convert_to_tensor(self.end_value, dtype=dtype) - start_column = tf.fill((batch_size, 1), start_value) - end_column = tf.fill((batch_size, 1), end_value) - ones_column = tf.ones_like(start_column, dtype=tf.int32) + start_columns = tf.repeat( + start_value[tf.newaxis, :], repeats=batch_size, axis=0 + ) + sep_columns = tf.repeat( + sep_value[tf.newaxis, :], repeats=batch_size, axis=0 + ) + end_columns = tf.repeat( + end_value[tf.newaxis, :], repeats=batch_size, axis=0 + ) + ones_sep_columns = tf.ones_like(sep_columns, dtype=tf.int32) + ones_end_columns = tf.ones_like(end_columns, dtype=tf.int32) + + segments_to_combine = [start_columns] + segment_ids_to_combine = [ + tf.ones_like(start_columns, dtype=tf.int32) * 0 + ] - segments_to_combine = [start_column] - segment_ids_to_combine = [ones_column * 0] for i, seg in enumerate(segments): - # Combine all segments adding end tokens. + # Combine all segments. segments_to_combine.append(seg) - segments_to_combine.append(end_column) - # Combine segment ids accounting for end tokens. + # Combine segment ids. segment_ids_to_combine.append(tf.ones_like(seg, dtype=tf.int32) * i) - segment_ids_to_combine.append(ones_column * i) + + # Account for the sep/end tokens here. + if i == len(segments) - 1: + segments_to_combine.append(end_columns) + segment_ids_to_combine.append(ones_end_columns * i) + else: + segments_to_combine.append(sep_columns) + segment_ids_to_combine.append(ones_sep_columns * i) token_ids = tf.concat(segments_to_combine, 1) segment_ids = tf.concat(segment_ids_to_combine, 1) diff --git a/keras_nlp/layers/multi_segment_packer_test.py b/keras_nlp/layers/multi_segment_packer_test.py index 4434486a77..1984ae06c4 100644 --- a/keras_nlp/layers/multi_segment_packer_test.py +++ b/keras_nlp/layers/multi_segment_packer_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Transformer Decoder.""" +"""Tests for multi-segment packing.""" import os @@ -147,6 +147,32 @@ def test_pad_batched_inputs(self): ), ) + def test_list_special_tokens(self): + seq1 = tf.ragged.constant([["a", "b"], ["a", "b"]]) + seq2 = tf.ragged.constant([["x", "y"], ["x"]]) + packer = MultiSegmentPacker( + 8, + start_value="", + end_value="", + sep_value=["", ""], + pad_value="", + truncate="round_robin", + ) + output = packer([seq1, seq2]) + self.assertAllEqual( + output, + ( + [ + ["", "a", "b", "", "", "x", "y", ""], + ["", "a", "b", "", "", "x", "", ""], + ], + [ + [0, 0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 0], + ], + ), + ) + def test_config(self): seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b"]]) seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]]) diff --git a/keras_nlp/models/roberta/roberta_multi_segment_packer.py b/keras_nlp/models/roberta/roberta_multi_segment_packer.py deleted file mode 100644 index bbac5fdbe1..0000000000 --- a/keras_nlp/models/roberta/roberta_multi_segment_packer.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2022 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import tensorflow as tf -from tensorflow import keras - -from keras_nlp.api_export import keras_nlp_export -from keras_nlp.utils.tf_utils import assert_tf_text_installed - -try: - import tensorflow_text as tf_text -except ImportError: - tf_text = None - - -# TODO: This is a temporary, unexported layer until we find a way to make the -# `MultiSegmentPacker` layer more generic. -@keras_nlp_export("keras_nlp.models.RobertaMultiSegmentPacker") -class RobertaMultiSegmentPacker(keras.layers.Layer): - """Packs multiple sequences into a single fixed width model input. - - This layer packs multiple input sequences into a single fixed width sequence - containing start and end delimiters, forming a dense input suitable for a - classification task for RoBERTa. - - Takes as input a list or tuple of token segments. The layer will process - inputs as follows: - - Truncate all input segments to fit within `sequence_length` according to - the `truncate` strategy. - - Concatenate all input segments, adding a single `start_value` at the - start of the entire sequence, `[end_value, end_value]` at the end of - each segment save the last, and a single `end_value` at the end of the - entire sequence. - - Pad the resulting sequence to `sequence_length` using `pad_tokens`. - - Input should be either a `tf.RaggedTensor` or a dense `tf.Tensor`, and - either rank-1 or rank-2. - - Please refer to the arguments of `keras_nlp.layers.MultiSegmentPacker` for - more details. - """ - - def __init__( - self, - sequence_length, - start_value, - end_value, - pad_value=None, - truncate="round_robin", - **kwargs, - ): - assert_tf_text_installed(self.__class__.__name__) - - super().__init__(**kwargs) - self.sequence_length = sequence_length - if truncate not in ("round_robin", "waterfall"): - raise ValueError( - "Only 'round_robin' and 'waterfall' algorithms are " - "supported. Received %s" % truncate - ) - self.truncate = truncate - self.start_value = start_value - self.end_value = end_value - self.pad_value = pad_value - - def get_config(self): - config = super().get_config() - config.update( - { - "sequence_length": self.sequence_length, - "start_value": self.start_value, - "end_value": self.end_value, - "pad_value": self.pad_value, - "truncate": self.truncate, - } - ) - return config - - def _trim_inputs(self, inputs): - """Trim inputs to desired length.""" - # Special tokens include the start token at the beginning of the - # sequence, two `end_value` at the end of every segment save the last, - # and the `end_value` at the end of the sequence. - num_special_tokens = 2 * len(inputs) - if self.truncate == "round_robin": - return tf_text.RoundRobinTrimmer( - self.sequence_length - num_special_tokens - ).trim(inputs) - elif self.truncate == "waterfall": - return tf_text.WaterfallTrimmer( - self.sequence_length - num_special_tokens - ).trim(inputs) - else: - raise ValueError("Unsupported truncate: %s" % self.truncate) - - def _combine_inputs(self, segments): - """Combine inputs with start and end values added.""" - dtype = segments[0].dtype - batch_size = segments[0].nrows() - - start_value = tf.convert_to_tensor(self.start_value, dtype=dtype) - end_value = tf.convert_to_tensor(self.end_value, dtype=dtype) - - start_column = tf.fill((batch_size, 1), start_value) - end_column = tf.fill((batch_size, 1), end_value) - - segments_to_combine = [] - for i, seg in enumerate(segments): - segments_to_combine.append(start_column if i == 0 else end_column) - segments_to_combine.append(seg) - segments_to_combine.append(end_column) - - token_ids = tf.concat(segments_to_combine, 1) - return token_ids - - def call(self, inputs): - def to_ragged(x): - return tf.RaggedTensor.from_tensor(x[tf.newaxis, :]) - - # If rank 1, add a batch dim. - rank_1 = inputs[0].shape.rank == 1 - if rank_1: - inputs = [to_ragged(x) for x in inputs] - - segments = self._trim_inputs(inputs) - token_ids = self._combine_inputs(segments) - # Pad to dense tensor output. - shape = tf.cast([-1, self.sequence_length], tf.int64) - token_ids = token_ids.to_tensor( - shape=shape, default_value=self.pad_value - ) - # Remove the batch dim if added. - if rank_1: - token_ids = tf.squeeze(token_ids, 0) - - return token_ids diff --git a/keras_nlp/models/roberta/roberta_preprocessor.py b/keras_nlp/models/roberta/roberta_preprocessor.py index ca2c2560e7..44ed27ce12 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor.py +++ b/keras_nlp/models/roberta/roberta_preprocessor.py @@ -17,10 +17,8 @@ import copy from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker from keras_nlp.models.preprocessor import Preprocessor -from keras_nlp.models.roberta.roberta_multi_segment_packer import ( - RobertaMultiSegmentPacker, -) from keras_nlp.models.roberta.roberta_presets import backbone_presets from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_nlp.utils.keras_utils import ( @@ -145,9 +143,10 @@ def __init__( super().__init__(**kwargs) self.tokenizer = tokenizer - self.packer = RobertaMultiSegmentPacker( + self.packer = MultiSegmentPacker( start_value=self.tokenizer.start_token_id, end_value=self.tokenizer.end_token_id, + sep_value=[self.tokenizer.end_token_id] * 2, pad_value=self.tokenizer.pad_token_id, truncate=truncate, sequence_length=sequence_length, @@ -166,7 +165,7 @@ def get_config(self): def call(self, x, y=None, sample_weight=None): x = convert_inputs_to_list_of_tensor_segments(x) x = [self.tokenizer(segment) for segment in x] - token_ids = self.packer(x) + token_ids, _ = self.packer(x) x = { "token_ids": token_ids, "padding_mask": token_ids != self.tokenizer.pad_token_id, diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py index 122c372c01..475d778c31 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py @@ -17,10 +17,8 @@ import copy from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker from keras_nlp.models.preprocessor import Preprocessor -from keras_nlp.models.roberta.roberta_preprocessor import ( - RobertaMultiSegmentPacker, -) from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets from keras_nlp.models.xlm_roberta.xlm_roberta_tokenizer import ( XLMRobertaTokenizer, @@ -159,9 +157,10 @@ def __init__( self.tokenizer = tokenizer - self.packer = RobertaMultiSegmentPacker( + self.packer = MultiSegmentPacker( start_value=self.tokenizer.start_token_id, end_value=self.tokenizer.end_token_id, + sep_value=[self.tokenizer.end_token_id] * 2, pad_value=self.tokenizer.pad_token_id, truncate=truncate, sequence_length=sequence_length, @@ -180,7 +179,7 @@ def get_config(self): def call(self, x, y=None, sample_weight=None): x = convert_inputs_to_list_of_tensor_segments(x) x = [self.tokenizer(segment) for segment in x] - token_ids = self.packer(x) + token_ids, _ = self.packer(x) x = { "token_ids": token_ids, "padding_mask": token_ids != self.tokenizer.pad_token_id,