Skip to content

Commit 425dabb

Browse files
committed
Add async kernel kenel submition support
1 parent 05aa34d commit 425dabb

File tree

9 files changed

+412
-35
lines changed

9 files changed

+412
-35
lines changed

numba_dpex/core/runtime/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ python_add_library(${PROJECT_NAME} MODULE ${SOURCES})
109109

110110
# Add SYCL to target, this must come after python_add_library()
111111
# FIXME: sources incompatible with sycl include?
112-
# add_sycl_to_target(TARGET ${PROJECT_NAME} SOURCES ${KERNEL_SOURCES})
112+
add_sycl_to_target(TARGET ${PROJECT_NAME} SOURCES ${KERNEL_SOURCES})
113113

114114
# Link the DPCTLSyclInterface library to target
115115
target_link_libraries(${PROJECT_NAME} PRIVATE DPCTLSyclInterface)

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "_queuestruct.h"
2525
#include "_usmarraystruct.h"
2626

27+
#include "experimental/nrt_reserve_meminfo.h"
2728
#include "numba/core/runtime/nrt_external.h"
2829

2930
// forward declarations
@@ -1490,6 +1491,8 @@ static PyObject *build_c_helpers_dict(void)
14901491
&DPEXRT_sycl_event_from_python);
14911492
_declpointer("DPEXRT_sycl_event_to_python", &DPEXRT_sycl_event_to_python);
14921493
_declpointer("DPEXRT_sycl_event_init", &DPEXRT_sycl_event_init);
1494+
_declpointer("DPEXRT_nrt_acquire_meminfo_and_schedule_release",
1495+
&DPEXRT_nrt_acquire_meminfo_and_schedule_release);
14931496

14941497
#undef _declpointer
14951498
return dct;
@@ -1557,6 +1560,9 @@ MOD_INIT(_dpexrt_python)
15571560
PyLong_FromVoidPtr(&DPEXRT_MemInfo_alloc));
15581561
PyModule_AddObject(m, "DPEXRT_MemInfo_fill",
15591562
PyLong_FromVoidPtr(&DPEXRT_MemInfo_fill));
1563+
PyModule_AddObject(
1564+
m, "DPEXRT_nrt_acquire_meminfo_and_schedule_release",
1565+
PyLong_FromVoidPtr(&DPEXRT_nrt_acquire_meminfo_and_schedule_release));
15601566
PyModule_AddObject(m, "c_helpers", build_c_helpers_dict());
15611567
return MOD_SUCCESS_VAL(m);
15621568
}

numba_dpex/core/runtime/context.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,41 @@ def submit_ndrange(
433433
)
434434

