Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fast tokenizer for debertaV2 #14928

Closed
wants to merge 4 commits into from

Conversation

alcinos
Copy link
Contributor

@alcinos alcinos commented Dec 26, 2021

What does this PR do?

Implements a fast tokenizer for deberta v2. Loosely based on #11387

Fixes #11529
Fixes #14712

This is a draft as there are some failing tests (not super clear to me why atm, will have to investigate)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@LysandreJik

index of the token comprising a given character or the span of characters corresponding to a given token). Currently
no "Fast" implementation is available for the SentencePiece-based tokenizers (for T5, ALBERT, CamemBERT, XLM-RoBERTa
and XLNet models).
index of the token comprising a given character or the span of characters corresponding to a given token).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this change is not technically related to the PR, but I found this bit of the documentation to be outdated. If you want a separate PR, lmk.

@alcinos
Copy link
Contributor Author

alcinos commented Dec 26, 2021

I noticed that while I was working on my PR, another was submitted for the same purpose: #14923.

@stefan-it
Copy link
Collaborator

stefan-it commented Dec 26, 2021

Hey @alcinos ,

thanks for adding it!

I'm currently running comparisons between slow and fast tokenizer. Here are some mismatches between fast and slow.

I just run tokenization tests on README.md and README_zh-hans.md from official Transformers library, using this script:

import sys

from transformers import DebertaV2Tokenizer, DebertaV2TokenizerFast

model_name = "microsoft/deberta-v2-xlarge"

slow_tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
fast_tokenizer = DebertaV2TokenizerFast.from_pretrained(model_name)

filename = sys.argv[1]

with open(filename, "rt") as f_p:
    for line in f_p:
        line = line.rstrip()

        if not line:
            continue

        slow_tokens = slow_tokenizer.tokenize(line)
        fast_tokens = fast_tokenizer.tokenize(line)

        if slow_tokens != fast_tokens:
            print("Tokenization mismatch:", line)
            print("Slow tokens:", slow_tokens)
            print("Fast tokens:", fast_tokens)

Here are some mismatches:

Original input: * 🖼️ Images, for tasks like image classification, object detection, and segmentation.
Slow tokens: ['▁*', '▁', '[UNK]', '️', '▁Images', ',', '▁for', '▁tasks', '▁like', '▁image', '▁classification', ',', '▁object', '▁detection', ',', '▁and', '▁segmentation', '.']
Fast tokens: ['▁*', '▁', '🖼', '️', '▁Images', ',', '▁for', '▁tasks', '▁like', '▁image', '▁classification', ',', '▁object', '▁detection', ',', '▁and', '▁segmentation', '.']

Another example on README_zh-hans.md:

Original input: - 对教学和实践友好且低门槛
Slow tokens: ['▁-', '▁', '对', '教', '学', '和', '实', '践', '友', '好', '且', '低', '门', '[UNK]']
Fast tokens: ['▁-', '▁', '对', '教', '学', '和', '实', '践', '友', '好', '且', '低', '门', '槛']

The original DeBERTa tokenizer outputs the same tokens as the slow tokenizer.

Copy link
Contributor

@SaulLu SaulLu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for working to resolve this issue @alcinos ! I see that you have already understood how our tokenizers were designated: that's really great! 😄

I took the liberty to leave a little comment on a particular point of this tokenizer, don't hesitate if you need more information about it.

By the way, I noticed that you said "This is a draft as there are some failing tests (not super clear to me why atm, will have to investigate)". Are all the tests that don't pass not clear to you, or only some? In any case, we can list together the tests that do not pass in this PR and discuss how to solve these problems.

PS: I prefer to warn you that I will probably not be very active this week on github, as I'm on vacations.

Comment on lines +115 to +116
do_lower_case=False,
split_by_punct=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that these arguments will require to change respectively the normalizer and the pre_tokenizer of the backend_tokenizer object.

To start with, I would advise to add a specific test for these arguments which would allow to check that the tokenization is identical between the slow tokenizer and the fast tokenizer for all possible values for these arguments.

Copy link
Contributor Author

@alcinos alcinos Dec 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @SaulLu
Thanks a lot for your review.

As for the lower_case arg, I followed Alberts’s tokenizer. As you mentioned, Albert has a modification to the normalizer in the converter:

