Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tokenizers] Fixing #8001 - Adding tests on tokenizers serialization #8006

Merged
merged 2 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/tokenization_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def __init__(
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
remove_space=remove_space,
keep_accents=keep_accents,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/tokenization_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,16 @@ def __init__(
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/tokenization_bertweet.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,12 @@ def __init__(
**kwargs
):
super().__init__(
normalization=normalization,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs,
Expand Down
6 changes: 2 additions & 4 deletions src/transformers/tokenization_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,13 @@ class GPT2Tokenizer(object):

- We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264

do_lower_case (:obj:`bool`, optional):
Whether to convert inputs to lower case. **Not used in GPT2 tokenizer**.

special_tokens (:obj:`list`, optional):
List of special tokens to be added to the end of the vocabulary.


"""

def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was not used in the class so I think it's better to remove it from the init args.

def __init__(self, vocab_file=None, special_tokens=None):
self.pad_token = "[PAD]"
self.sep_token = "[SEP]"
self.unk_token = "[UNK]"
Expand Down Expand Up @@ -523,6 +520,7 @@ def __init__(
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/tokenization_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def __init__(
):
super().__init__(
langs=langs,
src_vocab_file=src_vocab_file,
tgt_vocab_file=tgt_vocab_file,
merges_file=merges_file,
unk_token=unk_token,
bos_token=bos_token,
sep_token=sep_token,
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/tokenization_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,14 @@ def __init__(
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
super().__init__(
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
add_prefix_space=add_prefix_space,
**kwargs,
)

with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/tokenization_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,12 @@ def __init__(
):
super().__init__(
# bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id
model_max_length=model_max_length,
eos_token=eos_token,
source_lang=source_lang,
target_lang=target_lang,
unk_token=unk_token,
eos_token=eos_token,
pad_token=pad_token,
model_max_length=model_max_length,
**kwargs,
)
assert Path(source_spm).exists(), f"cannot find spm source {source_spm}"
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/tokenization_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,16 @@ def __init__(
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
x_sep_token=x_sep_token,
pad_token=pad_token,
mask_token=mask_token,
x_sep_token=x_sep_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
self.unique_no_split_tokens.append(x_sep_token)
Expand Down
15 changes: 11 additions & 4 deletions src/transformers/tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,22 @@ def __init__(
**kwargs
):
# Add extra_ids to the special token list
if extra_ids > 0:
if additional_special_tokens is None:
additional_special_tokens = []
additional_special_tokens.extend(["<extra_id_{}>".format(i) for i in range(extra_ids)])
if extra_ids > 0 and additional_special_tokens is None:
additional_special_tokens = ["<extra_id_{}>".format(i) for i in range(extra_ids)]
elif extra_ids > 0 and additional_special_tokens is not None:
# Check that we have the right number of extra_id special tokens
extra_tokens = len(set(filter(lambda x: bool("extra_id" in x), additional_special_tokens)))
if extra_tokens != extra_ids:
raise ValueError(
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
"In this case the additional_special_tokens must include the extra_ids tokens"
)

super().__init__(
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
extra_ids=extra_ids,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
Expand Down
19 changes: 12 additions & 7 deletions src/transformers/tokenization_t5_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ def __init__(
additional_special_tokens=None,
**kwargs
):
# Add extra_ids to the special token list
if extra_ids > 0 and additional_special_tokens is None:
additional_special_tokens = ["<extra_id_{}>".format(i) for i in range(extra_ids)]
elif extra_ids > 0 and additional_special_tokens is not None:
# Check that we have the right number of extra special tokens
extra_tokens = len(set(filter(lambda x: bool("extra_id_" in x), additional_special_tokens)))
if extra_tokens != extra_ids:
raise ValueError(
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
"In this case the additional_special_tokens must include the extra_ids tokens"
)

super().__init__(
vocab_file,
tokenizer_file=tokenizer_file,
Expand All @@ -137,13 +149,6 @@ def __init__(
**kwargs,
)

if extra_ids > 0:
all_extra_tokens = ["<extra_id_{}>".format(i) for i in range(extra_ids)]
if all(tok not in self.additional_special_tokens for tok in all_extra_tokens):
self.additional_special_tokens = self.additional_special_tokens + [
"<extra_id_{}>".format(i) for i in range(extra_ids)
]

self.vocab_file = vocab_file
self._extra_ids = extra_ids

Expand Down
14 changes: 13 additions & 1 deletion src/transformers/tokenization_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,19 @@ def __init__(
**kwargs
):
super().__init__(
unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs
special=special,
min_freq=min_freq,
max_size=max_size,
lower_case=lower_case,
delimiter=delimiter,
vocab_file=vocab_file,
pretrained_vocab_file=pretrained_vocab_file,
never_split=never_split,
unk_token=unk_token,
eos_token=eos_token,
additional_special_tokens=additional_special_tokens,
language=language,
**kwargs,
)

if never_split is None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,7 @@ def _from_pretrained(
if (
"tokenizer_file" not in resolved_vocab_files or resolved_vocab_files["tokenizer_file"] is None
) and cls.slow_tokenizer_class is not None:
slow_tokenizer = cls.slow_tokenizer_class._from_pretrained(
slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
copy.deepcopy(resolved_vocab_files),
pretrained_model_name_or_path,
copy.deepcopy(init_configuration),
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/tokenization_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
For slow (python) tokenizers see tokenization_utils.py
"""

import copy
import json
import os
import warnings
Expand Down Expand Up @@ -105,7 +104,7 @@ def __init__(self, *args, **kwargs):
self._tokenizer = fast_tokenizer

if slow_tokenizer is not None:
kwargs = copy.deepcopy(slow_tokenizer.init_kwargs)
kwargs.update(slow_tokenizer.init_kwargs)

# We call this after having initialized the backend tokenizer because we update it.
super().__init__(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/tokenization_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,9 @@ def __init__(
cls_token=cls_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
lang2id=lang2id,
id2lang=id2lang,
do_lowercase_and_remove_accent=do_lowercase_and_remove_accent,
**kwargs,
)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/tokenization_xlm_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,10 @@ def __init__(
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
unk_token=unk_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs,
)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/tokenization_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def __init__(
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
remove_space=remove_space,
keep_accents=keep_accents,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
Expand Down
19 changes: 19 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,25 @@ def test_rust_tokenizer_signature(self):
self.assertIn("tokenizer_file", signature.parameters)
self.assertIsNone(signature.parameters["tokenizer_file"].default)

def test_tokenizer_slow_store_full_signature(self):
signature = inspect.signature(self.tokenizer_class.__init__)
tokenizer = self.get_tokenizer()

for parameter_name, parameter in signature.parameters.items():
if parameter.default != inspect.Parameter.empty:
self.assertIn(parameter_name, tokenizer.init_kwargs)

def test_tokenizer_fast_store_full_signature(self):
if not self.test_rust_tokenizer:
return

signature = inspect.signature(self.rust_tokenizer_class.__init__)
tokenizer = self.get_rust_tokenizer()

for parameter_name, parameter in signature.parameters.items():
if parameter.default != inspect.Parameter.empty:
self.assertIn(parameter_name, tokenizer.init_kwargs)

def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
Expand Down