Skip to content

Commit

Permalink
hooks: fix for pytorch/nvidia/torchmetrics
Browse files Browse the repository at this point in the history
  • Loading branch information
marcelotduarte committed Dec 24, 2024
1 parent 3035d11 commit 334a9b1
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 19 deletions.
46 changes: 46 additions & 0 deletions cx_Freeze/hooks/nvidia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""A collection of functions which are triggered automatically by finder when
nvidia package is included.
"""

from __future__ import annotations

from textwrap import dedent
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from cx_Freeze.finder import ModuleFinder
from cx_Freeze.module import Module


def load_nvidia(finder: ModuleFinder, module: Module) -> None:
"""Hook for nvidia."""
# include the cuda libraries as fixed libraries
source_lib = module.file.parent
if source_lib.exists():
target_lib = f"lib/{source_lib.name}"
for source in source_lib.glob("*/lib/*"):
library = source.relative_to(source_lib).as_posix()
finder.lib_files[source] = f"{target_lib}/{library}"

code_string = module.file.read_text(encoding="utf_8")
# fix for issue #2682
patch = dedent(
"""
def _cxfreeze_patch():
import ctypes
import sys
from pathlib import Path
source_lib = Path(sys.frozen_dir, "lib", "nvidia")
for source in source_lib.glob("*/lib/*"):
ctypes.CDLL(source, mode=ctypes.RTLD_GLOBAL)
_cxfreeze_patch()
"""
)
module.code = compile(
code_string + patch,
module.file.as_posix(),
"exec",
dont_inherit=True,
optimize=finder.optimize,
)
46 changes: 27 additions & 19 deletions cx_Freeze/hooks/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import TYPE_CHECKING

from cx_Freeze._compat import IS_MINGW, IS_WINDOWS
from cx_Freeze._compat import IS_LINUX, IS_MINGW, IS_WINDOWS

if TYPE_CHECKING:
from cx_Freeze.finder import ModuleFinder
Expand Down Expand Up @@ -37,24 +37,32 @@ def load_torch(finder: ModuleFinder, module: Module) -> None:
else:
module.in_file_system = 2

# patch the code to ignore CUDA_PATH_Vxx_x installation directory
code_string = module.file.read_text(encoding="utf_8")
code_string = code_string.replace("CUDA_PATH", "NO_CUDA_PATH")
module.code = compile(
code_string,
module.file.as_posix(),
"exec",
dont_inherit=True,
optimize=finder.optimize,
)

# include the cuda libraries as fixed libraries
source_lib = module.file.parent.parent / "nvidia"
if source_lib.exists():
target_lib = f"lib/{source_lib.name}"
for source in source_lib.glob("*/lib/*"):
target = target_lib / source.relative_to(source_lib)
finder.lib_files[source] = target.as_posix()
# has cuda libraries?
try:
finder.include_module("nvidia")
except ImportError:
pass
else:
code_string = module.file.read_text(encoding="utf_8")
# patch the code to ignore CUDA_PATH_Vxx_x installation directory
code_string = code_string.replace("CUDA_PATH", "NO_CUDA_PATH")
if IS_LINUX:
# fix for issue #2682
lines = code_string.splitlines()
for i, line in enumerate(lines[:]):
if line.strip() == "_load_global_deps()":
lines[i] = line.replace(
"_load_global_deps()",
"import nvidia; _load_global_deps()",
)
code_string = "\n".join(lines)
module.code = compile(
code_string,
module.file.as_posix(),
"exec",
dont_inherit=True,
optimize=finder.optimize,
)

# include the shared libraries in 'lib' as fixed libraries
source_lib = module.file.parent / "lib"
Expand Down

0 comments on commit 334a9b1

Please sign in to comment.