Skip to content

Commit

Permalink
Rename unsplittable to special
Browse files Browse the repository at this point in the history
  • Loading branch information
abuelnasr0 committed Apr 2, 2024
1 parent c231372 commit 74c3557
Show file tree
Hide file tree
Showing 11 changed files with 68 additions and 71 deletions.
6 changes: 3 additions & 3 deletions keras_nlp/models/bart/bart_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
unsplittable_tokens=[
special_tokens=[
self.start_token,
self.pad_token,
self.end_token,
],
unsplittable_tokens_in_strings=special_tokens_in_strings,
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand All @@ -113,5 +113,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):

def get_config(self):
config = super().get_config()
del config["unsplittable_tokens"] # Not configurable; set in __init__.
del config["special_tokens"] # Not configurable; set in __init__.
return config
6 changes: 3 additions & 3 deletions keras_nlp/models/bloom/bloom_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
unsplittable_tokens=[
special_tokens=[
self.start_token,
self.end_token,
self.pad_token,
],
unsplittable_tokens_in_strings=special_tokens_in_strings,
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand All @@ -105,5 +105,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):

def get_config(self):
config = super().get_config()
del config["unsplittable_tokens"] # Not configurable; set in __init__.
del config["special_tokens"] # Not configurable; set in __init__.
return config
6 changes: 3 additions & 3 deletions keras_nlp/models/falcon/falcon_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
unsplittable_tokens=[self.end_token],
unsplittable_tokens_in_strings=special_tokens_in_strings,
special_tokens=[self.end_token],
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand All @@ -100,5 +100,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):

def get_config(self):
config = super().get_config()
del config["unsplittable_tokens"] # Not configurable; set in __init__.
del config["special_tokens"] # Not configurable; set in __init__.
return config
2 changes: 1 addition & 1 deletion keras_nlp/models/falcon/falcon_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setUp(self):
"vocabulary": self.vocab,
"merges": self.merges,
"special_tokens_in_strings": True,
}
}
self.input_data = [
" airplane at airport<|endoftext|>",
" airplane airport",
Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/models/gpt2/gpt2_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
unsplittable_tokens=[self.end_token],
unsplittable_tokens_in_strings=special_tokens_in_strings,
special_tokens=[self.end_token],
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand All @@ -100,5 +100,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):

def get_config(self):
config = super().get_config()
del config["unsplittable_tokens"] # Not configurable; set in __init__.
del config["special_tokens"] # Not configurable; set in __init__.
return config
6 changes: 3 additions & 3 deletions keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
unsplittable_tokens=[self.end_token],
unsplittable_tokens_in_strings=special_tokens_in_strings,
special_tokens=[self.end_token],
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand All @@ -78,5 +78,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):

def get_config(self):
config = super().get_config()
del config["unsplittable_tokens"] # Not configurable; set in __init__.
del config["special_tokens"] # Not configurable; set in __init__.
return config
6 changes: 3 additions & 3 deletions keras_nlp/models/opt/opt_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
unsplittable_tokens=[
special_tokens=[
self.start_token,
self.pad_token,
self.end_token,
],
unsplittable_tokens_in_strings=special_tokens_in_strings,
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand All @@ -105,5 +105,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):

def get_config(self):
config = super().get_config()
del config["unsplittable_tokens"] # Not configurable; set in __init__.
del config["special_tokens"] # Not configurable; set in __init__.
return config
6 changes: 3 additions & 3 deletions keras_nlp/models/roberta/roberta_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
unsplittable_tokens=[
special_tokens=[
self.start_token,
self.pad_token,
self.end_token,
self.mask_token,
],
unsplittable_tokens_in_strings=special_tokens_in_strings,
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand All @@ -116,5 +116,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):

def get_config(self):
config = super().get_config()
del config["unsplittable_tokens"] # Not configurable; set in __init__.
del config["special_tokens"] # Not configurable; set in __init__.
return config
13 changes: 7 additions & 6 deletions keras_nlp/models/whisper/whisper_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def __init__(
self.translate_token_id = special_tokens[self.translate_token]
self.transcribe_token_id = special_tokens[self.transcribe_token]

self.special_tokens = special_tokens
# Underscore to distinguish it from `self.special_tokens` in base class.
self._special_tokens = special_tokens
self.language_tokens = language_tokens

# TODO: Add language tokens to `unsplittable_tokens` once we figure
Expand All @@ -109,8 +110,8 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
merges=merges,
unsplittable_tokens=unsplittable_tokens,
unsplittable_tokens_in_strings=special_tokens_in_strings,
special_tokens=unsplittable_tokens,
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand Down Expand Up @@ -146,18 +147,18 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
self.translate_token,
self.transcribe_token,
]:
vocabulary[token] = self.special_tokens[token]
vocabulary[token] = self._special_tokens[token]
else:
self._initial_vocabulary = None

