Skip to content

Commit

Permalink
based on Windows support PR triton-lang#2456
Browse files Browse the repository at this point in the history
 * WIN32 fix using LoadLibrary

Signed-off-by: Won-Kyu Park <wkpark@gmail.com>
  • Loading branch information
andreigh authored and wkpark committed Dec 1, 2023
1 parent 17fd518 commit c589172
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions python/triton/runtime/backends/cuda.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include "cuda.h"
#ifndef _WIN32
#include <dlfcn.h>
#else
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#endif
#include <stdbool.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
Expand Down Expand Up @@ -94,10 +99,17 @@ static bool gpuAssert(CUresult code, const char *file, int line) {
#define DISPATCH_ARGS_N(_14, _13, _12, _11, _10, _9, _8, _7, _6, _5, _4, _3, \
_2, _1, N, ...) \
ADD_ENUM_ITEM_##N
#if !defined(_MSC_VER) || defined(__clang__)
#define DISPATCH_ARGS(...) \
DISPATCH_ARGS_N(__VA_ARGS__, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, \
0) \
(__VA_ARGS__)
#else
#define EXPAND_ARGS(args) args
#define DISPATCH_ARGS(...) \
DISPATCH_ARGS_N EXPAND_ARGS((__VA_ARGS__, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, \
4, 3, 2, 1, 0))(__VA_ARGS__)
#endif

#define ADD_ENUM_TO_MODULE(module, enum_name, ...) \
do { \
Expand Down Expand Up @@ -377,6 +389,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
CUtensorMapFloatOOBfill oobFill);

#ifndef _WIN32
static cuTensorMapEncodeTiled_t getCuTensorMapEncodeTiledHandle() {
// Open the shared library
void *handle = dlopen("libcuda.so", RTLD_LAZY);
Expand All @@ -398,6 +411,28 @@ static cuTensorMapEncodeTiled_t getCuTensorMapEncodeTiledHandle() {
}
return cuTensorMapEncodeTiledHandle;
}
#else
static cuTensorMapEncodeTiled_t getCuTensorMapEncodeTiledHandle() {
// Open the shared library
HMODULE handle = LoadLibraryA("nvcuda.dll");
if (!handle) {
PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll");
return NULL;
}
cuTensorMapEncodeTiled_t cuTensorMapEncodeTiledHandle =
(cuTensorMapEncodeTiled_t)GetProcAddress((HMODULE)handle,
"cuTensorMapEncodeTiled");
// Check for errors
long error = GetLastError();
if (error) {
PyErr_SetString(
PyExc_RuntimeError,
"Failed to retrieve cuTensorMapEncodeTiled from nvcuda.dll");
return NULL;
}
return cuTensorMapEncodeTiledHandle;
}
#endif

static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
CUtensorMap *tensorMap = (CUtensorMap *)malloc(sizeof(CUtensorMap));
Expand Down

0 comments on commit c589172

Please sign in to comment.