Skip to content

Commit

Permalink
dlopen fix for win32
Browse files Browse the repository at this point in the history
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
  • Loading branch information
wkpark committed Oct 17, 2024
1 parent bef6e24 commit 68febb5
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
27 changes: 27 additions & 0 deletions third_party/nvidia/backend/driver.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include "cuda.h"
#ifndef _WIN32
#include <dlfcn.h>
#else
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#endif
#include <stdbool.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
Expand Down Expand Up @@ -162,6 +167,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
CUtensorMapFloatOOBfill oobFill);

#ifndef _WIN32
#define defineGetFunctionHandle(name, symbolName) \
static symbolName##_t name() { \
/* Open the shared library */ \
Expand All @@ -183,6 +189,27 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
} \
return funcHandle; \
}
#else
#define defineGetFunctionHandle(name, symbolName) \
static symbolName##_t name() { \
/* Open the shared library */ \
HMODULE handle = LoadLibraryA("nvcuda.dll"); \
if (!handle) { \
PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll"); \
return NULL; \
} \
symbolName##_t funcHandle = \
(symbolName##_t)GetProcAddress((HMODULE)handle, #symbolName); \
/* Check for errors */ \
long err = GetLastError(); \
if (err) { \
PyErr_SetString(PyExc_RuntimeError, \
"Failed to retrieve " #symbolName " from nvcuda.dll"); \
return NULL; \
} \
return funcHandle; \
}
#endif

defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
cuOccupancyMaxActiveClusters);
Expand Down
25 changes: 25 additions & 0 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ def format_of(ty):
#include \"cuda.h\"
#include <stdbool.h>
#include <Python.h>
#ifndef _WIN32
#include <dlfcn.h>
#else
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#endif
static inline void gpuAssert(CUresult code, const char *file, int line)
{{
Expand All @@ -186,6 +191,7 @@ def format_of(ty):
typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);
#ifndef _WIN32
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
// Open the shared library
void* handle = dlopen("libcuda.so.1", RTLD_LAZY);
Expand All @@ -204,6 +210,25 @@ def format_of(ty):
}}
return cuLaunchKernelExHandle;
}}
#else
static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
// Open the shared library
HMODULE handle = LoadLibraryA("nvcuda.dll");
if (!handle) {{
PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll");
return NULL;
}}
cuLaunchKernelEx_t cuLaunchKernelExHandle =
(cuLaunchKernelEx_t)GetProcAddress((HMODULE)handle, "cuLaunchKernelEx");
// Check for errors
long error = GetLastError();
if (error) {{
PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from nvcuda.dll");
return NULL;
}}
return cuLaunchKernelExHandle;
}}
#endif
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};
Expand Down

0 comments on commit 68febb5

Please sign in to comment.