435435
return ret
436+
437+
def acquire_meminfo_and_schedule_release(
438+
self, builder: llvmir.IRBuilder, args
439+
):
440+
"""Inserts LLVM IR to call nrt_acquire_meminfo_and_schedule_release.
441+
442+
DPCTLSyclEventRef
443+
DPEXRT_nrt_acquire_meminfo_and_schedule_release(
444+
NRT_api_functions *nrt,
445+
DPCTLSyclQueueRef QRef,
446+
NRT_MemInfo **meminfo_array,
447+
size_t meminfo_array_size,
448+
DPCTLSyclEventRef *depERefs,
449+
size_t nDepERefs,
450+
int *status,
451+
);
452+
453+
"""
454+
mod = builder.module
455+
456+
func_ty = llvmir.FunctionType(
457+
cgutils.voidptr_t,
458+
[
459+
cgutils.voidptr_t,
460+
cgutils.voidptr_t,
461+
cgutils.voidptr_t.as_pointer(),
462+
llvmir.IntType(64),
463+
cgutils.voidptr_t.as_pointer(),
464+
llvmir.IntType(64),
465+
llvmir.IntType(64).as_pointer(),
466+
],
467+
)
468+
fn = cgutils.get_or_insert_function(
469+
mod, func_ty, "DPEXRT_nrt_acquire_meminfo_and_schedule_release"
470+
)
471+
ret = builder.call(fn, args)
472+
473+
return ret
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// SPDX-FileCopyrightText: 2023 Intel Corporation
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#include "nrt_reserve_meminfo.h"
6+
7+
#include "_dbg_printer.h"
8+
#include "syclinterface/dpctl_sycl_type_casters.hpp"
9+
#include <CL/sycl.hpp>
10+
11+
extern "C"
12+
{
13+
DPCTLSyclEventRef
14+
DPEXRT_nrt_acquire_meminfo_and_schedule_release(NRT_api_functions *nrt,
15+
DPCTLSyclQueueRef QRef,
16+
NRT_MemInfo **meminfo_array,
17+
size_t meminfo_array_size,
18+
DPCTLSyclEventRef *depERefs,
19+
size_t nDepERefs,
20+
int *status)
21+
{
22+
DPEXRT_DEBUG(drt_debug_print(
23+
"DPEXRT-DEBUG: scheduling nrt meminfo release.\n"););
24+
25+
using dpctl::syclinterface::unwrap;
26+
using dpctl::syclinterface::wrap;
27+
28+
sycl::queue *q = unwrap<sycl::queue>(QRef);
29+
30+
std::vector<NRT_MemInfo *> meminfo_vec(
31+
meminfo_array, meminfo_array + meminfo_array_size);
32+
33+
for (size_t i = 0; i < meminfo_array_size; ++i) {
34+
nrt->acquire(meminfo_vec[i]);
35+
}
36+
37+
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: acquired meminfo.\n"););
38+
39+
try {
40+
sycl::event ht_ev = q->submit([&](sycl::handler &cgh) {
41+
for (size_t ev_id = 0; ev_id < nDepERefs; ++ev_id) {
42+
cgh.depends_on(*(unwrap<sycl::event>(depERefs[ev_id])));
43+
}
44+
cgh.host_task([meminfo_array_size, meminfo_vec, nrt]() {
45+
for (size_t i = 0; i < meminfo_array_size; ++i) {
46+
nrt->release(meminfo_vec[i]);
47+
DPEXRT_DEBUG(
48+
drt_debug_print("DPEXRT-DEBUG: released meminfo "
49+
"from host_task.\n"););
50+
}
51+
});
52+
});
53+
54+
constexpr int result_ok = 0;
55+
56+
*status = result_ok;
57+
auto e_ptr = new sycl::event(ht_ev);
58+
return wrap<sycl::event>(e_ptr);
59+
} catch (const std::exception &e) {
60+
constexpr int result_std_exception = 1;
61+
62+
*status = result_std_exception;
63+
return nullptr;
64+
}
65+
66+
constexpr int result_other_abnormal = 2;
67+
68+
*status = result_other_abnormal;
69+
return nullptr;
70+
}
71+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// SPDX-FileCopyrightText: 2023 Intel Corporation
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#ifndef _EXPERIMENTAL_H_
6+
#define _EXPERIMENTAL_H_
7+
8+
#include "dpctl_capi.h"
9+
#include "numba/core/runtime/nrt_external.h"
10+
11+
#ifdef __cplusplus
12+
extern "C"
13+
{
14+
#endif
15+
DPCTLSyclEventRef
16+
DPEXRT_nrt_acquire_meminfo_and_schedule_release(NRT_api_functions *nrt,
17+
DPCTLSyclQueueRef QRef,
18+
NRT_MemInfo **meminfo_array,
19+
size_t meminfo_array_size,
20+
DPCTLSyclEventRef *depERefs,
21+
size_t nDepERefs,
22+
int *status);
23+
#ifdef __cplusplus
24+
}
25+
#endif
26+
27+
#endif /* _EXPERIMENTAL_H_ */

numba_dpex/experimental/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .decorators import kernel
1212
from .kernel_dispatcher import KernelDispatcher
13-
from .launcher import call_kernel
13+
from .launcher import call_kernel, call_kernel_async
1414
from .models import *
1515
from .types import KernelDispatcherType
1616

@@ -26,4 +26,4 @@ def dpex_dispatcher_const(context):
2626
return context.get_dummy_value()
2727

2828

29-
__all__ = ["kernel", "KernelDispatcher", "call_kernel"]
29+
__all__ = ["kernel", "KernelDispatcher", "call_kernel", "call_kernel_async"]

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def get_overload_device_ir(self, sig):
254254
args, _ = sigutils.normalize_signature(sig)
255255
return self.overloads[tuple(args)].kernel_device_ir_module
256256

257-
def compile(self, sig) -> _KernelCompileResult:
257+
def compile(self, sig) -> any:
258258
disp = self._get_dispatcher_for_current_target()
259259
if disp is not self:
260260
return disp.compile(sig)

0 commit comments

Comments
 (0)