Skip to content

Commit 914dcd5

Browse files
committed
Initial async kernel support
1 parent e4b4d3e commit 914dcd5

File tree

8 files changed

+328
-11
lines changed

8 files changed

+328
-11
lines changed

numba_dpex/core/runtime/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ python_add_library(${PROJECT_NAME} MODULE ${SOURCES})
109109

110110
# Add SYCL to target, this must come after python_add_library()
111111
# FIXME: sources incompatible with sycl include?
112-
# add_sycl_to_target(TARGET ${PROJECT_NAME} SOURCES ${KERNEL_SOURCES})
112+
add_sycl_to_target(TARGET ${PROJECT_NAME} SOURCES ${KERNEL_SOURCES})
113113

114114
# Link the DPCTLSyclInterface library to target
115115
target_link_libraries(${PROJECT_NAME} PRIVATE DPCTLSyclInterface)

numba_dpex/core/runtime/_dbg_printer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
/* Debugging facilities - enabled at compile-time */
1515
/* #undef NDEBUG */
16-
#if 0
16+
#if 1
1717
#include <stdio.h>
1818
#define DPEXRT_DEBUG(X) \
1919
{ \

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "_queuestruct.h"
2525
#include "_usmarraystruct.h"
2626

27+
#include "experimental/experimental.h"
2728
#include "numba/core/runtime/nrt_external.h"
2829

2930
// forward declarations
@@ -1490,6 +1491,8 @@ static PyObject *build_c_helpers_dict(void)
14901491
&DPEXRT_sycl_event_from_python);
14911492
_declpointer("DPEXRT_sycl_event_to_python", &DPEXRT_sycl_event_to_python);
14921493
_declpointer("DPEXRT_sycl_event_init", &DPEXRT_sycl_event_init);
1494+
_declpointer("DPEXRT_nrt_acuire_meminfo_and_schedule_release",
1495+
&DPEXRT_nrt_acuire_meminfo_and_schedule_release);
14931496

14941497
#undef _declpointer
14951498
return dct;
@@ -1557,6 +1560,9 @@ MOD_INIT(_dpexrt_python)
15571560
PyLong_FromVoidPtr(&DPEXRT_MemInfo_alloc));
15581561
PyModule_AddObject(m, "DPEXRT_MemInfo_fill",
15591562
PyLong_FromVoidPtr(&DPEXRT_MemInfo_fill));
1563+
PyModule_AddObject(
1564+
m, "DPEXRT_nrt_acuire_meminfo_and_schedule_release",
1565+
PyLong_FromVoidPtr(&DPEXRT_nrt_acuire_meminfo_and_schedule_release));
15601566
PyModule_AddObject(m, "c_helpers", build_c_helpers_dict());
15611567
return MOD_SUCCESS_VAL(m);
15621568
}

numba_dpex/core/runtime/context.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,41 @@ def submit_ndrange(
433433
)
434434

