|
9 | 9 | # This code was automatically generated with version 12.8.0. Do not modify it directly. |
10 | 10 | {{if 'Windows' == platform.system()}} |
11 | 11 | import os |
12 | | -import site |
13 | | -import struct |
14 | 12 | import win32api |
15 | | -from pywintypes import error |
16 | 13 | {{else}} |
17 | 14 | cimport cuda.bindings._lib.dlfcn as dlfcn |
| 15 | +from libc.stdint cimport uintptr_t |
18 | 16 | {{endif}} |
| 17 | +from cuda.bindings import path_finder |
19 | 18 |
|
20 | 19 | cdef bint __cuPythonInit = False |
21 | 20 | {{if 'nvrtcGetErrorString' in found_functions}}cdef void *__nvrtcGetErrorString = NULL{{endif}} |
@@ -47,74 +46,17 @@ cdef bint __cuPythonInit = False |
47 | 46 |
|
48 | 47 | cdef int cuPythonInit() except -1 nogil: |
49 | 48 | {{if 'Windows' != platform.system()}} |
50 | | - cdef char* err_msg |
| 49 | + cdef void* handle = NULL |
51 | 50 | {{endif}} |
52 | 51 |
|
53 | 52 | global __cuPythonInit |
54 | 53 | if __cuPythonInit: |
55 | 54 | return 0 |
56 | 55 | __cuPythonInit = True |
57 | 56 |
|
58 | | - # Load library |
59 | | - {{if 'Windows' == platform.system()}} |
60 | | - with gil: |
61 | | - # First check if the DLL has been loaded by 3rd parties |
62 | | - try: |
63 | | - handle = win32api.GetModuleHandle("nvrtc64_120_0.dll") |
64 | | - except: |
65 | | - handle = None |
66 | | - |
67 | | - # Else try default search |
68 | | - if not handle: |
69 | | - LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000 |
70 | | - try: |
71 | | - handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) |
72 | | - except: |
73 | | - pass |
74 | | - |
75 | | - # Final check if DLLs can be found within pip installations |
76 | | - if not handle: |
77 | | - site_packages = [site.getusersitepackages()] + site.getsitepackages() |
78 | | - for sp in site_packages: |
79 | | - mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin") |
80 | | - if not os.path.isdir(mod_path): |
81 | | - continue |
82 | | - os.add_dll_directory(mod_path) |
83 | | - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000 |
84 | | - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100 |
85 | | - try: |
86 | | - handle = win32api.LoadLibraryEx( |
87 | | - # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... |
88 | | - os.path.join(mod_path, "nvrtc64_120_0.dll"), |
89 | | - 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) |
90 | | - |
91 | | - # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is |
92 | | - # located in the same mod_path. |
93 | | - # Update PATH environ so that the two dlls can find each other |
94 | | - os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) |
95 | | - except: |
96 | | - pass |
97 | | - |
98 | | - if not handle: |
99 | | - raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll') |
100 | | - {{else}} |
101 | | - with gil: |
102 | | - print("\nLOOOK dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW)", flush=True) |
103 | | - handle = dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW) |
104 | | - if handle == NULL: |
105 | | - with gil: |
106 | | - err_msg = dlfcn.dlerror() |
107 | | - if err_msg == NULL: |
108 | | - err_msg_str = 'Unknown error' |
109 | | - else: |
110 | | - err_msg_str = err_msg.decode('utf-8', errors='backslashreplace') |
111 | | - raise RuntimeError(f'Failed to dlopen libnvrtc.so.12: {err_msg_str}') |
112 | | - {{endif}} |
113 | | - |
114 | | - |
115 | | - # Load function |
116 | 57 | {{if 'Windows' == platform.system()}} |
117 | 58 | with gil: |
| 59 | + handle = path_finder.load_nvidia_dynamic_library("nvrtc") |
118 | 60 | {{if 'nvrtcGetErrorString' in found_functions}} |
119 | 61 | try: |
120 | 62 | global __nvrtcGetErrorString |
@@ -299,6 +241,8 @@ cdef int cuPythonInit() except -1 nogil: |
299 | 241 | {{endif}} |
300 | 242 |
|
301 | 243 | {{else}} |
| 244 | + with gil: |
| 245 | + handle = <void*><uintptr_t>path_finder.load_nvidia_dynamic_library("nvrtc") |
302 | 246 | {{if 'nvrtcGetErrorString' in found_functions}} |
303 | 247 | global __nvrtcGetErrorString |
304 | 248 | __nvrtcGetErrorString = dlfcn.dlsym(handle, 'nvrtcGetErrorString') |
|
0 commit comments