From 7f1cdf18958efef6339040ba91edb32ae7377720 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 17 Feb 2023 21:22:39 +0100 Subject: [PATCH] Fix dynamic module import error (#21646) * fix dynamic module import error --------- Co-authored-by: ydshieh --- src/transformers/dynamic_module_utils.py | 28 +++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index f3fc14838275be..9cc1f585654f8e 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -18,7 +18,9 @@ import os import re import shutil +import subprocess import sys +import tempfile from pathlib import Path from typing import Dict, Optional, Union @@ -143,9 +145,25 @@ def get_class_in_module(class_name, module_path): """ Import a module on the cache directory for modules and extract a class from it. """ - module_path = module_path.replace(os.path.sep, ".") - module = importlib.import_module(module_path) - return getattr(module, class_name) + with tempfile.TemporaryDirectory() as tmp_dir: + module_dir = Path(HF_MODULES_CACHE) / os.path.dirname(module_path) + module_file_name = module_path.split(os.path.sep)[-1] + ".py" + + # Copy to a temporary directory. We need to do this in another process to avoid strange and flaky error + # `ModuleNotFoundError: No module named 'transformers_modules.[module_dir_name].modeling'` + shutil.copy(f"{module_dir}/{module_file_name}", tmp_dir) + # On Windows, we need this character `r` before the path argument of `os.remove` + cmd = f'import os; os.remove(r"{module_dir}{os.path.sep}{module_file_name}")' + subprocess.run(["python", "-c", cmd]) + + # copy back the file that we want to import + shutil.copyfile(f"{tmp_dir}/{module_file_name}", f"{module_dir}/{module_file_name}") + + # import the module + module_path = module_path.replace(os.path.sep, ".") + module = importlib.import_module(module_path) + + return getattr(module, class_name) def get_cached_module_file( @@ -212,7 +230,7 @@ def get_cached_module_file( # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. pretrained_model_name_or_path = str(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): - submodule = "local" + submodule = pretrained_model_name_or_path.split(os.path.sep)[-1] else: submodule = pretrained_model_name_or_path.replace("/", os.path.sep) @@ -240,7 +258,7 @@ def get_cached_module_file( full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule create_dynamic_module(full_submodule) submodule_path = Path(HF_MODULES_CACHE) / full_submodule - if submodule == "local": + if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]: # We always copy local files (we could hash the file to see if there was a change, and give them the name of # that hash, to only copy when there is a modification but it seems overkill for now). # The only reason we do the copy is to avoid putting too many folders in sys.path.