super().set_vocabulary_and_merges(vocabulary, merges)

def get_config(self):
config = super().get_config()
del config["unsplittable_tokens"] # Not configurable; set in __init__.
del config["special_tokens"] # Not configurable; set in __init__.
config.update(
{
"special_tokens": self.special_tokens,
"special_tokens": self._special_tokens,
"language_tokens": self.language_tokens,
}
)
Expand Down
68 changes: 32 additions & 36 deletions keras_nlp/tokenizers/byte_pair_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@
SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""


def get_unsplittable_tokens_pattern(unsplittable_tokens):
if unsplittable_tokens is None or len(unsplittable_tokens) == 0:
def get_special_tokens_pattern(special_tokens):
if special_tokens is None or len(special_tokens) == 0:
return None
return r"|".join([re.escape(token) for token in unsplittable_tokens])
return r"|".join([re.escape(token) for token in special_tokens])


def bytes_to_unicode():
Expand Down Expand Up @@ -97,7 +97,7 @@ def remove_strings_from_inputs(tensor, string_to_remove):
return result


def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None):
def split_strings_for_bpe(inputs, special_tokens_pattern=None):
# We need to recreate the exact behavior of token presplitting in the
# original gpt2 tokenizer which uses a lookahead. As re2 does not
# support lookahead match, we are using an alternative insert a special
Expand All @@ -110,23 +110,23 @@ def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None):
inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६"
)

if unsplittable_tokens_pattern is not None:
# First split the unsplittable tokens from the input.
if special_tokens_pattern is not None:
# First split the special tokens from the input.
raw_tokens = tf_text.regex_split(
inputs, unsplittable_tokens_pattern, unsplittable_tokens_pattern
inputs, special_tokens_pattern, special_tokens_pattern
)
# Then split using both `unsplittable_tokens_pattern` and
# Then split using both `special_tokens_pattern` and
# `SPLIT_PATTERN_1` to split inputs like original gpt2, while not
# affecting the unsplittable tokens.
# We split unsplittable tokens first then apply this split instead of
# affecting the special tokens.
# We split special tokens first then apply this split instead of
# applying this split directly, because otherwise we will not split
# unsplittable tokens from inputs properly, because of this pattern
# special tokens from inputs properly, because of this pattern
# ` ?[^\s\p{L}\p{N}{special_spaces}]+`.
# e.g., [" </s>"] will be [" </", "s", ">"] instead of [" ", "</s>"]
raw_tokens = tf_text.regex_split(
raw_tokens,
r"|".join([unsplittable_tokens_pattern, SPLIT_PATTERN_1]),
r"|".join([unsplittable_tokens_pattern, SPLIT_PATTERN_1]),
r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]),
r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]),
)
raw_tokens = raw_tokens.merge_dims(-2, -1)
else:
Expand Down Expand Up @@ -238,16 +238,16 @@ class BytePairTokenizer(tokenizer.Tokenizer):
a prefix space to the first word will cause it to be tokenized
equivalently to all subsequent words in the sequence.
Defaults to `False`.
unsplittable_tokens: list. A list of unsplittable tokens. when
`unsplittable_tokens_in_strings` is set to `True`, unsplittable
special_tokens: list. A list of special tokens. when
`special_tokens_in_strings` is set to `True`, special
tokens will never be split during the word-level splitting applied
before the byte-pair encoding. This can be used to ensure special
tokens map to unique indices in the vocabulary, even if these
unsplittable tokens contain splittable characters such as
punctuation. Unsplittable tokens must still be included in
special tokens contain splittable characters such as
punctuation. special tokens must still be included in
`vocabulary`. Defaults to `None`.
unsplittable_tokens_in_strings: bool. To indicate if the tokenizer
should expect unsplittable tokens in input strings that should be
special_tokens_in_strings: bool. To indicate if the tokenizer
should expect special tokens in input strings that should be
tokenized and mapped correctly to their ids. Defaults to False.
Examples:
Expand Down Expand Up @@ -287,8 +287,8 @@ def __init__(
merges=None,
sequence_length=None,
add_prefix_space=False,
unsplittable_tokens=None,
unsplittable_tokens_in_strings=False,
special_tokens=None,
special_tokens_in_strings=False,
dtype="int32",
**kwargs,
) -> None:
Expand All @@ -303,11 +303,11 @@ def __init__(
super().__init__(dtype=dtype, **kwargs)
self.sequence_length = sequence_length
self.add_prefix_space = add_prefix_space
self.unsplittable_tokens = unsplittable_tokens
self._unsplittable_tokens_pattern = None
if unsplittable_tokens_in_strings:
self._unsplittable_tokens_pattern = get_unsplittable_tokens_pattern(
unsplittable_tokens
self.special_tokens = special_tokens
self._special_tokens_pattern = None
if special_tokens_in_strings:
self._special_tokens_pattern = get_special_tokens_pattern(
special_tokens
)

# Create byte <=> unicode mapping. This is useful for handling
Expand Down Expand Up @@ -362,8 +362,8 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
)

# Check for special tokens in vocabulary.
if self.unsplittable_tokens is not None:
for token in self.unsplittable_tokens:
if self.special_tokens is not None:
for token in self.special_tokens:
if token not in self.get_vocabulary():
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
Expand All @@ -383,12 +383,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
)

