Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
import com.nvidia.cuvs.internal.panama.cudaDeviceProp;
import com.nvidia.cuvs.internal.panama.headers_h;
import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemoryLayout.PathElement;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SymbolLookup;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.VarHandle;
import java.util.ArrayList;
Expand All @@ -50,6 +52,11 @@ private Util() {}

private static final Linker LINKER = Linker.nativeLinker();

static final SymbolLookup SYMBOL_LOOKUP =
SymbolLookup.libraryLookup(System.mapLibraryName("cuvs_c"), Arena.ofAuto())
.or(SymbolLookup.loaderLookup())
.or(Linker.nativeLinker().defaultLookup());

/**
* Bindings for {@code cudaMemcpyAsync}; differently from the {@code headers_h} bindings (which are
* automatically generated by {@code jextract}), these bindings specify the {@code critical} linker option,
Expand All @@ -62,6 +69,26 @@ private Util() {}
LINKER.downcallHandle(
cudaMemcpyAsync$address(), cudaMemcpyAsync$descriptor(), Linker.Option.critical(true));

private static final String cudaGetDevicePropertiesSymbolName =
"12".equals(System.getenv("RAPIDS_CUDA_MAJOR"))
? "cudaGetDeviceProperties_v2"
: "cudaGetDeviceProperties";

private static final MethodHandle cudaGetDeviceProperties$mh =
LINKER.downcallHandle(
SYMBOL_LOOKUP
.find(cudaGetDevicePropertiesSymbolName)
.orElseThrow(UnsatisfiedLinkError::new),
FunctionDescriptor.of(headers_h.C_INT, headers_h.C_POINTER, headers_h.C_INT));

public static int cudaGetDeviceProperties(MemorySegment prop, int device) {
try {
return (int) cudaGetDeviceProperties$mh.invokeExact(prop, device);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}

/**
* Checks the result value of a (CuVS) native method handle call.
*
Expand Down Expand Up @@ -242,8 +269,8 @@ public static List<GPUInfo> availableGPUs() throws Throwable {
returnValue = cudaSetDevice(i);
checkCudaError(returnValue, "cudaSetDevice");

returnValue = cudaGetDeviceProperties_v2(deviceProp, i);
checkCudaError(returnValue, "cudaGetDeviceProperties_v2");
returnValue = cudaGetDeviceProperties(deviceProp, i);
checkCudaError(returnValue, "cudaGetDeviceProperties");

returnValue = cudaMemGetInfo(free, total);
checkCudaError(returnValue, "cudaMemGetInfo");
Expand Down