| 
 | 1 | +# Copyright 2024-2025 NVIDIA Corporation.  All rights reserved.  | 
 | 2 | +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE  | 
 | 3 | + | 
 | 4 | +import functools  | 
 | 5 | +import glob  | 
 | 6 | +import os  | 
 | 7 | + | 
 | 8 | +from cuda.bindings._path_finder.find_sub_dirs import find_sub_dirs_all_sitepackages  | 
 | 9 | +from cuda.bindings._path_finder.supported_libs import IS_WINDOWS, is_suppressed_dll_file  | 
 | 10 | + | 
 | 11 | + | 
 | 12 | +def _no_such_file_in_sub_dirs(sub_dirs, file_wild, error_messages, attachments):  | 
 | 13 | +    error_messages.append(f"No such file: {file_wild}")  | 
 | 14 | +    for sub_dir in find_sub_dirs_all_sitepackages(sub_dirs):  | 
 | 15 | +        attachments.append(f'  listdir("{sub_dir}"):')  | 
 | 16 | +        for node in sorted(os.listdir(sub_dir)):  | 
 | 17 | +            attachments.append(f"    {node}")  | 
 | 18 | + | 
 | 19 | + | 
 | 20 | +def _find_so_using_nvidia_lib_dirs(libname, so_basename, error_messages, attachments):  | 
 | 21 | +    nvidia_sub_dirs = ("nvidia", "*", "nvvm", "lib64") if libname == "nvvm" else ("nvidia", "*", "lib")  | 
 | 22 | +    file_wild = so_basename + "*"  | 
 | 23 | +    for lib_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs):  | 
 | 24 | +        # First look for an exact match  | 
 | 25 | +        so_name = os.path.join(lib_dir, so_basename)  | 
 | 26 | +        if os.path.isfile(so_name):  | 
 | 27 | +            return so_name  | 
 | 28 | +        # Look for a versioned library  | 
 | 29 | +        # Using sort here mainly to make the result deterministic.  | 
 | 30 | +        for so_name in sorted(glob.glob(os.path.join(lib_dir, file_wild))):  | 
 | 31 | +            if os.path.isfile(so_name):  | 
 | 32 | +                return so_name  | 
 | 33 | +    _no_such_file_in_sub_dirs(nvidia_sub_dirs, file_wild, error_messages, attachments)  | 
 | 34 | +    return None  | 
 | 35 | + | 
 | 36 | + | 
 | 37 | +def _find_dll_under_dir(dirpath, file_wild):  | 
 | 38 | +    for path in sorted(glob.glob(os.path.join(dirpath, file_wild))):  | 
 | 39 | +        if not os.path.isfile(path):  | 
 | 40 | +            continue  | 
 | 41 | +        if not is_suppressed_dll_file(os.path.basename(path)):  | 
 | 42 | +            return path  | 
 | 43 | +    return None  | 
 | 44 | + | 
 | 45 | + | 
 | 46 | +def _find_dll_using_nvidia_bin_dirs(libname, lib_searched_for, error_messages, attachments):  | 
 | 47 | +    nvidia_sub_dirs = ("nvidia", "*", "nvvm", "bin") if libname == "nvvm" else ("nvidia", "*", "bin")  | 
 | 48 | +    for bin_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs):  | 
 | 49 | +        dll_name = _find_dll_under_dir(bin_dir, lib_searched_for)  | 
 | 50 | +        if dll_name is not None:  | 
 | 51 | +            return dll_name  | 
 | 52 | +    _no_such_file_in_sub_dirs(nvidia_sub_dirs, lib_searched_for, error_messages, attachments)  | 
 | 53 | +    return None  | 
 | 54 | + | 
 | 55 | + | 
 | 56 | +def _get_cuda_home():  | 
 | 57 | +    cuda_home = os.environ.get("CUDA_HOME")  | 
 | 58 | +    if cuda_home is None:  | 
 | 59 | +        cuda_home = os.environ.get("CUDA_PATH")  | 
 | 60 | +    return cuda_home  | 
 | 61 | + | 
 | 62 | + | 
 | 63 | +def _find_lib_dir_using_cuda_home(libname):  | 
 | 64 | +    cuda_home = _get_cuda_home()  | 
 | 65 | +    if cuda_home is None:  | 
 | 66 | +        return None  | 
 | 67 | +    if IS_WINDOWS:  | 
 | 68 | +        subdirs = (os.path.join("nvvm", "bin"),) if libname == "nvvm" else ("bin",)  | 
 | 69 | +    else:  | 
 | 70 | +        subdirs = (  | 
 | 71 | +            (os.path.join("nvvm", "lib64"),)  | 
 | 72 | +            if libname == "nvvm"  | 
 | 73 | +            else (  | 
 | 74 | +                "lib64",  # CTK  | 
 | 75 | +                "lib",  # Conda  | 
 | 76 | +            )  | 
 | 77 | +        )  | 
 | 78 | +    for subdir in subdirs:  | 
 | 79 | +        dirname = os.path.join(cuda_home, subdir)  | 
 | 80 | +        if os.path.isdir(dirname):  | 
 | 81 | +            return dirname  | 
 | 82 | +    return None  | 
 | 83 | + | 
 | 84 | + | 
 | 85 | +def _find_so_using_lib_dir(lib_dir, so_basename, error_messages, attachments):  | 
 | 86 | +    so_name = os.path.join(lib_dir, so_basename)  | 
 | 87 | +    if os.path.isfile(so_name):  | 
 | 88 | +        return so_name  | 
 | 89 | +    error_messages.append(f"No such file: {so_name}")  | 
 | 90 | +    attachments.append(f'  listdir("{lib_dir}"):')  | 
 | 91 | +    if not os.path.isdir(lib_dir):  | 
 | 92 | +        attachments.append("    DIRECTORY DOES NOT EXIST")  | 
 | 93 | +    else:  | 
 | 94 | +        for node in sorted(os.listdir(lib_dir)):  | 
 | 95 | +            attachments.append(f"    {node}")  | 
 | 96 | +    return None  | 
 | 97 | + | 
 | 98 | + | 
 | 99 | +def _find_dll_using_lib_dir(lib_dir, libname, error_messages, attachments):  | 
 | 100 | +    file_wild = libname + "*.dll"  | 
 | 101 | +    dll_name = _find_dll_under_dir(lib_dir, file_wild)  | 
 | 102 | +    if dll_name is not None:  | 
 | 103 | +        return dll_name  | 
 | 104 | +    error_messages.append(f"No such file: {file_wild}")  | 
 | 105 | +    attachments.append(f'  listdir("{lib_dir}"):')  | 
 | 106 | +    for node in sorted(os.listdir(lib_dir)):  | 
 | 107 | +        attachments.append(f"    {node}")  | 
 | 108 | +    return None  | 
 | 109 | + | 
 | 110 | + | 
 | 111 | +class _find_nvidia_dynamic_library:  | 
 | 112 | +    def __init__(self, libname: str):  | 
 | 113 | +        self.libname = libname  | 
 | 114 | +        self.error_messages = []  | 
 | 115 | +        self.attachments = []  | 
 | 116 | +        self.abs_path = None  | 
 | 117 | + | 
 | 118 | +        if IS_WINDOWS:  | 
 | 119 | +            self.lib_searched_for = f"{libname}*.dll"  | 
 | 120 | +            if self.abs_path is None:  | 
 | 121 | +                self.abs_path = _find_dll_using_nvidia_bin_dirs(  | 
 | 122 | +                    libname, self.lib_searched_for, self.error_messages, self.attachments  | 
 | 123 | +                )  | 
 | 124 | +        else:  | 
 | 125 | +            self.lib_searched_for = f"lib{libname}.so"  | 
 | 126 | +            if self.abs_path is None:  | 
 | 127 | +                self.abs_path = _find_so_using_nvidia_lib_dirs(  | 
 | 128 | +                    libname, self.lib_searched_for, self.error_messages, self.attachments  | 
 | 129 | +                )  | 
 | 130 | + | 
 | 131 | +    def retry_with_cuda_home_priority_last(self):  | 
 | 132 | +        cuda_home_lib_dir = _find_lib_dir_using_cuda_home(self.libname)  | 
 | 133 | +        if cuda_home_lib_dir is not None:  | 
 | 134 | +            if IS_WINDOWS:  | 
 | 135 | +                self.abs_path = _find_dll_using_lib_dir(  | 
 | 136 | +                    cuda_home_lib_dir, self.libname, self.error_messages, self.attachments  | 
 | 137 | +                )  | 
 | 138 | +            else:  | 
 | 139 | +                self.abs_path = _find_so_using_lib_dir(  | 
 | 140 | +                    cuda_home_lib_dir, self.lib_searched_for, self.error_messages, self.attachments  | 
 | 141 | +                )  | 
 | 142 | + | 
 | 143 | +    def raise_if_abs_path_is_None(self):  | 
 | 144 | +        if self.abs_path:  | 
 | 145 | +            return self.abs_path  | 
 | 146 | +        err = ", ".join(self.error_messages)  | 
 | 147 | +        att = "\n".join(self.attachments)  | 
 | 148 | +        raise RuntimeError(f'Failure finding "{self.lib_searched_for}": {err}\n{att}')  | 
 | 149 | + | 
 | 150 | + | 
 | 151 | +@functools.cache  | 
 | 152 | +def find_nvidia_dynamic_library(libname: str) -> str:  | 
 | 153 | +    return _find_nvidia_dynamic_library(libname).raise_if_abs_path_is_None()  | 
0 commit comments