99from collections import namedtuple
1010from typing import Union
1111
12+ import dpctl
1213from llvmlite import ir as llvmir
1314from numba .core import cgutils , cpu , types
1415from numba .extending import intrinsic , overload
1516
1617from numba_dpex import config , dpjit
1718from numba_dpex .core .exceptions import UnreachableError
19+ from numba_dpex .core .runtime import context as dpexrt
1820from numba_dpex .core .targets .kernel_target import DpexKernelTargetContext
19- from numba_dpex .core .types import DpnpNdArray , NdRangeType , RangeType
21+ from numba_dpex .core .types import (
22+ DpctlSyclEvent ,
23+ DpnpNdArray ,
24+ NdRangeType ,
25+ RangeType ,
26+ )
2027from numba_dpex .core .utils import kernel_launcher as kl
2128from numba_dpex .dpctl_iface import libsyclinterface_bindings as sycl
2229from numba_dpex .experimental .kernel_dispatcher import _KernelModule
@@ -256,6 +263,16 @@ def submit_and_wait(self, submit_call_args: _KernelSubmissionArgs) -> None:
256263 """Generates LLVM IR CallInst to submit a kernel to specified SYCL queue
257264 and then call DPCTLEvent_Wait on the returned event.
258265 """
266+ eref = self .submit (submit_call_args )
267+ sycl .dpctl_event_wait (self ._builder , eref )
268+ sycl .dpctl_event_delete (self ._builder , eref )
269+
270+ def submit (
271+ self , submit_call_args : _KernelSubmissionArgs
272+ ) -> llvmir .PointerType (llvmir .IntType (8 )):
273+ """Generates LLVM IR CallInst to submit a kernel to specified SYCL
274+ queue.
275+ """
259276 if config .DEBUG_KERNEL_LAUNCHER :
260277 cgutils .printf (
261278 self ._builder , "DPEX-DEBUG: Submit sync range kernel.\n "
@@ -274,8 +291,7 @@ def submit_and_wait(self, submit_call_args: _KernelSubmissionArgs) -> None:
274291 if config .DEBUG_KERNEL_LAUNCHER :
275292 cgutils .printf (self ._builder , "DPEX-DEBUG: Wait on event.\n " )
276293
277- sycl .dpctl_event_wait (self ._builder , eref )
278- sycl .dpctl_event_delete (self ._builder , eref )
294+ return eref
279295
280296 def cleanup (
281297 self ,
@@ -305,7 +321,8 @@ def intrin_launch_trampoline(
305321 """
306322 kernel_args_list = list (kernel_args )
307323 # signature of this intrinsic
308- sig = types .void (kernel_fn , index_space , kernel_args )
324+ ty_event = DpctlSyclEvent ()
325+ sig = ty_event (kernel_fn , index_space , kernel_args )
309326 # signature of the kernel_fn
310327 kernel_sig = types .void (* kernel_args_list )
311328 kmodule : _KernelModule = kernel_fn .dispatcher .compile (kernel_sig )
@@ -359,10 +376,27 @@ def codegen(cgctx, builder, sig, llargs):
359376 local_range_extents = index_space_values .local_range_extents ,
360377 )
361378
362- fn_body_gen .submit_and_wait (submit_call_args )
379+ eref = fn_body_gen .submit (submit_call_args )
363380
364381 fn_body_gen .cleanup (kernel_bundle_ref = kbref , kernel_ref = kref )
365382
383+ pyapi = cgctx .get_python_api (builder )
384+
385+ event_struct_proxy = cgutils .create_struct_proxy (ty_event )(
386+ cgctx , builder
387+ )
388+
389+ dpexrtCtx = dpexrt .DpexRTContext (cgctx )
390+
391+ # Ref count after the call is equal to 1.
392+ dpexrtCtx .eventstruct_init (
393+ pyapi , eref , event_struct_proxy ._getpointer ()
394+ )
395+
396+ event_value = event_struct_proxy ._getvalue ()
397+
398+ return event_value
399+
366400 return sig , codegen
367401
368402
@@ -374,15 +408,15 @@ def _launch_trampoline(kernel_fn, index_space, *kernel_args):
374408@overload (_launch_trampoline , target = "cpu" )
375409def _ol_launch_trampoline (kernel_fn , index_space , * kernel_args ):
376410 def impl (kernel_fn , index_space , * kernel_args ):
377- intrin_launch_trampoline ( # pylint: disable=E1120
411+ return intrin_launch_trampoline ( # pylint: disable=E1120
378412 kernel_fn , index_space , kernel_args
379413 )
380414
381415 return impl
382416
383417
384418@dpjit
385- def call_kernel (kernel_fn , index_space , * kernel_args ):
419+ def call_kernel (kernel_fn , index_space , * kernel_args ) -> dpctl . SyclEvent :
386420 """Calls a numba_dpex.kernel decorated function from CPython or from another
387421 dpjit function.
388422
@@ -395,4 +429,4 @@ def call_kernel(kernel_fn, index_space, *kernel_args):
395429 kernel_args : List of objects that are passed to the numba_dpex.kernel
396430 decorated function.
397431 """
398- _launch_trampoline (kernel_fn , index_space , * kernel_args )
432+ return _launch_trampoline (kernel_fn , index_space , * kernel_args )
0 commit comments