def normalizer(self, proto):
list_normalizers = [
normalizers.Replace("``", '"'),
normalizers.Replace("''", '"'),
]
if not self.original_tokenizer.keep_accents:
list_normalizers.append(normalizers.NFKD())
list_normalizers.append(normalizers.StripAccents())
if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())

which I duplicated in my PR:
if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())

Eyeballing the init method of PreTrainedTokenizerFast makes me believe the creation process always involves using the said slow->fast conversion method, so that should be covered?

As for split_by_punct we could take the same approach and overload the pre-tokenizer method of the converter? Would a sequence [MetaSpace, Punct] do the trick? I’m a bit uncertain here since there doesn’t seem to be any other converter that seem to be dealing with punctuation splitting so maybe I’m understanding this wrong.

EDIT: forgot to mention, tests are indeed a good idea, will see what is the best way to test this behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eyeballing the init method of PreTrainedTokenizerFast makes me believe the creation process always involves using the said slow->fast conversion method, so that should be covered?

Indeed, for the first use case we initialize the fast tokenizer by using the conversion script. However, we also want to be able to initialize this fast tokenizer from the fast files only. It will thus be necessary to also modify the backend_tokenizer in the __init__ method. If ever it is useful, here are the lines where it is done for Bert. Be careful, bert has a custom normalizer just for it so we should adapt these lines to the normalizer of deberta-v2.

(note that you allowed me to notice that we should have the same kind of thing for Albert, I'll open a new PR for that).

As for split_by_punct we could take the same approach and overload the pre-tokenizer method of the converter? Would a sequence [MetaSpace, Punct] do the trick? I’m a bit uncertain here since there doesn’t seem to be any other converter that seem to be dealing with punctuation splitting so maybe I’m understanding this wrong.

It is indeed exactly the same approach that I would have tested first. However, I can't confirm that the pre_tokenizers.Punctuation module behaves exactly like the slow tokenizer feature. But some tests should answer this question 😄

@alcinos
Copy link
Contributor Author

alcinos commented Dec 27, 2021

@stefan-it Thanks for looking into this and providing the testcases.

It seems that the issues you are reporting are all related to unknown tokens? I don’t know the rust implementation well enough, is there any reason why the fast tokenizer would not respect the vocabulary?

@stefan-it
Copy link
Collaborator

stefan-it commented Dec 27, 2021

Hey @alcinos I'm currently trying to figure it out :)

@stefan-it
Copy link
Collaborator

stefan-it commented Dec 27, 2021

Good news: when using encode, there's no mismatch between slow and fast tokenizer.

For slow tokenizer, this is happening here:

def _norm(x):
if x not in self.vocab or x == "<unk>":
return "[UNK]"
else:
return x

Also when using normal T5 (Slow and fast) there are no UNKs when using the tokenize function (but encode shows that those subtokens are UNKs) so this is DeBERTa-specific.

@mingboiz
Copy link
Contributor

mingboiz commented Dec 28, 2021

@alcinos I think the issue has something regarding the tokenize function inherited from PreTrainedTokenizerFast

For this line: * 🖼️ Images, for tasks like image classification, object detection, and segmentation.

the tokenize function calledencode_plus returns a Dict[List] converted from BatchEncoding. For both slow and fast tokenizers encode_plus return {'input_ids': [943, 250, 3, 28596, 7654, 6, 14, 2930, 72, 812, 8692, 6, 2328, 5563, 6, 7, 27235], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
return self.encode_plus(text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs).tokens()

but the .tokens() at the end of the line doesn't return the [UNK] token but rather the token itself which causes the discrepancy here.

May not be entirely correct - but I was able to fix the discrepancy in the testing script @stefan-it provided by overriding the tokenize function and adding this method in Debertav2TokenizerFast class

    def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
        enc = self.encode_plus(text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs)
        return self.convert_ids_to_tokens(enc['input_ids'])

But the other tests are still failing and I'm not sure what's causing the issue and need to investigate.

@alcinos
Copy link
Contributor Author

alcinos commented Dec 28, 2021

Thanks @stefan-it and @mingboiz for looking into the tokenization issue. If I summarize the findings so far:

  • Slow tokenizer replaces unknown tokens with "[UNK]" while fast tokenizer doesn’t
  • This behavior seems specific to Deberta, as T5’s tokenizers don’t replace with "[UNK]"
  • After encoding, the results are the same for slow and fast, meaning that the issue is probably minor
  • @mingboiz found a way to have the fast tokenizer spit out the "[UNK]".

I’m not sure what the expected behavior should be, nor whether we should be concerned about this in the first place. Input from someone from HF would be appreciated :) (ping @SaulLu )

Aside from that, I pushed some fixes, more tests are passing.
Some feedback for the HF team on the issue I ran into:
In one of the common tests, the code looks for a "do_lower_case" attribute variable:

if not hasattr(tokenizer, "do_lower_case") or not tokenizer.do_lower_case:
continue

This is problematic in my opinion since:

  • I didn’t see any mention that this attribute variable is required in the documentation I’ve come across (though I may have missed it)
  • The code silently fails if the attribute is not present
  • This argument itself is not used anywhere in the DebertaV2TokenizeFast class, hence I was not naturally inclined to add it as an attribute.

I would suggest one of the following change to make this more dev friendly:

  • Make it a hard requirement that any tokenizer class must have this attribute, document it, and remove the silent fail in case it’s not found
  • Additionally, IMHO this would be better suited as a an overridable getter method rather than a direct access to a private attribute

More tests are failing: one most likely has the same root cause as the issue raised by @stefan-it. The others seem to be failing because in some code paths the vocab_file is None, but it’s not clear to me why that happens, any help on that appreciated.

@mingboiz
Copy link
Contributor

mingboiz commented Dec 28, 2021

@alcinos I can't figure out a solution yet but the tests that are failing because of the missing vocab file which I think it's because in all of them legacy_format=False is being selected

tokenizer_old.save_pretrained(tmp_dir, legacy_format=False) # save only fast version

which only saves using the Rust Tokenizer these files without the spm.model vocab file:

tokenizer_config.json',
special_tokens_map.json',
tokenizer.json'

this code chunk will run instead without using the save_vocabulary function:

if save_fast:
tokenizer_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_FILE
)
self.backend_tokenizer.save(tokenizer_file)
file_names = file_names + (tokenizer_file,)

