|
| 1 | +#include "numba/_pymodule.h" |
| 2 | +#include "numba/core/runtime/nrt_external.h" |
| 3 | +#include "assert.h" |
| 4 | +#include <stdio.h> |
| 5 | +#if !defined _WIN32 |
| 6 | + #include <dlfcn.h> |
| 7 | +#else |
| 8 | + #include <windows.h> |
| 9 | +#endif |
| 10 | + |
| 11 | +NRT_ExternalAllocator usmarray_allocator; |
| 12 | +NRT_external_malloc_func internal_allocator = NULL; |
| 13 | +NRT_external_free_func internal_free = NULL; |
| 14 | +void *(*get_queue_internal)(void) = NULL; |
| 15 | +void (*free_queue_internal)(void*) = NULL; |
| 16 | + |
| 17 | +void * save_queue_allocator(size_t size, void *opaque) { |
| 18 | + // Allocate a pointer-size more space than neded. |
| 19 | + int new_size = size + sizeof(void*); |
| 20 | + // Get the current queue |
| 21 | + void *cur_queue = get_queue_internal(); // this makes a copy |
| 22 | + // Use that queue to allocate. |
| 23 | + void *data = internal_allocator(new_size, cur_queue); |
| 24 | + // Set first pointer-sized data in allocated space to be the current queue. |
| 25 | + *(void**)data = cur_queue; |
| 26 | + // Return the pointer after this queue in memory. |
| 27 | + return (char*)data + sizeof(void*); |
| 28 | +} |
| 29 | + |
| 30 | +void save_queue_deallocator(void *data, void *opaque) { |
| 31 | + // Compute original allocation location by subtracting the length |
| 32 | + // of the queue pointer from the data location that Numba thinks |
| 33 | + // starts the object. |
| 34 | + void *orig_data = (char*)data - sizeof(void*); |
| 35 | + // Get the queue from the original data by derefencing the first qword. |
| 36 | + void *obj_queue = *(void**)orig_data; |
| 37 | + // Free the space using the correct queue. |
| 38 | + internal_free(orig_data, obj_queue); |
| 39 | + // Free the queue itself. |
| 40 | + free_queue_internal(obj_queue); |
| 41 | +} |
| 42 | + |
| 43 | +void usmarray_memsys_init(void) { |
| 44 | + #if !defined _WIN32 |
| 45 | + char *lib_name = "libDPCTLSyclInterface.so"; |
| 46 | + char *malloc_name = "DPCTLmalloc_shared"; |
| 47 | + char *free_name = "DPCTLfree_with_queue"; |
| 48 | + char *get_queue_name = "DPCTLQueueMgr_GetCurrentQueue"; |
| 49 | + char *free_queue_name = "DPCTLQueue_Delete"; |
| 50 | + |
| 51 | + void *sycldl = dlopen(lib_name, RTLD_NOW); |
| 52 | + assert(sycldl != NULL); |
| 53 | + internal_allocator = (NRT_external_malloc_func)dlsym(sycldl, malloc_name); |
| 54 | + usmarray_allocator.malloc = save_queue_allocator; |
| 55 | + if (internal_allocator == NULL) { |
| 56 | + printf("Did not find %s in %s\n", malloc_name, lib_name); |
| 57 | + exit(-1); |
| 58 | + } |
| 59 | + |
| 60 | + usmarray_allocator.realloc = NULL; |
| 61 | + |
| 62 | + internal_free = (NRT_external_free_func)dlsym(sycldl, free_name); |
| 63 | + usmarray_allocator.free = save_queue_deallocator; |
| 64 | + if (internal_free == NULL) { |
| 65 | + printf("Did not find %s in %s\n", free_name, lib_name); |
| 66 | + exit(-1); |
| 67 | + } |
| 68 | + |
| 69 | + get_queue_internal = (void *(*)(void))dlsym(sycldl, get_queue_name); |
| 70 | + if (get_queue_internal == NULL) { |
| 71 | + printf("Did not find %s in %s\n", get_queue_name, lib_name); |
| 72 | + exit(-1); |
| 73 | + } |
| 74 | + usmarray_allocator.opaque_data = NULL; |
| 75 | + |
| 76 | + free_queue_internal = (void (*)(void*))dlsym(sycldl, free_queue_name); |
| 77 | + if (free_queue_internal == NULL) { |
| 78 | + printf("Did not find %s in %s\n", free_queue_name, lib_name); |
| 79 | + exit(-1); |
| 80 | + } |
| 81 | + #else |
| 82 | + char *lib_name = "DPCTLSyclInterface.dll"; |
| 83 | + char *malloc_name = "DPCTLmalloc_shared"; |
| 84 | + char *free_name = "DPCTLfree_with_queue"; |
| 85 | + char *get_queue_name = "DPCTLQueueMgr_GetCurrentQueue"; |
| 86 | + char *free_queue_name = "DPCTLQueue_Delete"; |
| 87 | + |
| 88 | + HMODULE sycldl = LoadLibrary(lib_name); |
| 89 | + assert(sycldl != NULL); |
| 90 | + internal_allocator = (NRT_external_malloc_func)GetProcAddress(sycldl, malloc_name); |
| 91 | + usmarray_allocator.malloc = save_queue_allocator; |
| 92 | + if (internal_allocator == NULL) { |
| 93 | + printf("Did not find %s in %s\n", malloc_name, lib_name); |
| 94 | + exit(-1); |
| 95 | + } |
| 96 | + |
| 97 | + usmarray_allocator.realloc = NULL; |
| 98 | + |
| 99 | + internal_free = (NRT_external_free_func)GetProcAddress(sycldl, free_name); |
| 100 | + usmarray_allocator.free = save_queue_deallocator; |
| 101 | + if (internal_free == NULL) { |
| 102 | + printf("Did not find %s in %s\n", free_name, lib_name); |
| 103 | + exit(-1); |
| 104 | + } |
| 105 | + |
| 106 | + get_queue_internal = (void *(*)(void))GetProcAddress(sycldl, get_queue_name); |
| 107 | + if (get_queue_internal == NULL) { |
| 108 | + printf("Did not find %s in %s\n", get_queue_name, lib_name); |
| 109 | + exit(-1); |
| 110 | + } |
| 111 | + usmarray_allocator.opaque_data = NULL; |
| 112 | + |
| 113 | + free_queue_internal = (void (*)(void*))GetProcAddress(sycldl, free_queue_name); |
| 114 | + if (free_queue_internal == NULL) { |
| 115 | + printf("Did not find %s in %s\n", free_queue_name, lib_name); |
| 116 | + exit(-1); |
| 117 | + } |
| 118 | + #endif |
| 119 | +} |
| 120 | + |
| 121 | +void * usmarray_get_ext_allocator(void) { |
| 122 | + return (void*)&usmarray_allocator; |
| 123 | +} |
| 124 | + |
| 125 | +static PyObject * |
| 126 | +get_external_allocator(PyObject *self, PyObject *args) { |
| 127 | + return PyLong_FromVoidPtr(usmarray_get_ext_allocator()); |
| 128 | +} |
| 129 | + |
| 130 | +static PyMethodDef ext_methods[] = { |
| 131 | +#define declmethod_noargs(func) { #func , ( PyCFunction )func , METH_NOARGS, NULL } |
| 132 | + declmethod_noargs(get_external_allocator), |
| 133 | + {NULL}, |
| 134 | +#undef declmethod_noargs |
| 135 | +}; |
| 136 | + |
| 137 | +static PyObject * |
| 138 | +build_c_helpers_dict(void) |
| 139 | +{ |
| 140 | + PyObject *dct = PyDict_New(); |
| 141 | + if (dct == NULL) |
| 142 | + goto error; |
| 143 | + |
| 144 | +#define _declpointer(name, value) do { \ |
| 145 | + PyObject *o = PyLong_FromVoidPtr(value); \ |
| 146 | + if (o == NULL) goto error; \ |
| 147 | + if (PyDict_SetItemString(dct, name, o)) { \ |
| 148 | + Py_DECREF(o); \ |
| 149 | + goto error; \ |
| 150 | + } \ |
| 151 | + Py_DECREF(o); \ |
| 152 | +} while (0) |
| 153 | + |
| 154 | + _declpointer("usmarray_get_ext_allocator", &usmarray_get_ext_allocator); |
| 155 | + |
| 156 | +#undef _declpointer |
| 157 | + return dct; |
| 158 | +error: |
| 159 | + Py_XDECREF(dct); |
| 160 | + return NULL; |
| 161 | +} |
| 162 | + |
| 163 | +MOD_INIT(_dppy_rt) { |
| 164 | + PyObject *m; |
| 165 | + MOD_DEF(m, "numba_dppy._dppy_rt", "No docs", ext_methods) |
| 166 | + if (m == NULL) |
| 167 | + return MOD_ERROR_VAL; |
| 168 | + usmarray_memsys_init(); |
| 169 | + PyModule_AddObject(m, "c_helpers", build_c_helpers_dict()); |
| 170 | + return MOD_SUCCESS_VAL(m); |
| 171 | +} |
0 commit comments