From de3725f2e811071893300c28278a9fffa4ab61a8 Mon Sep 17 00:00:00 2001 From: europeanplaice Date: Fri, 9 Jul 2021 06:13:01 +0900 Subject: [PATCH 1/9] Detect mismatch by analyzing config --- src/transformers/tokenization_utils_base.py | 21 ++++++++++++++++++++- tests/test_tokenization_base.py | 17 +++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 tests/test_tokenization_base.py diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index bd5642bb352f4d..5741cec64f7359 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -28,7 +28,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np - import requests from .file_utils import ( @@ -111,6 +110,7 @@ class EncodingFast: SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" ADDED_TOKENS_FILE = "added_tokens.json" TOKENIZER_CONFIG_FILE = "tokenizer_config.json" +CONFIG_FILE = "config.json" # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file FULL_TOKENIZER_FILE = "tokenizer.json" @@ -1639,6 +1639,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, "tokenizer_config_file": TOKENIZER_CONFIG_FILE, "tokenizer_file": FULL_TOKENIZER_FILE, + "config_file": CONFIG_FILE, } # Look for the tokenizer files for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items(): @@ -1742,9 +1743,11 @@ def _from_pretrained( # Prepare tokenizer initialization kwargs # Did we saved some inputs and kwargs to reload ? tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None) + config_tokenizer_class = None if tokenizer_config_file is not None: with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: init_kwargs = json.load(tokenizer_config_handle) + config_tokenizer_class = init_kwargs.get("tokenizer_class") init_kwargs.pop("tokenizer_class", None) saved_init_inputs = init_kwargs.pop("init_inputs", ()) if not init_inputs: @@ -1752,6 +1755,22 @@ def _from_pretrained( else: init_kwargs = init_configuration + if tokenizer_config_file is None or config_tokenizer_class is None: + config_file = resolved_vocab_files.pop("config_file", None) + if config_file is not None: + with open(config_file, encoding="utf-8") as config_handle: + config_dict = json.load(config_handle) + config_tokenizer_class = config_dict.get("tokenizer_class") + + if config_tokenizer_class is not None: + if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""): + raise ValueError( + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. " + "It may result in unexpected tokenization. \n" + f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n" + f"The class this function is called from is '{cls.__name__}'." + ) + # Update with newly provided kwargs init_kwargs.update(kwargs) diff --git a/tests/test_tokenization_base.py b/tests/test_tokenization_base.py new file mode 100644 index 00000000000000..f5be88fe911c0c --- /dev/null +++ b/tests/test_tokenization_base.py @@ -0,0 +1,17 @@ +import unittest + +from transformers.models.bert.tokenization_bert import BertTokenizer +from transformers.models.bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer + + +class ClassMismatchTest(unittest.TestCase): + def test_mismatch_error(self): + PRETRAINED_MODEL = "cl-tohoku/bert-base-japanese" + with self.assertRaises(ValueError): + BertTokenizer.from_pretrained(PRETRAINED_MODEL) + + def test_limit_of_match_validation(self): + # Can't detect mismatch because this model's config + # doesn't have information about the tokenizer model. + PRETRAINED_MODEL = "bert-base-uncased" + BertJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL) From 4fc28ed5a799358757a37059bbb32b4b167d5991 Mon Sep 17 00:00:00 2001 From: europeanplaice Date: Sat, 10 Jul 2021 06:13:22 +0900 Subject: [PATCH 2/9] Fix comment --- .../models/bert_japanese/tokenization_bert_japanese.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bert_japanese/tokenization_bert_japanese.py b/src/transformers/models/bert_japanese/tokenization_bert_japanese.py index be62e92e059e8a..ecd7df9b03227b 100644 --- a/src/transformers/models/bert_japanese/tokenization_bert_japanese.py +++ b/src/transformers/models/bert_japanese/tokenization_bert_japanese.py @@ -132,7 +132,7 @@ def __init__( if not os.path.isfile(vocab_file): raise ValueError( f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " - "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" ) self.vocab = load_vocab(vocab_file) self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) From faa4fa34c737cc218ba3a127f1a41dc060572fe6 Mon Sep 17 00:00:00 2001 From: europeanplaice Date: Sat, 10 Jul 2021 06:48:21 +0900 Subject: [PATCH 3/9] Fix import --- src/transformers/tokenization_utils_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 5741cec64f7359..1488c7e322341d 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -28,6 +28,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np + import requests from .file_utils import ( From 05b43e8741141c9618604e5dc43f6b8f5f420c9b Mon Sep 17 00:00:00 2001 From: Tomohiro Endo Date: Tue, 13 Jul 2021 22:10:59 +0900 Subject: [PATCH 4/9] Update src/transformers/tokenization_utils_base.py Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com> --- src/transformers/tokenization_utils_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 1488c7e322341d..9f6f8946d53e32 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1765,7 +1765,7 @@ def _from_pretrained( if config_tokenizer_class is not None: if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""): - raise ValueError( + logger.warning( "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. " "It may result in unexpected tokenization. \n" f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n" From b998ac58f9229dd0b9f8829ca6ad8ad554b0cc71 Mon Sep 17 00:00:00 2001 From: europeanplaice Date: Wed, 14 Jul 2021 00:02:59 +0900 Subject: [PATCH 5/9] Revise based on reviews --- src/transformers/tokenization_utils_base.py | 45 ++++++++++++++++----- tests/test_tokenization_base.py | 17 -------- tests/test_tokenization_bert_japanese.py | 21 ++++++++++ tests/test_tokenization_common.py | 38 +++++++++++++++++ 4 files changed, 94 insertions(+), 27 deletions(-) delete mode 100644 tests/test_tokenization_base.py diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 1488c7e322341d..44f7eff93272e6 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -111,7 +111,6 @@ class EncodingFast: SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" ADDED_TOKENS_FILE = "added_tokens.json" TOKENIZER_CONFIG_FILE = "tokenizer_config.json" -CONFIG_FILE = "config.json" # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file FULL_TOKENIZER_FILE = "tokenizer.json" @@ -1640,7 +1639,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, "tokenizer_config_file": TOKENIZER_CONFIG_FILE, "tokenizer_file": FULL_TOKENIZER_FILE, - "config_file": CONFIG_FILE, } # Look for the tokenizer files for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items(): @@ -1744,28 +1742,55 @@ def _from_pretrained( # Prepare tokenizer initialization kwargs # Did we saved some inputs and kwargs to reload ? tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None) - config_tokenizer_class = None if tokenizer_config_file is not None: with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: init_kwargs = json.load(tokenizer_config_handle) + # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers. config_tokenizer_class = init_kwargs.get("tokenizer_class") init_kwargs.pop("tokenizer_class", None) saved_init_inputs = init_kwargs.pop("init_inputs", ()) if not init_inputs: init_inputs = saved_init_inputs else: + config_tokenizer_class = None init_kwargs = init_configuration - if tokenizer_config_file is None or config_tokenizer_class is None: - config_file = resolved_vocab_files.pop("config_file", None) - if config_file is not None: - with open(config_file, encoding="utf-8") as config_handle: - config_dict = json.load(config_handle) - config_tokenizer_class = config_dict.get("tokenizer_class") + if config_tokenizer_class is None: + from .models.auto.configuration_auto import AutoConfig + + # Second attempt. If we have not yet found tokenizer_class, let's try to use the config. + try: + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config_tokenizer_class = config.tokenizer_class + except ValueError: + # skip if config.json doesn't exist + config = None + if config_tokenizer_class is None: + # Third attempt. If we have not yet found the original type of the tokenizer, + # we are loading we see if we can infer it from the type of the configuration file + from .models.auto.configuration_auto import CONFIG_MAPPING + from .models.auto.tokenization_auto import TOKENIZER_MAPPING + + if hasattr(config, "model_type"): + config_class = CONFIG_MAPPING[config.model_type] + else: + # Fallback: use pattern matching on the string. + config_class = None + for pattern, config_class_tmp in CONFIG_MAPPING.items(): + if pattern in str(pretrained_model_name_or_path): + config_class = config_class_tmp + break + + if config_class in TOKENIZER_MAPPING.keys(): + config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class] + if config_tokenizer_class is not None: + config_tokenizer_class = config_tokenizer_class.__name__ + else: + config_tokenizer_class = config_tokenizer_class_fast.__name__ if config_tokenizer_class is not None: if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""): - raise ValueError( + logger.warning( "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. " "It may result in unexpected tokenization. \n" f"The tokenizer class you load from this checkpoint is '{config_tokenizer_class}'. \n" diff --git a/tests/test_tokenization_base.py b/tests/test_tokenization_base.py deleted file mode 100644 index f5be88fe911c0c..00000000000000 --- a/tests/test_tokenization_base.py +++ /dev/null @@ -1,17 +0,0 @@ -import unittest - -from transformers.models.bert.tokenization_bert import BertTokenizer -from transformers.models.bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer - - -class ClassMismatchTest(unittest.TestCase): - def test_mismatch_error(self): - PRETRAINED_MODEL = "cl-tohoku/bert-base-japanese" - with self.assertRaises(ValueError): - BertTokenizer.from_pretrained(PRETRAINED_MODEL) - - def test_limit_of_match_validation(self): - # Can't detect mismatch because this model's config - # doesn't have information about the tokenizer model. - PRETRAINED_MODEL = "bert-base-uncased" - BertJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL) diff --git a/tests/test_tokenization_bert_japanese.py b/tests/test_tokenization_bert_japanese.py index b42a14314a4ea2..59942258584f62 100644 --- a/tests/test_tokenization_bert_japanese.py +++ b/tests/test_tokenization_bert_japanese.py @@ -22,6 +22,7 @@ from transformers.models.bert_japanese.tokenization_bert_japanese import ( VOCAB_FILES_NAMES, BertJapaneseTokenizer, + BertTokenizer, CharacterTokenizer, MecabTokenizer, WordpieceTokenizer, @@ -278,3 +279,23 @@ def test_tokenizer_bert_japanese(self): EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese" tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID) self.assertIsInstance(tokenizer, BertJapaneseTokenizer) + + +class BertTokenizerMismatchTest(unittest.TestCase): + def test_tokenizer_mismatch_warning(self): + EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese" + with self.assertLogs("transformers", level="WARNING") as cm: + BertTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID) + self.assertTrue( + cm.records[0].message.startswith( + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from." + ) + ) + EXAMPLE_BERT_ID = "bert-base-cased" + with self.assertLogs("transformers", level="WARNING") as cm: + BertJapaneseTokenizer.from_pretrained(EXAMPLE_BERT_ID) + self.assertTrue( + cm.records[0].message.startswith( + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from." + ) + ) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 0a662cc62c44ac..a4520b15a69070 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -29,7 +29,10 @@ from huggingface_hub import HfApi from requests.exceptions import HTTPError from transformers import ( + AlbertTokenizer, + AlbertTokenizerFast, BertTokenizer, + BertTokenizerFast, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, @@ -3288,6 +3291,41 @@ def test_training_new_tokenizer_with_special_tokens_change(self): expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result) self.assertEqual(expected_result, decoded_input) + def test_tokenizer_mismatch_warning(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + with self.assertLogs("transformers", level="WARNING") as cm: + try: + if self.tokenizer_class == BertTokenizer: + AlbertTokenizer.from_pretrained(pretrained_name) + else: + BertTokenizer.from_pretrained(pretrained_name) + except (TypeError, AttributeError): + # Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned, + # here we just check that the warning has been logged before the error is raised + pass + finally: + self.assertTrue( + cm.records[0].message.startswith( + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from." + ) + ) + try: + if self.rust_tokenizer_class == BertTokenizerFast: + AlbertTokenizerFast.from_pretrained(pretrained_name) + else: + BertTokenizerFast.from_pretrained(pretrained_name) + except (TypeError, AttributeError): + # Some tokenizers cannot be loaded into the target tokenizer at all and errors are returned, + # here we just check that the warning has been logged before the error is raised + pass + finally: + self.assertTrue( + cm.records[0].message.startswith( + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from." + ) + ) + @is_staging_test class TokenizerPushToHubTester(unittest.TestCase): From c440af2c7e0f39a6a45362dddf002065cdcd35ea Mon Sep 17 00:00:00 2001 From: europeanplaice Date: Sat, 17 Jul 2021 13:57:20 +0900 Subject: [PATCH 6/9] remove kwargs --- src/transformers/tokenization_utils_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 10b01a0706b68b..8d161456528cd4 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1760,7 +1760,7 @@ def _from_pretrained( # Second attempt. If we have not yet found tokenizer_class, let's try to use the config. try: - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) config_tokenizer_class = config.tokenizer_class except ValueError: # skip if config.json doesn't exist From 634e31d2a4d0f0916f0ac22051db0ccf50372d29 Mon Sep 17 00:00:00 2001 From: europeanplaice Date: Sat, 17 Jul 2021 14:14:14 +0900 Subject: [PATCH 7/9] Fix exception --- src/transformers/tokenization_utils_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 8d161456528cd4..b593af4cb813f4 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1762,7 +1762,7 @@ def _from_pretrained( try: config = AutoConfig.from_pretrained(pretrained_model_name_or_path) config_tokenizer_class = config.tokenizer_class - except ValueError: + except (OSError, ValueError): # skip if config.json doesn't exist config = None if config_tokenizer_class is None: From f36c44f30fa3ea3d86c9da7f09d243dfb4bde898 Mon Sep 17 00:00:00 2001 From: europeanplaice Date: Sat, 17 Jul 2021 14:37:06 +0900 Subject: [PATCH 8/9] Fix handling exception again --- src/transformers/tokenization_utils_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index b593af4cb813f4..42871ae5587d51 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1762,8 +1762,8 @@ def _from_pretrained( try: config = AutoConfig.from_pretrained(pretrained_model_name_or_path) config_tokenizer_class = config.tokenizer_class - except (OSError, ValueError): - # skip if config.json doesn't exist + except (OSError, ValueError, KeyError): + # skip if an error occured. config = None if config_tokenizer_class is None: # Third attempt. If we have not yet found the original type of the tokenizer, @@ -1772,7 +1772,7 @@ def _from_pretrained( from .models.auto.tokenization_auto import TOKENIZER_MAPPING if hasattr(config, "model_type"): - config_class = CONFIG_MAPPING[config.model_type] + config_class = CONFIG_MAPPING.get(config.model_type) else: # Fallback: use pattern matching on the string. config_class = None From 2ac9387d51ca51f1d61acf0a7f1c8582bc271333 Mon Sep 17 00:00:00 2001 From: europeanplaice Date: Sat, 17 Jul 2021 15:03:27 +0900 Subject: [PATCH 9/9] Disable mismatch test in PreTrainedTokenizerFast --- tests/test_tokenization_fast.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index de237aac18d76c..c6472b0d8db104 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -44,6 +44,11 @@ def setUp(self): tokenizer = PreTrainedTokenizerFast.from_pretrained(model_paths[0]) tokenizer.save_pretrained(self.tmpdirname) + def test_tokenizer_mismatch_warning(self): + # We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any + # model + pass + def test_pretrained_model_lists(self): # We disable this test for PreTrainedTokenizerFast because it is the only tokenizer that is not linked to any # model