Debertav2Tokenizer didn't have this issue because its backend SPMTokenizer class provided its own save_pretrained method to save the spm.model, but I can't figure out why the AlbertTokenizerFast tests could work and passes all tests when the same tests fails here - which I think these are the only remaining failing tests:

  • test_saving_tokenizer_trainer
  • test_training_new_tokenizer_with_special_tokens_change
  • test_training_new_tokenizer

@SaulLu
Copy link
Contributor

SaulLu commented Dec 28, 2021

I have noted the ping! I'll come back to you as soon as possible on this subject because the choice to be made here is not obvious: you have highlighted that the tokenize method of DebertaV2Tokenizer does not behave in the same way as all the tokenizer methods of the fast tokenizers. At the moment it would make more sense to me to modify the DebertaV2Tokenizer's tokenize method, but in general we don't really like to introduce backward incompatibility, so I'll need to discuss it with the maintainers 🙂.

Copy link
Contributor

@SaulLu SaulLu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for all your work, I'm sorry I couldn't be very responsive last week.

So, to bounce around a few points:

I’m not sure what the expected behavior should be, nor whether we should be concerned about this in the first place. Input from someone from HF would be appreciated :)

@alicnos, @minboiz, One of the maintainers also shares my opinion, we think here that it would be better to modify the tokenize method of the slow version of the tokenizer so that it looks like the one of other tokenizers.

Some feedback for the HF team on the issue I ran into:

@alcinos , thanks a lot for sharing your opinion, it's extremely useful to know the difficulties you encounter.

