Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fast tokenizer for debertaV2 #14928

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Flax), PyTorch, and/or TensorFlow.
| ConvBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
| DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
| DeBERTa-v2 | ✅ | | ✅ | ✅ | ❌ |
| DeBERTa-v2 | ✅ | | ✅ | ✅ | ❌ |
| DeiT | ❌ | ❌ | ✅ | ❌ | ❌ |
| DETR | ❌ | ❌ | ✅ | ❌ | ❌ |
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand Down
4 changes: 1 addition & 3 deletions docs/source/main_classes/tokenizer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ Rust library [🤗 Tokenizers](https://github.com/huggingface/tokenizers). The "

1. a significant speed-up in particular when doing batched tokenization and
2. additional methods to map between the original string (character and words) and the token space (e.g. getting the
index of the token comprising a given character or the span of characters corresponding to a given token). Currently
no "Fast" implementation is available for the SentencePiece-based tokenizers (for T5, ALBERT, CamemBERT, XLM-RoBERTa
and XLNet models).
index of the token comprising a given character or the span of characters corresponding to a given token).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this change is not technically related to the PR, but I found this bit of the documentation to be outdated. If you want a separate PR, lmk.


The base classes [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`]
implement the common methods for encoding string inputs in model inputs (see below) and instantiating/saving python and
Expand Down
6 changes: 6 additions & 0 deletions docs/source/model_doc/deberta_v2.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
- create_token_type_ids_from_sequences
- save_vocabulary

## DebertaV2TokenizerFast

[[autodoc]] DebertaV2TokenizerFast
- build_inputs_with_special_tokens
- create_token_type_ids_from_sequences

## DebertaV2Model

[[autodoc]] DebertaV2Model
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@
_import_structure["models.blenderbot"].append("BlenderbotTokenizerFast")
_import_structure["models.camembert"].append("CamembertTokenizerFast")
_import_structure["models.deberta"].append("DebertaTokenizerFast")
_import_structure["models.deberta_v2"].append("DebertaV2TokenizerFast")
_import_structure["models.distilbert"].append("DistilBertTokenizerFast")
_import_structure["models.dpr"].extend(
["DPRContextEncoderTokenizerFast", "DPRQuestionEncoderTokenizerFast", "DPRReaderTokenizerFast"]
Expand Down Expand Up @@ -2454,6 +2455,7 @@
from .models.clip import CLIPTokenizerFast
from .models.convbert import ConvBertTokenizerFast
from .models.deberta import DebertaTokenizerFast
from .models.deberta_v2 import DebertaV2TokenizerFast
from .models.distilbert import DistilBertTokenizerFast
from .models.dpr import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast, DPRReaderTokenizerFast
from .models.electra import ElectraTokenizerFast
Expand Down
25 changes: 25 additions & 0 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,30 @@ def post_processor(self):
)


class DebertaV2Converter(SpmConverter):
def normalizer(self, proto):
list_normalizers = []
if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())

precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
if precompiled_charsmap:
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))

return normalizers.Sequence(list_normalizers)

def post_processor(self):
return processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
],
)


class CLIPConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
Expand Down Expand Up @@ -921,6 +945,7 @@ def converted(self) -> Tokenizer:
"CLIPTokenizer": CLIPConverter,
"ConvBertTokenizer": BertConverter,
"DebertaTokenizer": DebertaConverter,
"DebertaV2Tokenizer": DebertaV2Converter,
"DistilBertTokenizer": BertConverter,
"DPRReaderTokenizer": BertConverter,
"DPRQuestionEncoderTokenizer": BertConverter,
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,13 @@
("fsmt", ("FSMTTokenizer", None)),
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
("deberta-v2", ("DebertaV2Tokenizer" if is_sentencepiece_available() else None, None)),
(
"deberta-v2",
(
"DebertaV2Tokenizer" if is_sentencepiece_available() else None,
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
),
),
("rag", ("RagTokenizer", None)),
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/deberta_v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_tf_available, is_torch_available
from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available


_import_structure = {
"configuration_deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"],
"tokenization_deberta_v2": ["DebertaV2Tokenizer"],
}

if is_tokenizers_available():
_import_structure["tokenization_deberta_v2_fast"] = ["DebertaV2TokenizerFast"]

if is_tf_available():
_import_structure["modeling_tf_deberta_v2"] = [
"TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -53,6 +56,9 @@
from .configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config
from .tokenization_deberta_v2 import DebertaV2Tokenizer

if is_tokenizers_available():
from .tokenization_deberta_v2_fast import DebertaV2TokenizerFast

if is_tf_available():
from .modeling_tf_deberta_v2 import (
TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
)
self.do_lower_case = do_lower_case
self.split_by_punct = split_by_punct
self.vocab_file = vocab_file
self._tokenizer = SPMTokenizer(vocab_file, split_by_punct=split_by_punct, sp_model_kwargs=self.sp_model_kwargs)

@property
Expand Down
247 changes: 247 additions & 0 deletions src/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# coding=utf-8
# Copyright 2020 Microsoft and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization class for model DeBERTa."""

import os
from shutil import copyfile
from typing import Optional, Tuple

from ...file_utils import is_sentencepiece_available
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging


if is_sentencepiece_available():
from .tokenization_deberta_v2 import DebertaV2Tokenizer
else:
DebertaV2Tokenizer = None

logger = logging.get_logger(__name__)

PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"microsoft/deberta-v2-xlarge": "https://huggingface.co/microsoft/deberta-v2-xlarge/resolve/main/spm.model",
"microsoft/deberta-v2-xxlarge": "https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/spm.model",
"microsoft/deberta-v2-xlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/spm.model",
"microsoft/deberta-v2-xxlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/spm.model",
}
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"microsoft/deberta-v2-xlarge": 512,
"microsoft/deberta-v2-xxlarge": 512,
"microsoft/deberta-v2-xlarge-mnli": 512,
"microsoft/deberta-v2-xxlarge-mnli": 512,
}

PRETRAINED_INIT_CONFIGURATION = {
"microsoft/deberta-v2-xlarge": {"do_lower_case": False},
"microsoft/deberta-v2-xxlarge": {"do_lower_case": False},
"microsoft/deberta-v2-xlarge-mnli": {"do_lower_case": False},
"microsoft/deberta-v2-xxlarge-mnli": {"do_lower_case": False},
}

VOCAB_FILES_NAMES = {"vocab_file": "spm.model"}


class DebertaV2TokenizerFast(PreTrainedTokenizerFast):
r"""
Constructs a DeBERTa-v2 fast tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
contains the vocabulary necessary to instantiate a tokenizer.
do_lower_case (`bool`, *optional*, defaults to `False`):
Whether or not to lowercase the input when tokenizing.
bos_token (`string`, *optional*, defaults to "[CLS]"):
The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token.
When building a sequence using special tokens, this is not the token that is used for the beginning of
sequence. The token used is the `cls_token`.
eos_token (`string`, *optional*, defaults to "[SEP]"):
The end of sequence token. When building a sequence using special tokens, this is not the token that is
used for the end of sequence. The token used is the `sep_token`.
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
The token used for padding, for example when batching sequences of different lengths.
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
sp_model_kwargs (`dict`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, to set:

- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.

- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.

- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
"""

vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
slow_tokenizer_class = DebertaV2Tokenizer

def __init__(
self,
vocab_file=None,
tokenizer_file=None,
do_lower_case=False,
split_by_punct=False,
Comment on lines +116 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that these arguments will require to change respectively the normalizer and the pre_tokenizer of the backend_tokenizer object.

To start with, I would advise to add a specific test for these arguments which would allow to check that the tokenization is identical between the slow tokenizer and the fast tokenizer for all possible values for these arguments.

Copy link
Contributor Author

@alcinos alcinos Dec 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @SaulLu
Thanks a lot for your review.

As for the lower_case arg, I followed Alberts’s tokenizer. As you mentioned, Albert has a modification to the normalizer in the converter:

def normalizer(self, proto):
list_normalizers = [
normalizers.Replace("``", '"'),
normalizers.Replace("''", '"'),
]
if not self.original_tokenizer.keep_accents:
list_normalizers.append(normalizers.NFKD())
list_normalizers.append(normalizers.StripAccents())
if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())

which I duplicated in my PR:
if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())

Eyeballing the init method of PreTrainedTokenizerFast makes me believe the creation process always involves using the said slow->fast conversion method, so that should be covered?

As for split_by_punct we could take the same approach and overload the pre-tokenizer method of the converter? Would a sequence [MetaSpace, Punct] do the trick? I’m a bit uncertain here since there doesn’t seem to be any other converter that seem to be dealing with punctuation splitting so maybe I’m understanding this wrong.

EDIT: forgot to mention, tests are indeed a good idea, will see what is the best way to test this behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eyeballing the init method of PreTrainedTokenizerFast makes me believe the creation process always involves using the said slow->fast conversion method, so that should be covered?

Indeed, for the first use case we initialize the fast tokenizer by using the conversion script. However, we also want to be able to initialize this fast tokenizer from the fast files only. It will thus be necessary to also modify the backend_tokenizer in the __init__ method. If ever it is useful, here are the lines where it is done for Bert. Be careful, bert has a custom normalizer just for it so we should adapt these lines to the normalizer of deberta-v2.

(note that you allowed me to notice that we should have the same kind of thing for Albert, I'll open a new PR for that).

As for split_by_punct we could take the same approach and overload the pre-tokenizer method of the converter? Would a sequence [MetaSpace, Punct] do the trick? I’m a bit uncertain here since there doesn’t seem to be any other converter that seem to be dealing with punctuation splitting so maybe I’m understanding this wrong.

It is indeed exactly the same approach that I would have tested first. However, I can't confirm that the pre_tokenizers.Punctuation module behaves exactly like the slow tokenizer feature. But some tests should answer this question 😄

bos_token="[CLS]",
eos_token="[SEP]",
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
**kwargs
) -> None:
super().__init__(
vocab_file,
tokenizer_file=tokenizer_file,
do_lower_case=do_lower_case,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
split_by_punct=split_by_punct,
**kwargs,
)

if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
Comment on lines +142 to +146
Copy link
Contributor

@SaulLu SaulLu Jan 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these lines should be removed. Indeed, for all the tokenizers having a slow version and a fast version we wish to leave the possibility of initializing the tokenizer starting from the two types of files: the files of the slow version or the files of the fast version. It seems to me that these lines would prevent to initialize a deberta-v2 fast tokenizer with only fast files.

Suggested change
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)

I think that removing these lines could solve the current problem with the test_training_new_tokenizer_with_special_tokens_change.


self.do_lower_case = do_lower_case
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A DeBERTa sequence has the following format:

- single sequence: [CLS] X [SEP]
- pair of sequences: [CLS] A [SEP] B [SEP]

Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""

if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep

def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.

Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.

Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""

if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)

if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]

def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa
sequence pair mask has the following format:

```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```

If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given
sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)

if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)

if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please allow the tokenizer to be also saved, if the file it loaded from is removed?
(see here for example tokenization_albert.py )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a really good point!

Copy link
Contributor

@mingboiz mingboiz Feb 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this already supported in deberta-v2? specifically line 481, 482

def save_pretrained(self, path: str, filename_prefix: str = None):
filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]]
if filename_prefix is not None:
filename = filename_prefix + "-" + filename
full_path = os.path.join(path, filename)
with open(full_path, "wb") as fs:
fs.write(self.spm.serialized_model_proto())
return (full_path,)


return (out_vocab_file,)
Loading