-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support list/tuple inputs for special tokens in MultiSegmentPacker layer #1046
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,12 +53,16 @@ class MultiSegmentPacker(keras.layers.Layer): | |
|
||
Args: | ||
sequence_length: The desired output length. | ||
start_value: The id or token that is to be placed at the start of each | ||
sequence (called "[CLS]" for BERT). The dtype must match the dtype | ||
of the input tensors to the layer. | ||
end_value: The id or token that is to be placed at the end of each | ||
input segment (called "[SEP]" for BERT). The dtype much match the | ||
start_value: The id(s) or token(s) that are to be placed at the start of | ||
each sequence (called "[CLS]" for BERT). The dtype must match the | ||
dtype of the input tensors to the layer. | ||
end_value: The id(s) or token(s) that is/are to be placed at the end of | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is/are -> are |
||
the last input segment (called "[SEP]" for BERT). The dtype much | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. much -> must |
||
match the dtype of the input tensors to the layer. | ||
sep_value: The id(s) or token(s) that is/are to be placed at the end of | ||
every segment, except the last segment (called "[SEP]" for BERT). | ||
If `None`, `end_value` is used. The dtype much match the dtype of | ||
the input tensors to the layer. | ||
pad_value: The id or token that is to be placed into the unused | ||
positions after the last segment in the sequence | ||
(called "[PAD]" for BERT). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add an example below? maybe roberta double sep? |
||
|
@@ -110,6 +114,7 @@ def __init__( | |
sequence_length, | ||
start_value, | ||
end_value, | ||
sep_value=None, | ||
pad_value=None, | ||
truncate="round_robin", | ||
**kwargs, | ||
|
@@ -124,17 +129,37 @@ def __init__( | |
"supported. Received %s" % truncate | ||
) | ||
self.truncate = truncate | ||
|
||
# Maintain private copies of start/end values for config purposes. | ||
self._start_value = start_value | ||
self._sep_value = sep_value | ||
self._end_value = end_value | ||
|
||
if not isinstance(start_value, (list, tuple)): | ||
start_value = [start_value] | ||
|
||
if sep_value is None: | ||
sep_value = end_value | ||
if not isinstance(sep_value, (list, tuple)): | ||
sep_value = [sep_value] | ||
|
||
if not isinstance(end_value, (list, tuple)): | ||
end_value = [end_value] | ||
|
||
self.start_value = start_value | ||
self.sep_value = sep_value | ||
self.end_value = end_value | ||
|
||
self.pad_value = pad_value | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"sequence_length": self.sequence_length, | ||
"start_value": self.start_value, | ||
"end_value": self.end_value, | ||
"start_value": self._start_value, | ||
"end_value": self._end_value, | ||
"sep_value": self._sep_value, | ||
"pad_value": self.pad_value, | ||
"truncate": self.truncate, | ||
} | ||
|
@@ -170,7 +195,12 @@ def _convert_dense(self, x): | |
|
||
def _trim_inputs(self, inputs): | ||
"""Trim inputs to desired length.""" | ||
num_special_tokens = len(inputs) + 1 | ||
num_segments = len(inputs) | ||
num_special_tokens = ( | ||
len(self.start_value) | ||
+ (num_segments - 1) * len(self.sep_value) | ||
+ len(self.end_value) | ||
) | ||
if self.truncate == "round_robin": | ||
return tf_text.RoundRobinTrimmer( | ||
self.sequence_length - num_special_tokens | ||
|
@@ -187,22 +217,40 @@ def _combine_inputs(self, segments): | |
dtype = segments[0].dtype | ||
batch_size = segments[0].nrows() | ||
start_value = tf.convert_to_tensor(self.start_value, dtype=dtype) | ||
sep_value = tf.convert_to_tensor(self.sep_value, dtype=dtype) | ||
end_value = tf.convert_to_tensor(self.end_value, dtype=dtype) | ||
|
||
start_column = tf.fill((batch_size, 1), start_value) | ||
end_column = tf.fill((batch_size, 1), end_value) | ||
ones_column = tf.ones_like(start_column, dtype=tf.int32) | ||
start_values_tensor = tf.repeat( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these names are a little confusing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can, but it isn't exactly a column :P. I'll call it |
||
start_value[tf.newaxis, :], repeats=batch_size, axis=0 | ||
) | ||
end_values_tensor = tf.repeat( | ||
end_value[tf.newaxis, :], repeats=batch_size, axis=0 | ||
) | ||
sep_values_tensor = tf.repeat( | ||
sep_value[tf.newaxis, :], repeats=batch_size, axis=0 | ||
) | ||
ones_sep_tensor = tf.ones_like(sep_values_tensor, dtype=tf.int32) | ||
ones_end_tensor = tf.ones_like(end_values_tensor, dtype=tf.int32) | ||
|
||
segments_to_combine = [start_values_tensor] | ||
segment_ids_to_combine = [ | ||
tf.ones_like(start_values_tensor, dtype=tf.int32) * 0 | ||
] | ||
|
||
segments_to_combine = [start_column] | ||
segment_ids_to_combine = [ones_column * 0] | ||
for i, seg in enumerate(segments): | ||
# Combine all segments adding end tokens. | ||
# Combine all segments. | ||
segments_to_combine.append(seg) | ||
segments_to_combine.append(end_column) | ||
|
||
# Combine segment ids accounting for end tokens. | ||
# Combine segment ids. | ||
segment_ids_to_combine.append(tf.ones_like(seg, dtype=tf.int32) * i) | ||
segment_ids_to_combine.append(ones_column * i) | ||
|
||
# Account for the sep/end tokens here. | ||
if i == len(segments) - 1: | ||
segments_to_combine.append(end_values_tensor) | ||
segment_ids_to_combine.append(ones_end_tensor * i) | ||
else: | ||
segments_to_combine.append(sep_values_tensor) | ||
segment_ids_to_combine.append(ones_sep_tensor * i) | ||
|
||
token_ids = tf.concat(segments_to_combine, 1) | ||
segment_ids = tf.concat(segment_ids_to_combine, 1) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for Transformer Decoder.""" | ||
"""Tests for multi-segment packing.""" | ||
|
||
import os | ||
|
||
|
@@ -147,6 +147,52 @@ def test_pad_batched_inputs(self): | |
), | ||
) | ||
|
||
def test_list_special_tokens(self): | ||
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b", "c"]]) | ||
seq2 = tf.ragged.constant([["x", "y", "z"], ["x"]]) | ||
packer = MultiSegmentPacker( | ||
9, | ||
start_value="[CLS]", | ||
end_value="[SEP]", | ||
sep_value=["[SEP]", "[SEP]"], | ||
pad_value="[PAD]", | ||
truncate="round_robin", | ||
) | ||
output = packer([seq1, seq2]) | ||
self.assertAllEqual( | ||
output, | ||
( | ||
[ | ||
[ | ||
"[CLS]", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. try to come up with a slightly shorter test case that will format the lists to one line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whoops! |
||
"a", | ||
"b", | ||
"c", | ||
"[SEP]", | ||
"[SEP]", | ||
"x", | ||
"y", | ||
"[SEP]", | ||
], | ||
[ | ||
"[CLS]", | ||
"a", | ||
"b", | ||
"c", | ||
"[SEP]", | ||
"[SEP]", | ||
"x", | ||
"[SEP]", | ||
"[PAD]", | ||
], | ||
], | ||
[ | ||
[0, 0, 0, 0, 0, 0, 1, 1, 1], | ||
[0, 0, 0, 0, 0, 0, 1, 1, 0], | ||
], | ||
), | ||
) | ||
|
||
def test_config(self): | ||
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b"]]) | ||
seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]]) | ||
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add the type notes you have in the other layer for consistency