Indeed, when you add a new tokenizer, an error like the one you mention with do_lower_case happens quickly because all the tests in test_tokenization_commons.py are not useful for all the tokenizers (and can be silently passed). My advice is to voluntarily introduce an error (raise ValueError("to delete") in the most used methods (like encode, tokenize, __call__) in order to see which tests are really executed and which are not. I then check that it is normal if some tests are not executed .

Unfortunately, not all tokenizers in the library have a do_lower_case argument and that's why this is not a hard requirement.

Additionally, IMHO this would be better suited as an overridable getter method rather than a direct access to a private attribute

@alcinos , I'm not sure what you are referring to here. What private attribute are you talking about? 🙂

DebertaV2TokenizationTest.test_saving_tokenizer_trainer test fails
I'll try to find out more about what's going on. From what I have seen this is due to the fact that on the first save, the key and the value 'tokenizer_file': None are saved in the tokenizer_config.json (so they are not overridden afterwards because of this line). What I need to check again is why this key and value are saved with deberta-v2 but not with albert.

Comment on lines +142 to +146
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
Copy link
Contributor

@SaulLu SaulLu Jan 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these lines should be removed. Indeed, for all the tokenizers having a slow version and a fast version we wish to leave the possibility of initializing the tokenizer starting from the two types of files: the files of the slow version or the files of the fast version. It seems to me that these lines would prevent to initialize a deberta-v2 fast tokenizer with only fast files.

Suggested change
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)

I think that removing these lines could solve the current problem with the test_training_new_tokenizer_with_special_tokens_change.

Comment on lines +115 to +116
do_lower_case=False,
split_by_punct=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eyeballing the init method of PreTrainedTokenizerFast makes me believe the creation process always involves using the said slow->fast conversion method, so that should be covered?

Indeed, for the first use case we initialize the fast tokenizer by using the conversion script. However, we also want to be able to initialize this fast tokenizer from the fast files only. It will thus be necessary to also modify the backend_tokenizer in the __init__ method. If ever it is useful, here are the lines where it is done for Bert. Be careful, bert has a custom normalizer just for it so we should adapt these lines to the normalizer of deberta-v2.

(note that you allowed me to notice that we should have the same kind of thing for Albert, I'll open a new PR for that).

As for split_by_punct we could take the same approach and overload the pre-tokenizer method of the converter? Would a sequence [MetaSpace, Punct] do the trick? I’m a bit uncertain here since there doesn’t seem to be any other converter that seem to be dealing with punctuation splitting so maybe I’m understanding this wrong.

It is indeed exactly the same approach that I would have tested first. However, I can't confirm that the pre_tokenizers.Punctuation module behaves exactly like the slow tokenizer feature. But some tests should answer this question 😄

@SaulLu
Copy link
Contributor

SaulLu commented Jan 5, 2022

@alcinos and @mingboiz , while investigating the test_saving_tokenizer_trainer test further, I noticed that the variable VOCAB_FILES_NAMES did not specify the "tokenizer_file" value (and we need it for the fast version of the tokenizer). Moreover, for the test to fully succeed, the following lines must also be removed from the tokenization_deberta_v2_fast.py file.

        if not os.path.isfile(vocab_file):
            raise ValueError(
                f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
                "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
            )

With these 2 changes, the test now pass 😄 !

I have opened a PR here to show you the changes that should be made to solve these problems. Feel free to merge it if you agree with it. 🙂

@stefan-it
Copy link
Collaborator

@alcinos could you please have a look at the alcinos#1 PR - I think it is ready then 🤗

)

if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please allow the tokenizer to be also saved, if the file it loaded from is removed?
(see here for example tokenization_albert.py )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a really good point!

Copy link
Contributor

@mingboiz mingboiz Feb 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this already supported in deberta-v2? specifically line 481, 482

def save_pretrained(self, path: str, filename_prefix: str = None):
filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]]
if filename_prefix is not None:
filename = filename_prefix + "-" + filename
full_path = os.path.join(path, filename)
with open(full_path, "wb") as fs:
fs.write(self.spm.serialized_model_proto())
return (full_path,)

@SaulLu
Copy link
Contributor

SaulLu commented Feb 3, 2022

Hi @alcinos, thank you very much for your work, the addition seems to be near the end! Please let me know if you need help with any of it!

@mingboiz mingboiz mentioned this pull request Feb 5, 2022
5 tasks
@huggingface huggingface deleted a comment from github-actions bot Mar 9, 2022
@github-actions github-actions bot closed this Apr 25, 2022
@huggingface huggingface deleted a comment from github-actions bot Apr 25, 2022
@SaulLu
Copy link
Contributor

SaulLu commented Apr 25, 2022

Finished in PR #15529

Thanks all again for the contribution 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DeBERTa V3 Fast Tokenizer Deberta v2 Fast Tokenizer
5 participants