From cf81a59a1f777e86d95a8403b772696dda86a7d2 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 11 Feb 2022 13:46:08 -0500 Subject: [PATCH] Fix _configuration_file argument getting passed to model (#15629) --- src/transformers/configuration_utils.py | 2 +- tests/test_configuration_common.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 580de6b91ee553..8548281631ffba 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -580,7 +580,7 @@ def _get_config_dict( if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path else: - configuration_file = kwargs.get("_configuration_file", CONFIG_NAME) + configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, configuration_file) diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 2b4a023d91c05c..a073c5250746fa 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -334,8 +334,12 @@ def test_repo_versioning_before(self): import transformers as new_transformers new_transformers.configuration_utils.__version__ = "v4.0.0" - new_configuration = new_transformers.models.auto.AutoConfig.from_pretrained(repo) + new_configuration, kwargs = new_transformers.models.auto.AutoConfig.from_pretrained( + repo, return_unused_kwargs=True + ) self.assertEqual(new_configuration.hidden_size, 2) + # This checks `_configuration_file` ia not kept in the kwargs by mistake. + self.assertDictEqual(kwargs, {"_from_auto": True}) # Testing an older version by monkey-patching the version in the module it's used. import transformers as old_transformers