Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[split_special_tokens] Add support for split_special_tokens argument to encode #25081

Merged
merged 26 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
efec7cc
draft changes
ArthurZucker Jul 25, 2023
643d830
update and add tests
ArthurZucker Jul 26, 2023
f092997
styling for no
ArthurZucker Jul 26, 2023
d1aabd0
move test
ArthurZucker Jul 26, 2023
63e665e
path to usable model
ArthurZucker Jul 26, 2023
6c86c8b
update test
ArthurZucker Jul 31, 2023
d869f58
small update
ArthurZucker Jul 31, 2023
8cd03d4
update bertbased tokenizers
ArthurZucker Jul 31, 2023
8e1ecd7
don'tuse kwargs for _tokenize
ArthurZucker Jul 31, 2023
964b311
don'tuse kwargs for _tokenize
ArthurZucker Jul 31, 2023
707a570
fix copies
ArthurZucker Jul 31, 2023
d38096c
update
ArthurZucker Jul 31, 2023
7478b8b
update test for special tokenizers
ArthurZucker Aug 1, 2023
8e9c0f6
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Aug 1, 2023
2fe79dc
fixup
ArthurZucker Aug 1, 2023
d2edd3d
skip two tests
ArthurZucker Aug 1, 2023
5bff8ec
remove pdb breakpiont()
ArthurZucker Aug 1, 2023
6cb52e3
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Aug 1, 2023
6cff6a6
wowo
ArthurZucker Aug 3, 2023
49b60d0
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Aug 11, 2023
bf08a7e
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Aug 17, 2023
5b62626
rewrite custom tests
ArthurZucker Aug 17, 2023
ae3a65a
nits
ArthurZucker Aug 17, 2023
4201bf2
revert chang in target keys
ArthurZucker Aug 17, 2023
24cdd94
fix markup lm
ArthurZucker Aug 17, 2023
98eb560
update documentation of the argument
ArthurZucker Aug 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/transformers/models/bert/tokenization_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/convbert/tokenization_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,12 @@ def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/distilbert/tokenization_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,12 @@ def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/electra/tokenization_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/funnel/tokenization_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,12 @@ def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/layoutlm/tokenization_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/lxmert/tokenization_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/mobilebert/tokenization_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/roc_bert/tokenization_roc_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,12 @@ def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ def tokenize(self, text: TextInput, **kwargs) -> List[str]:
str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
}

split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)

text, kwargs = self.prepare_for_tokenization(text, **kwargs)

if kwargs:
Expand All @@ -513,8 +515,14 @@ def tokenize(self, text: TextInput, **kwargs) -> List[str]:
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)

no_split_token = set(self.unique_no_split_tokens)
tokens = self.tokens_trie.split(text)
# split_special_tokens: empty `no_split_token`
if split_special_tokens:
no_split_token = []
tokens = [text]
else:
no_split_token = set(self.unique_no_split_tokens)
tokens = self.tokens_trie.split(text)
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved

# ["This is something", "<special_token_1>", " else"]
for i, token in enumerate(tokens):
if token in no_split_token:
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,11 @@ def all_special_ids(self) -> List[int]:
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process.
split_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the special tokens should be split during the tokenization process. The default behaviour is
to not split special tokens. This means that if `<s>` is the `bos_token`, then `tokenizer.tokenize("<s>") =
['<s>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<s>")` will be give `['<',
's', '>']`.
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
"""


Expand Down Expand Up @@ -1541,6 +1546,9 @@ def __init__(self, **kwargs):
# By default, cleaning tokenization spaces for both fast and slow tokenizers
self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True)

# By default, do not split special tokens for both fast and slow tokenizers
self.split_special_tokens = kwargs.pop("split_special_tokens", False)

self.deprecation_warnings = (
{}
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
Expand Down Expand Up @@ -2157,7 +2165,7 @@ def save_pretrained(

# TODO: Ensure the modified attributes (those are also in the __init__ kwargs) will give identical tokenizers
# target_keys = self.init_kwargs.keys()
target_keys = ["model_max_length", "clean_up_tokenization_spaces"]
target_keys = ["model_max_length", "clean_up_tokenization_spaces", "split_special_tokens"]
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
for k in target_keys:
if hasattr(self, k):
tokenizer_config[k] = getattr(self, k)
Expand Down
5 changes: 5 additions & 0 deletions tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,11 @@ def test_encode_decode_with_spaces(self):
def test_right_and_left_truncation(self):
pass

@unittest.skip("Not implemented")
def test_test_split_special_tokens(self):
...


def test_encode_plus_with_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
Expand Down
4 changes: 4 additions & 0 deletions tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ def test_encode_decode_with_spaces(self):
def test_right_and_left_truncation(self):
pass

@unittest.skip("Not implemented")
def test_test_split_special_tokens(self):
...

def test_encode_plus_with_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3909,6 +3909,7 @@ def test_save_slow_from_fast_and_reload_fast(self):
# Should not raise an error
self.rust_tokenizer_class.from_pretrained(tmp_dir_2)

# TODO This is ran for all models but only tests bert...
def test_clean_up_tokenization_spaces(self):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
assert tokenizer.clean_up_tokenization_spaces is True
Expand Down Expand Up @@ -3953,3 +3954,28 @@ def test_clean_up_tokenization_spaces(self):
tokenizer.clean_up_tokenization_spaces = True
decoded = tokenizer.decode(tokens)
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"

def test_split_special_tokens(self):
if not self.test_slow_tokenizer:
return

for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
special_token = "[SPECIAL_TOKEN]"
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)

if not tokenizer.is_fast:
# bloom, gptneox etc only have a fast
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)

encoded_split_special_token = tokenizer.encode(
special_token, add_special_tokens=False, split_special_tokens=True
)
breakpoint()
if len(encoded_split_special_token) == 1:
# if we have subword tokenization or special vocab
self.assertTrue(encoded_split_special_token[0] != tokenizer.convert_tokens_to_ids(special_token) )
else:
self.assertTrue(len(encoded_split_special_token) > 1)
Loading