Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tokenizers class mismatch detection between cls and checkpoint #12619

Merged
merged 12 commits into from
Jul 17, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
"model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that the other tokenizers in the transformers library specify the specific class of tokenizer here instead of the generic AutoTokenizer. Was there any particular reason to prefer AutoTokenizer to BertJapaneseTokenizer? 🙂

Suggested change
"model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
"model use `tokenizer = BertJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion about it too. I chose AutoTokenizer because I thought leading a user to AutoTokenizer would avoid a problem like this issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to encourage users to use the AutoTokenizer class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your feedback @sgugger ! In that case, @europeanplaice your proposal is great - you can ignore my previous comment.

@sgugger, Should we take this opportunity to make the same change with other tokenizers that log the same type of message (cf PR #12745)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was a great idea!

)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
Expand Down
45 changes: 45 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,13 +1745,58 @@ def _from_pretrained(
if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle)
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
config_tokenizer_class = init_kwargs.get("tokenizer_class")
init_kwargs.pop("tokenizer_class", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs:
init_inputs = saved_init_inputs
else:
config_tokenizer_class = None
init_kwargs = init_configuration

SaulLu marked this conversation as resolved.
Show resolved Hide resolved
if config_tokenizer_class is None:
from .models.auto.configuration_auto import AutoConfig

# Second attempt. If we have not yet found tokenizer_class, let's try to use the config.
try:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
config_tokenizer_class = config.tokenizer_class
except (OSError, ValueError, KeyError):
# skip if an error occured.
config = None
if config_tokenizer_class is None:
# Third attempt. If we have not yet found the original type of the tokenizer,
# we are loading we see if we can infer it from the type of the configuration file
from .models.auto.configuration_auto import CONFIG_MAPPING
from .models.auto.tokenization_auto import TOKENIZER_MAPPING

if hasattr(config, "model_type"):
config_class = CONFIG_MAPPING.get(config.model_type)
else:
# Fallback: use pattern matching on the string.
config_class = None
for pattern, config_class_tmp in CONFIG_MAPPING.items():
if pattern in str(pretrained_model_name_or_path):
config_class = config_class_tmp
break

if config_class in TOKENIZER_MAPPING.keys():
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class]
if config_tokenizer_class is not None:
config_tokenizer_class = config_tokenizer_class.__name__
else:
config_tokenizer_class = config_tokenizer_class_fast.__name__

Copy link
Contributor

@SaulLu SaulLu Jul 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-------- EDIT:--------
Reading @sgugger 's answer, I also agree with him that we can simplify this part and use AutoConfig directly.

-------- Old comment:--------
The addition of the snippet below could therefore solve the limitation that you have shown in the test that you named test_limit_of_match_validation.

It would have to be checked by running all the tests, but I have the impression that by doing the imports at this level we don't have a circular import problem.

        # If we have not yet found the original type of the tokenizer we are loading we see if we can infer it from the
        # type of the configuration file
        if config_dict is not None and config_tokenizer_class is None:
            from .models.auto.configuration_auto import CONFIG_MAPPING
            from .models.auto.tokenization_auto import TOKENIZER_MAPPING
            
            if "model_type" in config_dict:
                config_class = CONFIG_MAPPING[config_dict["model_type"]]
            else:
                # Fallback: use pattern matching on the string.
                for pattern, config_class_tmp in CONFIG_MAPPING.items():
                    if pattern in str(pretrained_model_name_or_path):
                        config_class = config_class_tmp
                        break

            if config_class in TOKENIZER_MAPPING.keys():
                config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class]
                if config_tokenizer_class is not None:
                    config_tokenizer_class = config_tokenizer_class.__name__
                else:
                    config_tokenizer_class = config_tokenizer_class_fast.__name__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for a excellent suggestion!

if config_tokenizer_class is not None:
if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):
logger.warning(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. "
"It may result in unexpected tokenization. \n"
f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n"
f"The class this function is called from is '{cls.__name__}'."
Comment on lines +1794 to +1797
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great 👍 ! This will really help future users

)

# Update with newly provided kwargs
init_kwargs.update(kwargs)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_tokenization_bert_japanese.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers.models.bert_japanese.tokenization_bert_japanese import (
VOCAB_FILES_NAMES,
BertJapaneseTokenizer,
BertTokenizer,
CharacterTokenizer,
MecabTokenizer,
WordpieceTokenizer,
Expand Down Expand Up @@ -278,3 +279,23 @@ def test_tokenizer_bert_japanese(self):
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)


class BertTokenizerMismatchTest(unittest.TestCase):
def test_tokenizer_mismatch_warning(self):
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
with self.assertLogs("transformers", level="WARNING") as cm:
BertTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
EXAMPLE_BERT_ID = "bert-base-cased"
with self.assertLogs("transformers", level="WARNING") as cm:
BertJapaneseTokenizer.from_pretrained(EXAMPLE_BERT_ID)
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
38 changes: 38 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import (
AlbertTokenizer,
AlbertTokenizerFast,
BertTokenizer,
BertTokenizerFast,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
Expand Down Expand Up @@ -3288,6 +3291,41 @@ def test_training_new_tokenizer_with_special_tokens_change(self):
expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
self.assertEqual(expected_result, decoded_input)

def test_tokenizer_mismatch_warning(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
with self.assertLogs("transformers", level="WARNING") as cm:
try:
if self.tokenizer_class == BertTokenizer:
AlbertTokenizer.from_pretrained(pretrained_name)
else:
BertTokenizer.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
try:
if self.rust_tokenizer_class == BertTokenizerFast:
AlbertTokenizerFast.from_pretrained(pretrained_name)
else:
BertTokenizerFast.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)


@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):
Expand Down
5 changes: 5 additions & 0 deletions tests/test_tokenization_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def setUp(self):
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_paths[0])
tokenizer.save_pretrained(self.tmpdirname)

def test_tokenizer_mismatch_warning(self):
# We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any
# model
pass
Comment on lines +47 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


def test_pretrained_model_lists(self):
# We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any
# model
Expand Down