diff --git a/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in b/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in index bd5e14d03..edd61ecaf 100644 --- a/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in +++ b/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in @@ -6,8 +6,10 @@ # this software and related documentation outside the terms of the EULA # is strictly prohibited. {{if 'Windows' == platform.system()}} -import win32api +import os +import site import struct +import win32api from pywintypes import error {{else}} cimport cuda.bindings._lib.dlfcn as dlfcn @@ -40,18 +42,70 @@ cdef int cuPythonInit() except -1 nogil: # Load library {{if 'Windows' == platform.system()}} - LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000 with gil: + # First check if the DLL has been loaded by 3rd parties try: - handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) + handle = win32api.GetModuleHandle("nvrtc64_112_0.dll") except: try: - handle = win32api.LoadLibraryEx("nvrtc64_111_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) + handle = win32api.GetModuleHandle("nvrtc64_111_0.dll") + except: + try: + handle = win32api.GetModuleHandle("nvrtc64_110_0.dll") + except: + handle = None + + # Else try default search + if not handle: + LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000 + try: + handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) + except: + try: + handle = win32api.LoadLibraryEx("nvrtc64_111_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) + except: + try: + handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) + except: + pass + + # Final check if DLLs can be found within pip installations + if not handle: + site_packages = [site.getusersitepackages()] + site.getsitepackages() + for sp in site_packages: + mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin") + if not os.path.isdir(mod_path): + continue + os.add_dll_directory(mod_path) + LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000 + LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100 + try: + handle = win32api.LoadLibraryEx( + # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... + os.path.join(mod_path, "nvrtc64_112_0.dll"), + 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) + + # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is + # located in the same mod_path. + # Update PATH environ so that the two dlls can find each other + os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) except: try: - handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) + handle = win32api.LoadLibraryEx( + os.path.join(mod_path, "nvrtc64_111_0.dll"), + 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) + os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) except: - raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll') + try: + handle = win32api.LoadLibraryEx( + os.path.join(mod_path, "nvrtc64_110_0.dll"), + 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) + os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) + except: + pass + + if not handle: + raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll') {{else}} handle = NULL if handle == NULL: diff --git a/cuda_bindings/pyproject.toml b/cuda_bindings/pyproject.toml index 63c09db5c..cf22cd1a7 100644 --- a/cuda_bindings/pyproject.toml +++ b/cuda_bindings/pyproject.toml @@ -32,6 +32,11 @@ dependencies = [ "pywin32; sys_platform == 'win32'", ] +[project.optional-dependencies] +all = [ + "nvidia-cuda-nvrtc-cu11" +] + [project.urls] Repository = "https://github.com/NVIDIA/cuda-python" Documentation = "https://nvidia.github.io/cuda-python/" diff --git a/cuda_bindings/setup.py b/cuda_bindings/setup.py index 13c793d28..0db64c0db 100644 --- a/cuda_bindings/setup.py +++ b/cuda_bindings/setup.py @@ -17,6 +17,7 @@ from pyclibrary import CParser from setuptools import find_packages, setup from setuptools.extension import Extension +from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.build_ext import build_ext import versioneer @@ -24,9 +25,7 @@ # ---------------------------------------------------------------------- # Fetch configuration options -CUDA_HOME = os.environ.get("CUDA_HOME") -if not CUDA_HOME: - CUDA_HOME = os.environ.get("CUDA_PATH") +CUDA_HOME = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", None)) if not CUDA_HOME: raise RuntimeError('Environment variable CUDA_HOME or CUDA_PATH is not set') @@ -236,20 +235,50 @@ def do_cythonize(extensions): extensions += prep_extensions(sources) # --------------------------------------------------------------------- -# Custom build_ext command -# Files are build in two steps: -# 1) Cythonized (in the do_cythonize() command) -# 2) Compiled to .o files as part of build_ext -# This class is solely for passing the value of nthreads to build_ext +# Custom cmdclass extensions + +building_wheel = False + + +class WheelsBuildExtensions(bdist_wheel): + def run(self): + global building_wheel + building_wheel = True + super().run() + class ParallelBuildExtensions(build_ext): def initialize_options(self): - build_ext.initialize_options(self) + super().initialize_options() if nthreads > 0: self.parallel = nthreads - def finalize_options(self): - build_ext.finalize_options(self) + def build_extension(self, ext): + if building_wheel and sys.platform == "linux": + # Strip binaries to remove debug symbols + extra_linker_flags = ["-Wl,--strip-all"] + + # Allow extensions to discover libraries at runtime + # relative their wheels installation. + if ext.name == "cuda.bindings._bindings.cynvrtc": + ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib" + else: + ldflag = None + + if ldflag: + extra_linker_flags.append(ldflag) + else: + extra_linker_flags = [] + + ext.extra_link_args += extra_linker_flags + super().build_extension(ext) + + +cmdclass = { + "bdist_wheel": WheelsBuildExtensions, + "build_ext": ParallelBuildExtensions, + } + cmdclass = {"build_ext": ParallelBuildExtensions} cmdclass = versioneer.get_cmdclass(cmdclass)