Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dynamic module import error #21646

Merged
merged 14 commits into from
Feb 17, 2023
28 changes: 23 additions & 5 deletions src/transformers/dynamic_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import os
import re
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Dict, Optional, Union

Expand Down Expand Up @@ -143,9 +145,25 @@ def get_class_in_module(class_name, module_path):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
return getattr(module, class_name)
with tempfile.TemporaryDirectory() as tmp_dir:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not the location to load the module. It's just to hold the file temporarily , and it will be copied back to the original place.

module_dir = Path(HF_MODULES_CACHE) / os.path.dirname(module_path)
module_file_name = module_path.split(os.path.sep)[-1] + ".py"

# Copy to a temporary directory. We need to do this in another process to avoid strange and flaky error
# `ModuleNotFoundError: No module named 'transformers_modules.[module_dir_name].modeling'`
shutil.copy(f"{module_dir}/{module_file_name}", tmp_dir)
# On Windows, we need this character `r` before the path argument of `os.remove`
cmd = f'import os; os.remove(r"{module_dir}{os.path.sep}{module_file_name}")'
subprocess.run(["python", "-c", cmd])
Copy link
Collaborator Author

@ydshieh ydshieh Feb 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If something goes wrong in the subprocess.run, no error will be thrown (in the process that calls this method).
I think we should capture/check the output of subprocess.run, and do something:

  • either: not to call shutil.copyfile below (although this makes the test flaky in this logic branch)
  • or: throw an error manually with some information

Let me know if you have any suggestion :-)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do something? If there is a problem deleting the file (which we copy just after), at worst we get the flaky failure again (though it should be extremely rare at this stage).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, right!


# copy back the file that we want to import
shutil.copyfile(f"{tmp_dir}/{module_file_name}", f"{module_dir}/{module_file_name}")

# import the module
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)

return getattr(module, class_name)


def get_cached_module_file(
Expand Down Expand Up @@ -212,7 +230,7 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
submodule = "local"
submodule = pretrained_model_name_or_path.split(os.path.sep)[-1]
else:
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)

Expand Down Expand Up @@ -240,7 +258,7 @@ def get_cached_module_file(
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
if submodule == "local":
if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]:
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
# that hash, to only copy when there is a modification but it seems overkill for now).
# The only reason we do the copy is to avoid putting too many folders in sys.path.
Expand Down