diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java index d62450ece0..7b6b4be7bb 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java @@ -31,10 +31,12 @@ import com.nvidia.cuvs.internal.panama.cudaDeviceProp; import com.nvidia.cuvs.internal.panama.headers_h; import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; import java.lang.foreign.Linker; import java.lang.foreign.MemoryLayout; import java.lang.foreign.MemoryLayout.PathElement; import java.lang.foreign.MemorySegment; +import java.lang.foreign.SymbolLookup; import java.lang.invoke.MethodHandle; import java.lang.invoke.VarHandle; import java.util.ArrayList; @@ -50,6 +52,11 @@ private Util() {} private static final Linker LINKER = Linker.nativeLinker(); + static final SymbolLookup SYMBOL_LOOKUP = + SymbolLookup.libraryLookup(System.mapLibraryName("cuvs_c"), Arena.ofAuto()) + .or(SymbolLookup.loaderLookup()) + .or(Linker.nativeLinker().defaultLookup()); + /** * Bindings for {@code cudaMemcpyAsync}; differently from the {@code headers_h} bindings (which are * automatically generated by {@code jextract}), these bindings specify the {@code critical} linker option, @@ -62,6 +69,26 @@ private Util() {} LINKER.downcallHandle( cudaMemcpyAsync$address(), cudaMemcpyAsync$descriptor(), Linker.Option.critical(true)); + private static final String cudaGetDevicePropertiesSymbolName = + "12".equals(System.getenv("RAPIDS_CUDA_MAJOR")) + ? "cudaGetDeviceProperties_v2" + : "cudaGetDeviceProperties"; + + private static final MethodHandle cudaGetDeviceProperties$mh = + LINKER.downcallHandle( + SYMBOL_LOOKUP + .find(cudaGetDevicePropertiesSymbolName) + .orElseThrow(UnsatisfiedLinkError::new), + FunctionDescriptor.of(headers_h.C_INT, headers_h.C_POINTER, headers_h.C_INT)); + + public static int cudaGetDeviceProperties(MemorySegment prop, int device) { + try { + return (int) cudaGetDeviceProperties$mh.invokeExact(prop, device); + } catch (Throwable ex$) { + throw new AssertionError("should not reach here", ex$); + } + } + /** * Checks the result value of a (CuVS) native method handle call. * @@ -242,8 +269,8 @@ public static List availableGPUs() throws Throwable { returnValue = cudaSetDevice(i); checkCudaError(returnValue, "cudaSetDevice"); - returnValue = cudaGetDeviceProperties_v2(deviceProp, i); - checkCudaError(returnValue, "cudaGetDeviceProperties_v2"); + returnValue = cudaGetDeviceProperties(deviceProp, i); + checkCudaError(returnValue, "cudaGetDeviceProperties"); returnValue = cudaMemGetInfo(free, total); checkCudaError(returnValue, "cudaMemGetInfo");