22#
33# SPDX-License-Identifier: Apache-2.0
44
5+ from collections import namedtuple
6+
57import numpy
68from llvmlite import ir as llvmir
79from llvmlite .ir import Constant
2729from numba_dpex .core .types import DpnpNdArray
2830from numba_dpex .core .types .dpctl_types import DpctlSyclQueue
2931
32+ # from numba_dpex import utils
33+ # from numba_dpex.dpctl_iface import DpctlCAPIFnBuilder
34+
35+
3036# Numpy array constructors
3137
3238
@@ -195,8 +201,9 @@ def make_queue(context, builder, arrtype):
195201 function for details on how to construct this argument.
196202
197203 Returns:
198- ret (tuple): A tuple containing `llvmlite.ir.instructions.ExtractValue`,
199- `llvmlite.ir.instructions.CastInstr` and `numba.core.pythonapi.PythonAPI`.
204+ ret (namedtuple): A namedtuple containing `llvmlite.ir.instructions.ExtractValue`
205+ as `queue_ref`, `llvmlite.ir.instructions.CastInstr` as `queue_address_ptr`
206+ and `numba.core.pythonapi.PythonAPI` as `pyapi`.
200207 """
201208
202209 pyapi = context .get_python_api (builder )
@@ -229,7 +236,10 @@ def make_queue(context, builder, arrtype):
229236 queue_struct = builder .load (queue_struct_ptr )
230237 queue_ref = builder .extract_value (queue_struct , 1 )
231238
232- ret = (queue_ref , queue_address_ptr , pyapi )
239+ return_values = namedtuple (
240+ "return_values" , "queue_ref queue_address_ptr pyapi"
241+ )
242+ ret = return_values (queue_ref , queue_address_ptr , pyapi )
233243
234244 return ret
235245
@@ -294,7 +304,17 @@ def _empty_nd_impl(context, builder, arrtype, shapes):
294304 )
295305
296306 if isinstance (arrtype , DpnpNdArray ):
297- (queue , queue_ptr , pyapi ) = make_queue (context , builder , arrtype )
307+ (queue_ref , queue_ptr , pyapi ) = make_queue (context , builder , arrtype )
308+ # This might fix the segfault
309+ # sycl_queue_val = cgutils.alloca_once(
310+ # builder,
311+ # utils.get_llvm_type(context=context, type=types.voidptr),
312+ # )
313+ # fn = DpctlCAPIFnBuilder.get_dpctl_queue_copy(
314+ # builder=builder, context=context
315+ # )
316+ # builder.store(builder.call(fn, []), sycl_queue_val)
317+
298318 usm_ty = arrtype .usm_type
299319 usm_ty_map = {"device" : 1 , "shared" : 2 , "host" : 3 }
300320 usm_type = context .get_constant (
@@ -305,7 +325,7 @@ def _empty_nd_impl(context, builder, arrtype, shapes):
305325 context .get_dummy_value (),
306326 allocsize ,
307327 usm_type ,
308- queue ,
328+ queue_ref ,
309329 )
310330 mip = types .MemInfoPointer (types .voidptr )
311331 arytypeclass = types .TypeRef (type (arrtype ))
@@ -355,7 +375,11 @@ def _empty_nd_impl(context, builder, arrtype, shapes):
355375 meminfo = meminfo ,
356376 )
357377
358- ret = (ary , queue ) if isinstance (arrtype , DpnpNdArray ) else ary
378+ if isinstance (arrtype , DpnpNdArray ):
379+ return_values = namedtuple ("return_values" , "ary queue_ref" )
380+ ret = return_values (ary , queue_ref )
381+ else :
382+ ret = ary
359383
360384 return ret
361385
0 commit comments