-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tokenizers class mismatch detection between cls
and checkpoint
#12619
Add tokenizers class mismatch detection between cls
and checkpoint
#12619
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for working on the implementation of this check. 😄
I mainly left comments to 1) allow retrieving the tokenizer type from the config type and 2) extend the tests.
Feel free to let me know if I missed anything or if you want me to take over the extension of the tests!
with open(config_file, encoding="utf-8") as config_handle: | ||
config_dict = json.load(config_handle) | ||
config_tokenizer_class = config_dict.get("tokenizer_class") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-------- EDIT:--------
Reading @sgugger 's answer, I also agree with him that we can simplify this part and use AutoConfig
directly.
-------- Old comment:--------
The addition of the snippet below could therefore solve the limitation that you have shown in the test that you named test_limit_of_match_validation
.
It would have to be checked by running all the tests, but I have the impression that by doing the imports at this level we don't have a circular import problem.
# If we have not yet found the original type of the tokenizer we are loading we see if we can infer it from the
# type of the configuration file
if config_dict is not None and config_tokenizer_class is None:
from .models.auto.configuration_auto import CONFIG_MAPPING
from .models.auto.tokenization_auto import TOKENIZER_MAPPING
if "model_type" in config_dict:
config_class = CONFIG_MAPPING[config_dict["model_type"]]
else:
# Fallback: use pattern matching on the string.
for pattern, config_class_tmp in CONFIG_MAPPING.items():
if pattern in str(pretrained_model_name_or_path):
config_class = config_class_tmp
break
if config_class in TOKENIZER_MAPPING.keys():
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class]
if config_tokenizer_class is not None:
config_tokenizer_class = config_tokenizer_class.__name__
else:
config_tokenizer_class = config_tokenizer_class_fast.__name__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for a excellent suggestion!
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. " | ||
"It may result in unexpected tokenization. \n" | ||
f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n" | ||
f"The class this function is called from is '{cls.__name__}'." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great 👍 ! This will really help future users
tests/test_tokenization_base.py
Outdated
import unittest | ||
|
||
from transformers.models.bert.tokenization_bert import BertTokenizer | ||
from transformers.models.bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer | ||
|
||
|
||
class ClassMismatchTest(unittest.TestCase): | ||
def test_mismatch_error(self): | ||
PRETRAINED_MODEL = "cl-tohoku/bert-base-japanese" | ||
with self.assertRaises(ValueError): | ||
BertTokenizer.from_pretrained(PRETRAINED_MODEL) | ||
|
||
def test_limit_of_match_validation(self): | ||
# Can't detect mismatch because this model's config | ||
# doesn't have information about the tokenizer model. | ||
PRETRAINED_MODEL = "bert-base-uncased" | ||
BertJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for writing this test: we immediately understand the new feature!
As the added changes concern all tokenizers, not only BertTokenizer
and BertJapaneseTokenizer
, I think it would be interesting to test the warning logged on all tokenizers by adding a new test to TokenizerTesterMixin
in the test_tokenization_common.py
file. This new test could for example look like something like:
def test_tokenizer_mismatch_warning(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
if self.tokenizer_class != BertTokenizer:
with self.assertLogs("transformers", level="WARNING") as cm:
try:
BertTokenizer.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into `BertTokenizer` at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
if self.rust_tokenizer_class != BertTokenizerFast:
with self.assertLogs("transformers", level="WARNING") as cm:
try:
BertTokenizerFast.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into `BertTokenizerFast` at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
What do you think?
Ps: I can of course help make this change if needed, especially as an adaptation will have to be made for PreTrainedTokenizerFast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an excellent idea, and I'd like to check all tokenizers that include BertTokenizer
and BertJapaneseTokenizer
at this test. I changed your suggestion to this. Is this missing something needed to test?
def test_tokenizer_mismatch_warning(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
with self.assertLogs("transformers", level="WARNING") as cm:
try:
if self.tokenizer_class == BertTokenizer:
AlbertTokenizer.from_pretrained(pretrained_name)
else:
BertTokenizer.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
try:
if self.rust_tokenizer_class == BertTokenizerFast:
AlbertTokenizerFast.from_pretrained(pretrained_name)
else:
BertTokenizerFast.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me! 🙂
@@ -132,7 +132,7 @@ def __init__( | |||
if not os.path.isfile(vocab_file): | |||
raise ValueError( | |||
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " | |||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" | |||
"model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that the other tokenizers in the transformers
library specify the specific class of tokenizer here instead of the generic AutoTokenizer
. Was there any particular reason to prefer AutoTokenizer
to BertJapaneseTokenizer
? 🙂
"model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" | |
"model use `tokenizer = BertJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion about it too. I chose AutoTokenizer
because I thought leading a user to AutoTokenizer
would avoid a problem like this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to encourage users to use the AutoTokenizer
class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your feedback @sgugger ! In that case, @europeanplaice your proposal is great - you can ignore my previous comment.
@sgugger, Should we take this opportunity to make the same change with other tokenizers that log the same type of message (cf PR #12745)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that was a great idea!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. I think it needs to be adjusted a bit to keep model configuration independent from tokenizers, as much as possible.
@@ -111,6 +111,7 @@ class EncodingFast: | |||
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" | |||
ADDED_TOKENS_FILE = "added_tokens.json" | |||
TOKENIZER_CONFIG_FILE = "tokenizer_config.json" | |||
CONFIG_FILE = "config.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No this file is the model configuration. It has nothing to do with the tokenizer and should not be put here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. AutoConfig.from_pretrained
makes this line unnecessary.
@@ -1639,6 +1640,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], | |||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, | |||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE, | |||
"tokenizer_file": FULL_TOKENIZER_FILE, | |||
"config_file": CONFIG_FILE, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree too.
if tokenizer_config_file is None or config_tokenizer_class is None: | ||
config_file = resolved_vocab_files.pop("config_file", None) | ||
if config_file is not None: | ||
with open(config_file, encoding="utf-8") as config_handle: | ||
config_dict = json.load(config_handle) | ||
config_tokenizer_class = config_dict.get("tokenizer_class") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should rely on AutoConfig.from_pretrained
for this blob (inside a try block).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your review. It is better than my code, and I avoided a circular import by importing AutoConfig
inside _from_pretrained
not at the top level (thanks to #12619 (comment)).
tests/test_tokenization_base.py
Outdated
@@ -0,0 +1,17 @@ | |||
import unittest | |||
|
|||
from transformers.models.bert.tokenization_bert import BertTokenizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test should go in an existing test file, for instance the one already testing BertJapaneseTokenizer
or common tokenizer test file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove test_tokenization_base.py
, and introduce #12619 (comment) 's test instead of this.
Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
…aice/transformers into tokenizer_class_check
I revised the code based on your reviews. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is still a last failure in the tests, I left a pointer below from where it originates.
I'm thinking this logic of getting the tokenizer class (that is used here and in the AutoClass) could probably be refactored in a function, I can do that in a follow up if you prefer. Let me know!
|
||
# Second attempt. If we have not yet found tokenizer_class, let's try to use the config. | ||
try: | ||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing the kwargs along here seems to break the tests. Also I don't think we need them?
I want to ask you to refactor the logic. |
@SaulLu could you confirm you're happy with the changes? I think this is good to be merged on my side, thanks for the adjustments @europeanplaice. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also share your opinion @sgugger ! Thanks a lot for the addition @europeanplaice , it's great to have this warning logged!
def test_tokenizer_mismatch_warning(self): | ||
# We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any | ||
# model | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
What does this PR do?
Fixes #12416
This PR detects a mismatch between
cls
and a checkpoint a user intends to load.However, It can't find a mismatch when a config doesn't contain the tokenizer's information.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.