-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tokenizers class mismatch detection between cls
and checkpoint
#12619
Changes from 3 commits
de3725f
4fc28ed
faa4fa3
05b43e8
b998ac5
b2db7c5
a2050ba
7d33c28
c440af2
634e31d
f36c44f
2ac9387
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,6 +111,7 @@ class EncodingFast: | |
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" | ||
ADDED_TOKENS_FILE = "added_tokens.json" | ||
TOKENIZER_CONFIG_FILE = "tokenizer_config.json" | ||
CONFIG_FILE = "config.json" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No this file is the model configuration. It has nothing to do with the tokenizer and should not be put here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. |
||
|
||
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file | ||
FULL_TOKENIZER_FILE = "tokenizer.json" | ||
|
@@ -1639,6 +1640,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], | |
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, | ||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE, | ||
"tokenizer_file": FULL_TOKENIZER_FILE, | ||
"config_file": CONFIG_FILE, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree too. |
||
} | ||
# Look for the tokenizer files | ||
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items(): | ||
|
@@ -1742,16 +1744,34 @@ def _from_pretrained( | |
# Prepare tokenizer initialization kwargs | ||
# Did we saved some inputs and kwargs to reload ? | ||
tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None) | ||
config_tokenizer_class = None | ||
if tokenizer_config_file is not None: | ||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: | ||
init_kwargs = json.load(tokenizer_config_handle) | ||
config_tokenizer_class = init_kwargs.get("tokenizer_class") | ||
init_kwargs.pop("tokenizer_class", None) | ||
saved_init_inputs = init_kwargs.pop("init_inputs", ()) | ||
if not init_inputs: | ||
init_inputs = saved_init_inputs | ||
else: | ||
init_kwargs = init_configuration | ||
|
||
SaulLu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if tokenizer_config_file is None or config_tokenizer_class is None: | ||
SaulLu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
config_file = resolved_vocab_files.pop("config_file", None) | ||
if config_file is not None: | ||
with open(config_file, encoding="utf-8") as config_handle: | ||
config_dict = json.load(config_handle) | ||
config_tokenizer_class = config_dict.get("tokenizer_class") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should rely on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for your review. It is better than my code, and I avoided a circular import by importing |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -------- EDIT:-------- -------- Old comment:-------- It would have to be checked by running all the tests, but I have the impression that by doing the imports at this level we don't have a circular import problem. # If we have not yet found the original type of the tokenizer we are loading we see if we can infer it from the
# type of the configuration file
if config_dict is not None and config_tokenizer_class is None:
from .models.auto.configuration_auto import CONFIG_MAPPING
from .models.auto.tokenization_auto import TOKENIZER_MAPPING
if "model_type" in config_dict:
config_class = CONFIG_MAPPING[config_dict["model_type"]]
else:
# Fallback: use pattern matching on the string.
for pattern, config_class_tmp in CONFIG_MAPPING.items():
if pattern in str(pretrained_model_name_or_path):
config_class = config_class_tmp
break
if config_class in TOKENIZER_MAPPING.keys():
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class]
if config_tokenizer_class is not None:
config_tokenizer_class = config_tokenizer_class.__name__
else:
config_tokenizer_class = config_tokenizer_class_fast.__name__ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for a excellent suggestion! |
||
if config_tokenizer_class is not None: | ||
if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""): | ||
raise ValueError( | ||
europeanplaice marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. " | ||
"It may result in unexpected tokenization. \n" | ||
f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n" | ||
f"The class this function is called from is '{cls.__name__}'." | ||
Comment on lines
+1794
to
+1797
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great 👍 ! This will really help future users |
||
) | ||
|
||
# Update with newly provided kwargs | ||
init_kwargs.update(kwargs) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import unittest | ||
|
||
from transformers.models.bert.tokenization_bert import BertTokenizer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test should go in an existing test file, for instance the one already testing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll remove |
||
from transformers.models.bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer | ||
|
||
|
||
class ClassMismatchTest(unittest.TestCase): | ||
def test_mismatch_error(self): | ||
PRETRAINED_MODEL = "cl-tohoku/bert-base-japanese" | ||
with self.assertRaises(ValueError): | ||
BertTokenizer.from_pretrained(PRETRAINED_MODEL) | ||
|
||
def test_limit_of_match_validation(self): | ||
# Can't detect mismatch because this model's config | ||
# doesn't have information about the tokenizer model. | ||
PRETRAINED_MODEL = "bert-base-uncased" | ||
BertJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you very much for writing this test: we immediately understand the new feature! As the added changes concern all tokenizers, not only def test_tokenizer_mismatch_warning(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
if self.tokenizer_class != BertTokenizer:
with self.assertLogs("transformers", level="WARNING") as cm:
try:
BertTokenizer.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into `BertTokenizer` at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
if self.rust_tokenizer_class != BertTokenizerFast:
with self.assertLogs("transformers", level="WARNING") as cm:
try:
BertTokenizerFast.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into `BertTokenizerFast` at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
) What do you think? Ps: I can of course help make this change if needed, especially as an adaptation will have to be made for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's an excellent idea, and I'd like to check all tokenizers that include def test_tokenizer_mismatch_warning(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
with self.assertLogs("transformers", level="WARNING") as cm:
try:
if self.tokenizer_class == BertTokenizer:
AlbertTokenizer.from_pretrained(pretrained_name)
else:
BertTokenizer.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
)
try:
if self.rust_tokenizer_class == BertTokenizerFast:
AlbertTokenizerFast.from_pretrained(pretrained_name)
else:
BertTokenizerFast.from_pretrained(pretrained_name)
except (TypeError, AttributeError):
# Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned,
# here we just check that the warning has been logged before the error is raised
pass
finally:
self.assertTrue(
cm.records[0].message.startswith(
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
)
) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks great to me! 🙂 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that the other tokenizers in the
transformers
library specify the specific class of tokenizer here instead of the genericAutoTokenizer
. Was there any particular reason to preferAutoTokenizer
toBertJapaneseTokenizer
? 🙂There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion about it too. I chose
AutoTokenizer
because I thought leading a user toAutoTokenizer
would avoid a problem like this issue.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to encourage users to use the
AutoTokenizer
class.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your feedback @sgugger ! In that case, @europeanplaice your proposal is great - you can ignore my previous comment.
@sgugger, Should we take this opportunity to make the same change with other tokenizers that log the same type of message (cf PR #12745)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that was a great idea!