From af20bbb3188a6ffeaa126fa5118c9cabb529c26a Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 16 Sep 2022 16:11:47 -0400 Subject: [PATCH] Fix tokenizer load from one file (#19073) * Fix tokenizer load from one file * Add a test * Style Co-authored-by: Lysandre --- src/transformers/tokenization_utils_base.py | 2 ++ tests/test_tokenization_common.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 5062a7bfb99991..2e7ac0be0fb29a 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1726,6 +1726,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], for file_id, file_path in vocab_files.items(): if file_path is None: resolved_vocab_files[file_id] = None + elif os.path.isfile(file_path): + resolved_vocab_files[file_id] = file_path elif is_remote_url(file_path): resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies) else: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index bdb7b6ce673896..ce6908038628a6 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -31,6 +31,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union from huggingface_hub import HfFolder, delete_repo, set_access_token +from huggingface_hub.file_download import http_get from parameterized import parameterized from requests.exceptions import HTTPError from transformers import ( @@ -3886,6 +3887,16 @@ def test_cached_files_are_used_when_internet_is_down(self): # This check we did call the fake head request mock_head.assert_called() + def test_legacy_load_from_one_file(self): + try: + tmp_file = tempfile.mktemp() + with open(tmp_file, "wb") as f: + http_get("https://huggingface.co/albert-base-v1/resolve/main/spiece.model", f) + + AlbertTokenizer.from_pretrained(tmp_file) + finally: + os.remove(tmp_file) + @is_staging_test class TokenizerPushToHubTester(unittest.TestCase):