435435
return ret
436+
437+
def acuire_meminfo_and_schedule_release(
438+
self, builder: llvmir.IRBuilder, args
439+
):
440+
"""Inserts LLVM IR to call nrt_acuire_meminfo_and_schedule_release.
441+
442+
DPCTLSyclEventRef
443+
DPEXRT_nrt_acuire_meminfo_and_schedule_release(
444+
NRT_api_functions *nrt,
445+
DPCTLSyclQueueRef QRef,
446+
NRT_MemInfo **meminfo_array,
447+
size_t meminfo_array_size,
448+
DPCTLSyclEventRef *depERefs,
449+
size_t nDepERefs,
450+
int *status,
451+
);
452+
453+
"""
454+
mod = builder.module
455+
456+
func_ty = llvmir.FunctionType(
457+
cgutils.voidptr_t,
458+
[
459+
cgutils.voidptr_t,
460+
cgutils.voidptr_t,
461+
cgutils.voidptr_t.as_pointer(),
462+
llvmir.IntType(64),
463+
cgutils.voidptr_t,
464+
llvmir.IntType(64),
465+
llvmir.IntType(64).as_pointer(),
466+
],
467+
)
468+
fn = cgutils.get_or_insert_function(
469+
mod, func_ty, "DPEXRT_nrt_acuire_meminfo_and_schedule_release"
470+
)
471+
ret = builder.call(fn, args)
472+
473+
return ret
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef _EXPERIMENTAL_H_
2+
#define _EXPERIMENTAL_H_
3+
4+
#include "dpctl_capi.h"
5+
#include "numba/core/runtime/nrt_external.h"
6+
7+
#ifdef __cplusplus
8+
extern "C"
9+
{
10+
#endif
11+
DPCTLSyclEventRef DPEXRT_nrt_acuire_meminfo_and_schedule_release(
12+
NRT_api_functions *nrt,
13+
DPCTLSyclQueueRef QRef,
14+
NRT_MemInfo **meminfo_array,
15+
size_t meminfo_array_size,
16+
// DPCTLSyclEventRef *depERefs,
17+
DPCTLSyclEventRef depERef,
18+
size_t nDepERefs,
19+
int *status);
20+
#ifdef __cplusplus
21+
}
22+
#endif
23+
24+
#endif /* _EXPERIMENTAL_H_ */
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include "experimental.h"
2+
3+
#include "_dbg_printer.h"
4+
#include "syclinterface/dpctl_sycl_type_casters.hpp"
5+
#include <CL/sycl.hpp>
6+
7+
extern "C"
8+
{
9+
DPCTLSyclEventRef DPEXRT_nrt_acuire_meminfo_and_schedule_release(
10+
NRT_api_functions *nrt,
11+
DPCTLSyclQueueRef QRef,
12+
NRT_MemInfo **meminfo_array,
13+
size_t meminfo_array_size,
14+
// DPCTLSyclEventRef *depERefs,
15+
DPCTLSyclEventRef depERef,
16+
size_t nDepERefs,
17+
int *status)
18+
{
19+
DPEXRT_DEBUG(drt_debug_print(
20+
"DPEXRT-DEBUG: scheduling nrt meminfo release.\n"););
21+
22+
using dpctl::syclinterface::unwrap;
23+
using dpctl::syclinterface::wrap;
24+
25+
sycl::queue *q = unwrap<sycl::queue>(QRef);
26+
27+
std::vector<NRT_MemInfo *> meminfo_vec(
28+
meminfo_array, meminfo_array + meminfo_array_size);
29+
30+
for (size_t i = 0; i < meminfo_array_size; ++i) {
31+
nrt->acquire(meminfo_vec[i]);
32+
}
33+
34+
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: acquired meminfo.\n"););
35+
36+
try {
37+
sycl::event ht_ev = q->submit([&](sycl::handler &cgh) {
38+
for (size_t ev_id = 0; ev_id < nDepERefs; ++ev_id) {
39+
// cgh.depends_on(*(unwrap<sycl::event>(depERefs[ev_id])));
40+
cgh.depends_on(*(unwrap<sycl::event>(depERef)));
41+
}
42+
cgh.host_task([meminfo_array_size, meminfo_vec, nrt]() {
43+
for (size_t i = 0; i < meminfo_array_size; ++i) {
44+
nrt->release(meminfo_vec[i]);
45+
DPEXRT_DEBUG(
46+
drt_debug_print("DPEXRT-DEBUG: released meminfo "
47+
"from host_task.\n"););
48+
}
49+
});
50+
});
51+
52+
constexpr int result_ok = 0;
53+
54+
*status = result_ok;
55+
auto e_ptr = new sycl::event(ht_ev);
56+
return wrap<sycl::event>(e_ptr);
57+
} catch (const std::exception &e) {
58+
constexpr int result_std_exception = 1;
59+
60+
*status = result_std_exception;
61+
return nullptr;
62+
}
63+
64+
constexpr int result_other_abnormal = 2;
65+
66+
*status = result_other_abnormal;
67+
return nullptr;
68+
}
69+
}

0 commit comments

Comments
 (0)