Skip to content

Commit 74d7230

Browse files
committed
Remove os.add_dll_directory() and os.environ["PATH"] manipulations from find_nvidia_dynamic_library.py. Add supported_libs.LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY and use from load_nvidia_dynamic_library().
1 parent 38a1d6c commit 74d7230

File tree

4 files changed

+30
-16
lines changed

4 files changed

+30
-16
lines changed

cuda_bindings/cuda/bindings/_path_finder/find_nvidia_dynamic_library.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,8 @@ def _find_so_using_nvidia_lib_dirs(libname, so_basename, error_messages, attachm
3838
return None
3939

4040

41-
def _append_to_os_environ_path(dirpath):
42-
curr_path = os.environ.get("PATH")
43-
os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath))
44-
45-
4641
def _find_dll_under_dir(dirpath, file_wild):
4742
dll_name = None
48-
have_builtins = False
4943
for path in sorted(glob.glob(os.path.join(dirpath, file_wild))):
5044
# nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl:
5145
# nvidia\cuda_nvrtc\bin\
@@ -56,19 +50,12 @@ def _find_dll_under_dir(dirpath, file_wild):
5650
if node.endswith(".alt.dll"):
5751
continue
5852
if "-builtins" in node:
59-
have_builtins = True
6053
continue
6154
if dll_name is not None:
6255
continue
6356
if os.path.isfile(path):
6457
dll_name = path
65-
if dll_name is not None:
66-
if have_builtins:
67-
# Add the DLL directory to the search path
68-
os.add_dll_directory(dirpath)
69-
# Update PATH as a fallback for dependent DLL resolution
70-
_append_to_os_environ_path(dirpath)
71-
return dll_name
58+
return dll_name
7259

7360

7461
def _find_dll_using_nvidia_bin_dirs(libname, error_messages, attachments):

cuda_bindings/cuda/bindings/_path_finder/load_nvidia_dynamic_library.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import ctypes
66
import functools
7+
import os
78
import sys
89
from typing import Optional, Tuple
910

@@ -19,7 +20,6 @@
1920

2021
else:
2122
import ctypes.util
22-
import os
2323

2424
_LINUX_CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL
2525

@@ -38,7 +38,23 @@ class Dl_info(ctypes.Structure):
3838

3939

4040
from .find_nvidia_dynamic_library import _find_nvidia_dynamic_library
41-
from .supported_libs import DIRECT_DEPENDENCIES, EXPECTED_LIB_SYMBOLS, SUPPORTED_LINUX_SONAMES, SUPPORTED_WINDOWS_DLLS
41+
from .supported_libs import (
42+
DIRECT_DEPENDENCIES,
43+
EXPECTED_LIB_SYMBOLS,
44+
LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY,
45+
SUPPORTED_LINUX_SONAMES,
46+
SUPPORTED_WINDOWS_DLLS,
47+
)
48+
49+
50+
def _add_dll_directory(dll_abs_path):
51+
dirpath = os.path.dirname(dll_abs_path)
52+
assert os.path.isdir(dirpath), dll_abs_path
53+
# Add the DLL directory to the search path
54+
os.add_dll_directory(dirpath)
55+
# Update PATH as a fallback for dependent DLL resolution
56+
curr_path = os.environ.get("PATH")
57+
os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath))
4258

4359

4460
@functools.cache
@@ -137,6 +153,8 @@ def load_nvidia_dynamic_library(libname: str) -> int:
137153
found.raise_if_abs_path_is_None()
138154

139155
if sys.platform == "win32":
156+
if libname in LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY:
157+
_add_dll_directory(found.abs_path)
140158
flags = _WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | _WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR
141159
try:
142160
handle = win32api.LoadLibraryEx(found.abs_path, 0, flags)

cuda_bindings/cuda/bindings/_path_finder/supported_libs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,11 @@
314314
),
315315
}
316316

317+
LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY = (
318+
"cufft",
319+
"nvrtc",
320+
)
321+
317322
# Based on nm output for Linux x86_64 /usr/local/cuda (12.8.1)
318323
EXPECTED_LIB_SYMBOLS = {
319324
"nvJitLink": ("nvJitLinkVersion",),

cuda_bindings/tests/test_path_finder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def test_all_libnames_windows_dlls_consistency():
2222
assert tuple(sorted(ALL_LIBNAMES)) == tuple(sorted(supported_libs.SUPPORTED_WINDOWS_DLLS.keys()))
2323

2424

25+
def test_all_libnames_libnames_requiring_os_add_dll_directory_consistency():
26+
assert not (set(supported_libs.LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY) - set(ALL_LIBNAMES))
27+
28+
2529
def test_all_libnames_expected_lib_symbols_consistency():
2630
assert tuple(sorted(ALL_LIBNAMES)) == tuple(sorted(supported_libs.EXPECTED_LIB_SYMBOLS.keys()))
2731

0 commit comments

Comments
 (0)