Skip to content

Commit c0cfee5

Browse files
committed
Initial async kernel support
1 parent 1fe4ad0 commit c0cfee5

File tree

8 files changed

+250
-10
lines changed

8 files changed

+250
-10
lines changed

numba_dpex/core/runtime/CMakeLists.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,18 @@ include_directories(${Python_INCLUDE_DIRS})
9494
include_directories(${NumPy_INCLUDE_DIRS})
9595
include_directories(${Numba_INCLUDE_DIRS})
9696
include_directories(${Dpctl_INCLUDE_DIRS})
97+
# include_directories(${CMAKE_CURRENT_SOURCE_DIR})
98+
# include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kernels/tensor/include)
9799
include_directories(.)
98100

99101
# Source files, *.c
100102
file(GLOB_RECURSE DPEXRT_SOURCES CONFIGURE_DEPENDS "*.c")
101103
file(GLOB_RECURSE KERNEL_SOURCES CONFIGURE_DEPENDS "*.cpp")
102-
set(SOURCES ${DPEXRT_SOURCES} ${KERNEL_SOURCES})
104+
# set(SOURCES ${DPEXRT_SOURCES} ${KERNEL_SOURCES})
105+
set(SOURCES ${KERNEL_SOURCES} ${DPEXRT_SOURCES})
106+
107+
message(KERNEL_SOURCES="${KERNEL_SOURCES}")
108+
message(SOURCES="${SOURCES}")
103109

104110
# Link dpctl library path with -L
105111
link_directories(${DPCTL_LIBRARY_PATH})
@@ -109,7 +115,7 @@ python_add_library(${PROJECT_NAME} MODULE ${SOURCES})
109115

110116
# Add SYCL to target, this must come after python_add_library()
111117
# FIXME: sources incompatible with sycl include?
112-
# add_sycl_to_target(TARGET ${PROJECT_NAME} SOURCES ${KERNEL_SOURCES})
118+
add_sycl_to_target(TARGET ${PROJECT_NAME} SOURCES ${KERNEL_SOURCES})
113119

114120
# Link the DPCTLSyclInterface library to target
115121
target_link_libraries(${PROJECT_NAME} PRIVATE DPCTLSyclInterface)

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "_queuestruct.h"
2525
#include "_usmarraystruct.h"
2626

27+
#include "experimental/experimental.h"
2728
#include "numba/core/runtime/nrt_external.h"
2829

2930
// forward declarations
@@ -1490,6 +1491,8 @@ static PyObject *build_c_helpers_dict(void)
14901491
&DPEXRT_sycl_event_from_python);
14911492
_declpointer("DPEXRT_sycl_event_to_python", &DPEXRT_sycl_event_to_python);
14921493
_declpointer("DPEXRT_sycl_event_init", &DPEXRT_sycl_event_init);
1494+
_declpointer("DPEXRT_schedule_nrt_meminfo_release",
1495+
&DPEXRT_schedule_nrt_meminfo_release);
14931496

14941497
#undef _declpointer
14951498
return dct;
@@ -1557,6 +1560,9 @@ MOD_INIT(_dpexrt_python)
15571560
PyLong_FromVoidPtr(&DPEXRT_MemInfo_alloc));
15581561
PyModule_AddObject(m, "DPEXRT_MemInfo_fill",
15591562
PyLong_FromVoidPtr(&DPEXRT_MemInfo_fill));
1563+
PyModule_AddObject(
1564+
m, "DPEXRT_schedule_nrt_meminfo_release",
1565+
PyLong_FromVoidPtr(&DPEXRT_schedule_nrt_meminfo_release));
15601566
PyModule_AddObject(m, "c_helpers", build_c_helpers_dict());
15611567
return MOD_SUCCESS_VAL(m);
15621568
}

numba_dpex/core/runtime/context.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,37 @@ def submit_ndrange(
433433
)
434434