self.cache = BytePairTokenizerCache()
if self.unsplittable_tokens:
if self.special_tokens and self._special_tokens_pattern is not None:
# Put special tokens into cache, so it won't be further split and
# merged.
self.cache.insert(
self.unsplittable_tokens, self.unsplittable_tokens
)
self.cache.insert(self.special_tokens, self.special_tokens)

# Create mapping between string tokens to int ids, and vice versa.
byte_pairs = [x[0] for x in self.vocabulary.items()]
Expand Down Expand Up @@ -566,9 +564,7 @@ def tokenize(self, inputs):
if scalar_input:
inputs = tf.expand_dims(inputs, 0)

raw_tokens = split_strings_for_bpe(
inputs, self._unsplittable_tokens_pattern
)
raw_tokens = split_strings_for_bpe(inputs, self._special_tokens_pattern)
token_row_splits = raw_tokens.row_splits
flat_tokens = raw_tokens.flat_values

Expand Down Expand Up @@ -662,7 +658,7 @@ def get_config(self):
{
"sequence_length": self.sequence_length,
"add_prefix_space": self.add_prefix_space,
"unsplittable_tokens": self.unsplittable_tokens,
"special_tokens": self.special_tokens,
}
)
return config
14 changes: 7 additions & 7 deletions keras_nlp/tokenizers/byte_pair_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,17 @@ def test_tokenize_with_special_tokens(self):
tokenizer = BytePairTokenizer(
vocabulary=vocab,
merges=merges,
unsplittable_tokens=["s", "p"],
unsplittable_tokens_in_strings=True,
special_tokens=["s", "p"],
special_tokens_in_strings=True,
)
output = tokenizer("sp")
self.assertAllEqual(output, [1, 2])

# If not unsplittable_tokens_in_strings is `True`, "sp" is one token.
# If not special_tokens_in_strings is `True`, "sp" is one token.
tokenizer = BytePairTokenizer(
vocabulary=vocab,
merges=merges,
unsplittable_tokens=["s", "p"],
special_tokens=["s", "p"],
)
output = tokenizer("sp")
self.assertAllEqual(output, [0])
Expand All @@ -89,16 +89,16 @@ def test_tokenize_with_special_tokens(self):
tokenizer = BytePairTokenizer(
vocabulary=vocab,
merges=merges,
unsplittable_tokens=["<s>", "</s>"],
unsplittable_tokens_in_strings=True,
special_tokens=["<s>", "</s>"],
special_tokens_in_strings=True,
)
output = tokenizer("<s>a quick fox</s>")
self.assertAllEqual(output, [0, 2, 3, 4, 1])

def test_errors_missing_special_tokens(self):
with self.assertRaises(ValueError):
BytePairTokenizer(
vocabulary=["a", "b", "c"], merges=[], unsplittable_tokens=["d"]
vocabulary=["a", "b", "c"], merges=[], special_tokens=["d"]
)

def test_tokenize_prefix_space(self):
Expand Down

0 comments on commit 74c3557

Please sign in to comment.