Skip to content

Commit

Permalink
[tokenizers] Fixing huggingface#8001 - Adding tests on tokenizers ser…
Browse files Browse the repository at this point in the history
…ialization (huggingface#8006)

* fixing huggingface#8001

* make T5 tokenizer serialization more robust - style
  • Loading branch information
thomwolf authored and fabiocapsouza committed Nov 15, 2020
1 parent fd48878 commit 9f79b9a
Show file tree
Hide file tree
Showing 17 changed files with 98 additions and 25 deletions.
3 changes: 3 additions & 0 deletions src/transformers/tokenization_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def __init__(
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
remove_space=remove_space,
keep_accents=keep_accents,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/tokenization_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,16 @@ def __init__(
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/tokenization_bertweet.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,12 @@ def __init__(
**kwargs
):
super().__init__(
normalization=normalization,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs,
Expand Down
6 changes: 2 additions & 4 deletions src/transformers/tokenization_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,13 @@ class GPT2Tokenizer(object):
- We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264
do_lower_case (:obj:`bool`, optional):
Whether to convert inputs to lower case. **Not used in GPT2 tokenizer**.
special_tokens (:obj:`list`, optional):
List of special tokens to be added to the end of the vocabulary.
"""

def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None):
def __init__(self, vocab_file=None, special_tokens=None):
self.pad_token = "[PAD]"
self.sep_token = "[SEP]"
self.unk_token = "[UNK]"
Expand Down Expand Up @@ -523,6 +520,7 @@ def __init__(
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/tokenization_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def __init__(
):
super().__init__(
langs=langs,
src_vocab_file=src_vocab_file,
tgt_vocab_file=tgt_vocab_file,
merges_file=merges_file,
unk_token=unk_token,
bos_token=bos_token,
sep_token=sep_token,
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/tokenization_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,14 @@ def __init__(
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
super().__init__(
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
add_prefix_space=add_prefix_space,
**kwargs,
)

with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/tokenization_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,12 @@ def __init__(
):
super().__init__(
# bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id
model_max_length=model_max_length,
eos_token=eos_token,
source_lang=source_lang,
target_lang=target_lang,
unk_token=unk_token,
eos_token=eos_token,
pad_token=pad_token,
model_max_length=model_max_length,
**kwargs,
)
assert Path(source_spm).exists(), f"cannot find spm source {source_spm}"
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/tokenization_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,16 @@ def __init__(
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
x_sep_token=x_sep_token,
pad_token=pad_token,
mask_token=mask_token,
x_sep_token=x_sep_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
self.unique_no_split_tokens.append(x_sep_token)
Expand Down
15 changes: 11 additions & 4 deletions src/transformers/tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,22 @@ def __init__(
**kwargs
):
# Add extra_ids to the special token list
if extra_ids > 0:
if additional_special_tokens is None:
additional_special_tokens = []
additional_special_tokens.extend(["<extra_id_{}>".format(i) for i in range(extra_ids)])
if extra_ids > 0 and additional_special_tokens is None:
additional_special_tokens = ["<extra_id_{}>".format(i) for i in range(extra_ids)]
elif extra_ids > 0 and additional_special_tokens is not None:
# Check that we have the right number of extra_id special tokens
extra_tokens = len(set(filter(lambda x: bool("extra_id" in x), additional_special_tokens)))
if extra_tokens != extra_ids:
raise ValueError(
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
"In this case the additional_special_tokens must include the extra_ids tokens"
)

super().__init__(
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
extra_ids=extra_ids,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
Expand Down
19 changes: 12 additions & 7 deletions src/transformers/tokenization_t5_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ def __init__(
additional_special_tokens=None,
**kwargs
):
# Add extra_ids to the special token list
if extra_ids > 0 and additional_special_tokens is None:
additional_special_tokens = ["<extra_id_{}>".format(i) for i in range(extra_ids)]
elif extra_ids > 0 and additional_special_tokens is not None:
# Check that we have the right number of extra special tokens
extra_tokens = len(set(filter(lambda x: bool("extra_id_" in x), additional_special_tokens)))
if extra_tokens != extra_ids:
raise ValueError(
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
"In this case the additional_special_tokens must include the extra_ids tokens"
)

super().__init__(
vocab_file,
tokenizer_file=tokenizer_file,
Expand All @@ -137,13 +149,6 @@ def __init__(
**kwargs,
)

if extra_ids > 0:
all_extra_tokens = ["<extra_id_{}>".format(i) for i in range(extra_ids)]
if all(tok not in self.additional_special_tokens for tok in all_extra_tokens):
self.additional_special_tokens = self.additional_special_tokens + [
"<extra_id_{}>".format(i) for i in range(extra_ids)
]

self.vocab_file = vocab_file
self._extra_ids = extra_ids

Expand Down
14 changes: 13 additions & 1 deletion src/transformers/tokenization_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,19 @@ def __init__(
**kwargs
):
super().__init__(
unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs
special=special,
min_freq=min_freq,
max_size=max_size,
lower_case=lower_case,
delimiter=delimiter,
vocab_file=vocab_file,
pretrained_vocab_file=pretrained_vocab_file,
never_split=never_split,
unk_token=unk_token,
eos_token=eos_token,
additional_special_tokens=additional_special_tokens,
language=language,
**kwargs,
)

if never_split is None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,7 @@ def _from_pretrained(
if (
"tokenizer_file" not in resolved_vocab_files or resolved_vocab_files["tokenizer_file"] is None
) and cls.slow_tokenizer_class is not None:
slow_tokenizer = cls.slow_tokenizer_class._from_pretrained(
slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
copy.deepcopy(resolved_vocab_files),
pretrained_model_name_or_path,
copy.deepcopy(init_configuration),
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/tokenization_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
For slow (python) tokenizers see tokenization_utils.py
"""

import copy
import json
import os
import warnings
Expand Down Expand Up @@ -105,7 +104,7 @@ def __init__(self, *args, **kwargs):
self._tokenizer = fast_tokenizer

if slow_tokenizer is not None:
kwargs = copy.deepcopy(slow_tokenizer.init_kwargs)
kwargs.update(slow_tokenizer.init_kwargs)

# We call this after having initialized the backend tokenizer because we update it.
super().__init__(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/tokenization_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,9 @@ def __init__(
cls_token=cls_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
lang2id=lang2id,
id2lang=id2lang,
do_lowercase_and_remove_accent=do_lowercase_and_remove_accent,
**kwargs,
)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/tokenization_xlm_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,10 @@ def __init__(
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
unk_token=unk_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs,
)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/tokenization_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def __init__(
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
remove_space=remove_space,
keep_accents=keep_accents,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
Expand Down
19 changes: 19 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,25 @@ def test_rust_tokenizer_signature(self):
self.assertIn("tokenizer_file", signature.parameters)
self.assertIsNone(signature.parameters["tokenizer_file"].default)

def test_tokenizer_slow_store_full_signature(self):
signature = inspect.signature(self.tokenizer_class.__init__)
tokenizer = self.get_tokenizer()

for parameter_name, parameter in signature.parameters.items():
if parameter.default != inspect.Parameter.empty:
self.assertIn(parameter_name, tokenizer.init_kwargs)

def test_tokenizer_fast_store_full_signature(self):
if not self.test_rust_tokenizer:
return

signature = inspect.signature(self.rust_tokenizer_class.__init__)
tokenizer = self.get_rust_tokenizer()

for parameter_name, parameter in signature.parameters.items():
if parameter.default != inspect.Parameter.empty:
self.assertIn(parameter_name, tokenizer.init_kwargs)

def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
Expand Down

0 comments on commit 9f79b9a

Please sign in to comment.