435435
return ret
436+
437+
def schedule_nrt_meminfo_release(self, builder: llvmir.IRBuilder, *args):
438+
"""Inserts LLVM IR to call schedule_nrt_release.
439+
440+
DPCTLSyclEventRef
441+
schedule_nrt_memino_release(
442+
NRT_api_functions *nrt,
443+
DPCTLSyclQueueRef QRef,
444+
NRT_MemInfo **meminfo_array,
445+
size_t meminfo_array_size,
446+
DPCTLSyclEventRef *depERefs,
447+
size_t nDepERefs,
448+
int *status,
449+
);
450+
451+
"""
452+
mod = builder.module
453+
fn = _build_dpctl_function(
454+
llvm_module=mod,
455+
return_ty=cgutils.voidptr_t,
456+
arg_list=[
457+
cgutils.voidptr_t,
458+
cgutils.voidptr_t,
459+
cgutils.voidptr_t.as_pointer(),
460+
llvmir.IntType(64),
461+
cgutils.voidptr_t,
462+
llvmir.IntType(64),
463+
llvmir.IntType(64).as_pointer(),
464+
],
465+
func_name="DPEXRT_schedule_nrt_meminfo_release",
466+
)
467+
ret = builder.call(fn, args)
468+
469+
return ret
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef _EXPERIMENTAL_H_
2+
#define _EXPERIMENTAL_H_
3+
4+
#include "dpctl_capi.h"
5+
#include "numba/core/runtime/nrt_external.h"
6+
7+
#ifdef __cplusplus
8+
extern "C"
9+
{
10+
#endif
11+
DPCTLSyclEventRef
12+
DPEXRT_schedule_nrt_meminfo_release(NRT_api_functions *nrt,
13+
DPCTLSyclQueueRef QRef,
14+
NRT_MemInfo **meminfo_array,
15+
size_t meminfo_array_size,
16+
DPCTLSyclEventRef *depERefs,
17+
size_t nDepERefs,
18+
int *status);
19+
#ifdef __cplusplus
20+
}
21+
#endif
22+
23+
#endif /* _EXPERIMENTAL_H_ */
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "experimental.h"
2+
3+
#include "_dbg_printer.h"
4+
#include "syclinterface/dpctl_sycl_type_casters.hpp"
5+
#include <CL/sycl.hpp>
6+
7+
extern "C"
8+
{
9+
DPCTLSyclEventRef
10+
DPEXRT_schedule_nrt_meminfo_release(NRT_api_functions *nrt,
11+
DPCTLSyclQueueRef QRef,
12+
NRT_MemInfo **meminfo_array,
13+
size_t meminfo_array_size,
14+
DPCTLSyclEventRef *depERefs,
15+
size_t nDepERefs,
16+
int *status)
17+
{
18+
using dpctl::syclinterface::unwrap;
19+
using dpctl::syclinterface::wrap;
20+
21+
sycl::queue *q = unwrap<sycl::queue>(QRef);
22+
23+
std::vector<NRT_MemInfo *> meminfo_vec(
24+
meminfo_array, meminfo_array + meminfo_array_size);
25+
26+
try {
27+
sycl::event ht_ev = q->submit([&](sycl::handler &cgh) {
28+
for (size_t ev_id = 0; ev_id < nDepERefs; ++ev_id) {
29+
cgh.depends_on(*(unwrap<sycl::event>(depERefs[ev_id])));
30+
}
31+
cgh.host_task([meminfo_array_size, meminfo_vec, nrt]() {
32+
for (size_t i = 0; i < meminfo_array_size; ++i) {
33+
nrt->release(meminfo_vec[i]);
34+
}
35+
});
36+
});
37+
38+
constexpr int result_ok = 0;
39+
40+
*status = result_ok;
41+
auto e_ptr = new sycl::event(ht_ev);
42+
return wrap<sycl::event>(e_ptr);
43+
} catch (const std::exception &e) {
44+
constexpr int result_std_exception = 1;
45+
46+
*status = result_std_exception;
47+
return nullptr;
48+
}
49+
50+
constexpr int result_other_abnormal = 2;
51+
52+
*status = result_other_abnormal;
53+
return nullptr;
54+
}
55+
}

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def submit_sycl_kernel(
375375
"""
376376
Submits the kernel to the specified queue, waits.
377377
"""
378+
378379
eref = None
379380
gr = self._create_sycl_range(global_range)
380381
args1 = [
@@ -409,6 +410,24 @@ def submit_sycl_kernel(
409410
sycl.dpctl_event_delete(self.builder, eref)
410411
return None
411412
else:
413+
# Making sure that arguments not released before the end of
414+
# execution.
415+
meminfo_list = arg_list # FIXME: extract meminfos
416+
total_meminfos = total_kernel_args # FIXME: -||-
417+
dpexrtCtx = DpexRTContext(self.context)
418+
host_eref = dpexrtCtx.schedule_nrt_meminfo_release(
419+
self.builder,
420+
[
421+
dpexrtCtx._context.nrt.get_nrt_api(self.builder),
422+
sycl_queue_ref,
423+
meminfo_list,
424+
self.context.get_constant(types.uintp, total_meminfos),
425+
eref, # ???? should I get pointer of it
426+
1, # ???? should I wrap it as a constant?
427+
None, # put pointer for the status
428+
],
429+
)
430+
# FIXME: should we return host event instead
412431
return eref
413432

414433
def populate_kernel_args_and_args_ty_arrays(

numba_dpex/experimental/launcher.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,21 @@
99
from collections import namedtuple
1010
from typing import Union
1111

12+
import dpctl
1213
from llvmlite import ir as llvmir
1314
from numba.core import cgutils, cpu, types
1415
from numba.extending import intrinsic, overload
1516

1617
from numba_dpex import config, dpjit
1718
from numba_dpex.core.exceptions import UnreachableError
19+
from numba_dpex.core.runtime import context as dpexrt
1820
from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext
19-
from numba_dpex.core.types import DpnpNdArray, NdRangeType, RangeType
21+
from numba_dpex.core.types import (
22+
DpctlSyclEvent,
23+
DpnpNdArray,
24+
NdRangeType,
25+
RangeType,
26+
)
2027
from numba_dpex.core.utils import kernel_launcher as kl
2128
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
2229
from numba_dpex.experimental.kernel_dispatcher import _KernelModule
@@ -256,6 +263,16 @@ def submit_and_wait(self, submit_call_args: _KernelSubmissionArgs) -> None:
256263
"""Generates LLVM IR CallInst to submit a kernel to specified SYCL queue
257264
and then call DPCTLEvent_Wait on the returned event.
258265
"""
266+
eref = self.submit(submit_call_args)
267+
sycl.dpctl_event_wait(self._builder, eref)
268+
sycl.dpctl_event_delete(self._builder, eref)
269+
270+
def submit(
271+
self, submit_call_args: _KernelSubmissionArgs
272+
) -> llvmir.PointerType(llvmir.IntType(8)):
273+
"""Generates LLVM IR CallInst to submit a kernel to specified SYCL
274+
queue.
275+
"""
259276
if config.DEBUG_KERNEL_LAUNCHER:
260277
cgutils.printf(
261278
self._builder, "DPEX-DEBUG: Submit sync range kernel.\n"
@@ -274,8 +291,7 @@ def submit_and_wait(self, submit_call_args: _KernelSubmissionArgs) -> None:
274291
if config.DEBUG_KERNEL_LAUNCHER:
275292
cgutils.printf(self._builder, "DPEX-DEBUG: Wait on event.\n")
276293

277-
sycl.dpctl_event_wait(self._builder, eref)
278-
sycl.dpctl_event_delete(self._builder, eref)
294+
return eref
279295

280296
def cleanup(
281297
self,
@@ -305,7 +321,8 @@ def intrin_launch_trampoline(
305321
"""
306322
kernel_args_list = list(kernel_args)
307323
# signature of this intrinsic
308-
sig = types.void(kernel_fn, index_space, kernel_args)
324+
ty_event = DpctlSyclEvent()
325+
sig = ty_event(kernel_fn, index_space, kernel_args)
309326
# signature of the kernel_fn
310327
kernel_sig = types.void(*kernel_args_list)
311328
kmodule: _KernelModule = kernel_fn.dispatcher.compile(kernel_sig)
@@ -359,10 +376,27 @@ def codegen(cgctx, builder, sig, llargs):
359376
local_range_extents=index_space_values.local_range_extents,
360377
)
361378

362-
fn_body_gen.submit_and_wait(submit_call_args)
379+
eref = fn_body_gen.submit(submit_call_args)
363380

364381
fn_body_gen.cleanup(kernel_bundle_ref=kbref, kernel_ref=kref)
365382

383+
pyapi = cgctx.get_python_api(builder)
384+
385+
event_struct_proxy = cgutils.create_struct_proxy(ty_event)(
386+
cgctx, builder
387+
)
388+
389+
dpexrtCtx = dpexrt.DpexRTContext(cgctx)
390+
391+
# Ref count after the call is equal to 1.
392+
dpexrtCtx.eventstruct_init(
393+
pyapi, eref, event_struct_proxy._getpointer()
394+
)
395+
396+
event_value = event_struct_proxy._getvalue()
397+
398+
return event_value
399+
366400
return sig, codegen
367401

368402

@@ -374,15 +408,15 @@ def _launch_trampoline(kernel_fn, index_space, *kernel_args):
374408
@overload(_launch_trampoline, target="cpu")
375409
def _ol_launch_trampoline(kernel_fn, index_space, *kernel_args):
376410
def impl(kernel_fn, index_space, *kernel_args):
377-
intrin_launch_trampoline( # pylint: disable=E1120
411+
return intrin_launch_trampoline( # pylint: disable=E1120
378412
kernel_fn, index_space, kernel_args
379413
)
380414

381415
return impl
382416

383417

384418
@dpjit
385-
def call_kernel(kernel_fn, index_space, *kernel_args):
419+
def call_kernel(kernel_fn, index_space, *kernel_args) -> dpctl.SyclEvent:
386420
"""Calls a numba_dpex.kernel decorated function from CPython or from another
387421
dpjit function.
388422
@@ -395,4 +429,4 @@ def call_kernel(kernel_fn, index_space, *kernel_args):
395429
kernel_args : List of objects that are passed to the numba_dpex.kernel
396430
decorated function.
397431
"""
398-
_launch_trampoline(kernel_fn, index_space, *kernel_args)
432+
return _launch_trampoline(kernel_fn, index_space, *kernel_args)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import time
2+
3+
import dpnp
4+
5+
import numba_dpex as dpex
6+
import numba_dpex.experimental as exp_dpex
7+
from numba_dpex import Range
8+
9+
10+
@exp_dpex.kernel(
11+
release_gil=False,
12+
no_compile=True,
13+
no_cpython_wrapper=True,
14+
no_cfunc_wrapper=True,
15+
)
16+
def add(a, b, c):
17+
i = dpex.get_global_id(0)
18+
c[i] = b[i] + a[i]
19+
20+
21+
@exp_dpex.kernel(
22+
release_gil=False,
23+
no_compile=True,
24+
no_cpython_wrapper=True,
25+
no_cfunc_wrapper=True,
26+
)
27+
def count_of_two(a, count):
28+
i = dpex.get_global_id(0)
29+
# if a[i]==2:
30+
dpex.atomic.add(count, a[i], 1)
31+
32+
33+
def test_async_add():
34+
size = 1000_000_000
35+
# size = 1000
36+
a = dpnp.ones(size)
37+
b = dpnp.ones(size)
38+
c = dpnp.zeros(size)
39+
count = dpnp.zeros(10)
40+
print(a)
41+
print(b)
42+
print(c)
43+
44+
r = Range(size)
45+
r2 = Range(480)
46+
47+
event_ref = exp_dpex.call_kernel(add, r, a, b, c)
48+
print(event_ref)
49+
event_ref.wait()
50+
51+
# time.sleep(3)
52+
53+
event_ref2 = exp_dpex.call_kernel(count_of_two, r2, b, count)
54+
event_ref2.wait()
55+
print(count)
56+
57+
count[0] = 0
58+
59+
print(c)
60+
61+
62+
if __name__ == "__main__":
63+
test_async_add()

0 commit comments

Comments
 (0)