66
77from libc.stdint cimport intptr_t
88
9- from .utils cimport get_nvvm_dso_version_suffix
10-
119from .utils import FunctionNotFoundError, NotSupportedError
1210
13- import os
14- import site
11+ from cuda.bindings import path_finder
1512
1613import win32api
1714
@@ -40,52 +37,9 @@ cdef void* __nvvmGetProgramLogSize = NULL
4037cdef void * __nvvmGetProgramLog = NULL
4138
4239
43- cdef inline list get_site_packages():
44- return [site.getusersitepackages()] + site.getsitepackages() + [" conda" ]
45-
46-
47- cdef load_library(const int driver_ver):
48- handle = 0
49-
50- for suffix in get_nvvm_dso_version_suffix(driver_ver):
51- if len (suffix) == 0 :
52- continue
53- dll_name = " nvvm64_40_0.dll"
54-
55- # First check if the DLL has been loaded by 3rd parties
56- try :
57- return win32api.GetModuleHandle(dll_name)
58- except :
59- pass
60-
61- # Next, check if DLLs are installed via pip or conda
62- for sp in get_site_packages():
63- if sp == " conda" :
64- # nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path
65- conda_prefix = os.environ.get(" CONDA_PREFIX" )
66- if conda_prefix is None :
67- continue
68- mod_path = os.path.join(conda_prefix, " Library" , " nvvm" , " bin" )
69- else :
70- mod_path = os.path.join(sp, " nvidia" , " cuda_nvcc" , " nvvm" , " bin" )
71- if os.path.isdir(mod_path):
72- os.add_dll_directory(mod_path)
73- try :
74- return win32api.LoadLibraryEx(
75- # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
76- os.path.join(mod_path, dll_name),
77- 0 , LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
78- except :
79- pass
80-
81- # Finally, try default search
82- # Only reached if DLL wasn't found in any site-package path
83- try :
84- return win32api.LoadLibrary(dll_name)
85- except :
86- pass
87-
88- raise RuntimeError (' Failed to load nvvm' )
40+ cdef void * load_library(int driver_ver) except * with gil:
41+ cdef intptr_t handle = path_finder._load_nvidia_dynamic_library(" nvvm" ).handle
42+ return < void * > handle
8943
9044
9145cdef int _check_or_init_nvvm() except - 1 nogil:
@@ -94,23 +48,24 @@ cdef int _check_or_init_nvvm() except -1 nogil:
9448 return 0
9549
9650 cdef int err, driver_ver
51+ cdef intptr_t handle
9752 with gil:
9853 # Load driver to check version
9954 try :
100- handle = win32api.LoadLibraryEx(" nvcuda.dll" , 0 , LOAD_LIBRARY_SEARCH_SYSTEM32)
55+ nvcuda_handle = win32api.LoadLibraryEx(" nvcuda.dll" , 0 , LOAD_LIBRARY_SEARCH_SYSTEM32)
10156 except Exception as e:
10257 raise NotSupportedError(f' CUDA driver is not found ({e})' )
10358 global __cuDriverGetVersion
10459 if __cuDriverGetVersion == NULL :
105- __cuDriverGetVersion = < void * >< intptr_t> win32api.GetProcAddress(handle , ' cuDriverGetVersion' )
60+ __cuDriverGetVersion = < void * >< intptr_t> win32api.GetProcAddress(nvcuda_handle , ' cuDriverGetVersion' )
10661 if __cuDriverGetVersion == NULL :
10762 raise RuntimeError (' something went wrong' )
10863 err = (< int (* )(int * ) noexcept nogil> __cuDriverGetVersion)(& driver_ver)
10964 if err != 0 :
11065 raise RuntimeError (' something went wrong' )
11166
11267 # Load library
113- handle = load_library(driver_ver)
68+ handle = < intptr_t > load_library(driver_ver)
11469
11570 # Load function
11671 global __nvvmVersion
0 commit comments