Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the speed of adding tokens from added_tokens.json #10780

Merged
merged 2 commits into from
Apr 1, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see
tokenization_utils_fast.py
"""
import bisect
import itertools
import re
import unicodedata
Expand Down Expand Up @@ -99,6 +100,19 @@ def _is_start_of_word(text):
return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))


def _insert_one_token_to_ordered_list(token_list: List[str], new_token: str):
"""
Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted.
"""
insertion_idx = bisect.bisect_left(token_list, new_token)
# Checks if new_token is already in the ordered token_list
if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token:
# new_token is in token_list, don't add
return
else:
token_list.insert(insertion_idx, new_token)


@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
class PreTrainedTokenizer(PreTrainedTokenizerBase):
"""
Expand Down Expand Up @@ -199,10 +213,16 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to

# Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
if special_tokens:
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(new_tokens)))
if len(new_tokens) == 1:
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, new_tokens[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're losing the sorted on the self.unique_no_split_token here, which is required by the _insert_one_token_to_ordered_list, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @LysandreJik , _insert_one_token_to_ordered_list uses bisect to find a position in the sorted list to insert a new token. So as long as the input list is already sorted, it remains sorted after the insertion. To make sure, I added the following assertion to the benchmark script to test:

# after loading tokenizer from a saved model
all_values = tokenizer.unique_no_split_tokens
assert all(all_values[i+1] > all_values[i] for i in range(len(all_values) - 1))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assumption that the input list is already sorted is because it would have been built using this which sorts values? Other than that particular question it looks good to me. I'll ping another from the team and if it's good for them we'll merge. Thanks for your patience!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's right. Because populating an empty tokenizer. unique_no_split_tokens with this bisect approach ensures it will always be sorted. Thanks!

else:
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(new_tokens)))
else:
# Or on the newly added tokens
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
if len(tokens_to_add) == 1:
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above comment^

else:
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))

return len(tokens_to_add)

Expand Down