-
-
Notifications
You must be signed in to change notification settings - Fork 224
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
hooks: fix for pytorch/nvidia/torchmetrics
- Loading branch information
1 parent
3035d11
commit 334a9b1
Showing
2 changed files
with
73 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""A collection of functions which are triggered automatically by finder when | ||
nvidia package is included. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from textwrap import dedent | ||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from cx_Freeze.finder import ModuleFinder | ||
from cx_Freeze.module import Module | ||
|
||
|
||
def load_nvidia(finder: ModuleFinder, module: Module) -> None: | ||
"""Hook for nvidia.""" | ||
# include the cuda libraries as fixed libraries | ||
source_lib = module.file.parent | ||
if source_lib.exists(): | ||
target_lib = f"lib/{source_lib.name}" | ||
for source in source_lib.glob("*/lib/*"): | ||
library = source.relative_to(source_lib).as_posix() | ||
finder.lib_files[source] = f"{target_lib}/{library}" | ||
|
||
code_string = module.file.read_text(encoding="utf_8") | ||
# fix for issue #2682 | ||
patch = dedent( | ||
""" | ||
def _cxfreeze_patch(): | ||
import ctypes | ||
import sys | ||
from pathlib import Path | ||
source_lib = Path(sys.frozen_dir, "lib", "nvidia") | ||
for source in source_lib.glob("*/lib/*"): | ||
ctypes.CDLL(source, mode=ctypes.RTLD_GLOBAL) | ||
_cxfreeze_patch() | ||
""" | ||
) | ||
module.code = compile( | ||
code_string + patch, | ||
module.file.as_posix(), | ||
"exec", | ||
dont_inherit=True, | ||
optimize=finder.optimize, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters