Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dynamic module import error #21646

Merged
merged 14 commits into from
Feb 17, 2023
36 changes: 34 additions & 2 deletions src/transformers/dynamic_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,40 @@ def get_class_in_module(class_name, module_path):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
module_dir = Path(HF_MODULES_CACHE) / os.path.dirname(module_path)
module_dir_backup_temp = str(module_dir) + "_backup_temp"
# make sure it doesn't exist yet
if os.path.isdir(module_dir_backup_temp):
shutil.rmtree(module_dir_backup_temp)
# copy to a temporary directory
shutil.copytree(module_dir, module_dir_backup_temp)

# remove `configuration.py`: this is necessary when we try to import modeling module, or other tokenizer/processor
# modules, while configuration module has been imported previously.
# TODO: This is only a simple heuristic. In general, we might need to consider any dynamic module that has been
# imported. However, we don't have this information so far.
if os.path.isfile(f"{module_dir}/configuration.py"):
os.remove(f"{module_dir}/configuration.py")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very weird and way to specific. Just because the tests call the file configuration doesn't mean it will always be called this way.

Copy link
Collaborator Author

@ydshieh ydshieh Feb 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer need to deal with this specific file, but the same trick is required for the module file (that we want to import)

# This has to be deleted too!
if os.path.isdir(f"{module_dir}/__pycache__"):
shutil.rmtree(f"{module_dir}/__pycache__")

# copy back the target module file - and ONLY this single file
# Without this hack, we may get error: `ModuleNotFoundError: No module named 'transformers_modules.local.modeling'`
module_file_name = module_path.split(os.path.sep)[-1] + ".py"
shutil.copy(os.path.join(module_dir_backup_temp, module_file_name), module_dir)

# import the module
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)

# copy the deleted file back
if os.path.isfile(f"{module_dir_backup_temp}/configuration.py"):
shutil.copy(f"{module_dir_backup_temp}/configuration.py", module_dir)

# remove the backup directory
shutil.rmtree(module_dir_backup_temp)

return getattr(module, class_name)


Expand Down Expand Up @@ -212,7 +244,7 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
submodule = "local"
submodule = f"local_{pretrained_model_name_or_path.replace(os.path.sep, '_')}"
Copy link
Collaborator Author

@ydshieh ydshieh Feb 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger You already mentioned this in your comment. As I said, the issue doesn't seem come from the concurrent file operations. However, the fix I implemented in this PR add more operations to the module directory, and at some point it looks getting some race condition (not 100% confident).

Therefore, I move forward to make the module directory depending on pretrained_model_name_or_path, but I need to add replace(os.path.sep, '_') to avoid the case where pretrained_model_name_or_path being like /tmp/xxxyyy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just taje the xxxyyy which should solve the issue for the tests (since they are all in tmp dirs that have unique names).

Copy link
Collaborator Author

@ydshieh ydshieh Feb 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger Sorry, but what is taje the xxxyyy?

Regarding they are all in tmp dirs that have unique names -> should solve the issue for the tests:
I guess what I did here also gives the unique names (during testing), but without the (latest) changes in get_class_in_module, we still get the same issue, as I already run it several times.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you ever want to double check: run this code snippet

This test issue is really tricky to reproduce

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're the one who called your folder /tmp/xxxyyy in your first comment. I'm just saying you should take the last part, so pretrained_model_name_or_path.split(os.path.sep)[-1]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

(My brain also has tmp memory regarding xxxyyy)

else:
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)

Expand Down Expand Up @@ -240,7 +272,7 @@ def get_cached_module_file(
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
if submodule == "local":
if submodule == f"local_{pretrained_model_name_or_path.replace(os.path.sep, '_')}":
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
# that hash, to only copy when there is a modification but it seems overkill for now).
# The only reason we do the copy is to avoid putting too many folders in sys.path.
Expand Down