2525
2626from numba_dpex .core .runtime import context as dpexrt
2727from numba_dpex .core .types import DpnpNdArray
28+ from numba_dpex .core .types .dpctl_types import DpctlSyclQueue
2829
2930# Numpy array constructors
3031
@@ -172,13 +173,73 @@ def _mk_alloc(
172173 return out
173174
174175
176+ def make_queue (context , builder , arrtype ):
177+ """Utility function used for allocating a new queue.
178+
179+ This function will allocates a new queue (e.g. SYCL queue)
180+ during LLVM code generation (lowering). Given a target context,
181+ builder, array type, returns a LLVM value pointing at a numba-dpex
182+ runtime allocated queue.
183+
184+ Args:
185+ context (numba.core.base.BaseContext): Any of the context
186+ derived from Numba's BaseContext
187+ (e.g. `numba.core.cpu.CPUContext`).
188+ builder (llvmlite.ir.builder.IRBuilder): The IR builder
189+ from `llvmlite` for code generation.
190+ arrtype (numba_dpex.core.types.dpnp_ndarray_type.DpnpNdArray):
191+ Any of the array types derived from
192+ `numba.core.types.nptypes.Array`,
193+ e.g. `numba_dpex.core.types.dpnp_ndarray_type.DpnpNdArray`.
194+ Please refer to `numba_dpex.dpnp_iface._intrinsic.alloc_empty_arrayobj()`
195+ function for details on how to construct this argument.
196+
197+ Returns:
198+ tuple: A tuple containing `llvmlite.ir.instructions.ExtractValue`,
199+ `llvmlite.ir.instructions.CastInstr` and `numba.core.pythonapi.PythonAPI`.
200+ """
201+
202+ pyapi = context .get_python_api (builder )
203+ queue_struct_proxy = cgutils .create_struct_proxy (
204+ DpctlSyclQueue (arrtype .queue )
205+ )(context , builder )
206+ queue_struct_ptr = queue_struct_proxy ._getpointer ()
207+ queue_struct_voidptr = builder .bitcast (queue_struct_ptr , cgutils .voidptr_t )
208+
209+ address = context .get_constant (types .intp , id (arrtype .queue ))
210+ queue_address_ptr = builder .inttoptr (address , cgutils .voidptr_t )
211+
212+ dpexrtCtx = dpexrt .DpexRTContext (context )
213+ dpexrtCtx .queuestruct_from_python (
214+ pyapi , queue_address_ptr , queue_struct_voidptr
215+ )
216+ # errcode = dpexrtCtx.queuestruct_from_python(
217+ # pyapi, queue_address_ptr, queue_struct_voidptr
218+ # )
219+ # is_error = cgutils.is_not_null(builder, errcode)
220+ # # Handle error
221+ # with builder.if_then(is_error, likely=False):
222+ # pyapi.err_set_string(
223+ # "_patches.make_queue(): PyExc_TypeError",
224+ # "can't unbox dpctl.SyclQueue from PyObject into a Numba "
225+ # "native value. The object maybe of a different type",
226+ # )
227+
228+ queue_struct = builder .load (queue_struct_ptr )
229+ queue_ref = builder .extract_value (queue_struct , 1 )
230+
231+ return (queue_ref , queue_address_ptr , pyapi )
232+
233+
175234def _empty_nd_impl (context , builder , arrtype , shapes ):
176235 """Utility function used for allocating a new array during LLVM code
177236 generation (lowering). Given a target context, builder, array
178237 type, and a tuple or list of lowered dimension sizes, returns a
179238 LLVM value pointing at a Numba runtime allocated array.
180239 """
181240
241+ (queue , queue_ptr , pyapi ) = make_queue (context , builder , arrtype )
242+
182243 arycls = make_array (arrtype )
183244 ary = arycls (context , builder )
184245
@@ -231,21 +292,16 @@ def _empty_nd_impl(context, builder, arrtype, shapes):
231292
232293 if isinstance (arrtype , DpnpNdArray ):
233294 usm_ty = arrtype .usm_type
234- usm_ty_val = 0
235- if usm_ty == "device" :
236- usm_ty_val = 1
237- elif usm_ty == "shared" :
238- usm_ty_val = 2
239- elif usm_ty == "host" :
240- usm_ty_val = 3
241- usm_type = context .get_constant (types .uint64 , usm_ty_val )
242- device = context .insert_const_string (builder .module , arrtype .device )
295+ usm_ty_map = {"device" : 1 , "shared" : 2 , "host" : 3 }
296+ usm_type = context .get_constant (
297+ types .uint64 , usm_ty_map [usm_ty ] if usm_ty in usm_ty_map else 0
298+ )
243299
244300 args = (
245301 context .get_dummy_value (),
246302 allocsize ,
247303 usm_type ,
248- device ,
304+ queue ,
249305 )
250306 mip = types .MemInfoPointer (types .voidptr )
251307 arytypeclass = types .TypeRef (type (arrtype ))
@@ -265,6 +321,7 @@ def _empty_nd_impl(context, builder, arrtype, shapes):
265321 fnop .get_call_type (context .typing_context , sig .args , {})
266322 eqfn = context .get_function (fnop , sig )
267323 meminfo = eqfn (builder , args )
324+ pyapi .decref (queue_ptr )
268325 else :
269326 dtype = arrtype .dtype
270327 align_val = context .get_preferred_array_alignment (dtype )
@@ -298,36 +355,36 @@ def _empty_nd_impl(context, builder, arrtype, shapes):
298355
299356
300357@overload_classmethod (DpnpNdArray , "_usm_allocate" )
301- def _ol_array_allocate (cls , allocsize , usm_type , device ):
358+ def _ol_array_allocate (cls , allocsize , usm_type , queue ):
302359 """Implements an allocator for dpnp.ndarrays."""
303360
304- def impl (cls , allocsize , usm_type , device ):
305- return intrin_usm_alloc (allocsize , usm_type , device )
361+ def impl (cls , allocsize , usm_type , queue ):
362+ return intrin_usm_alloc (allocsize , usm_type , queue )
306363
307364 return impl
308365
309366
310367numba_config .DISABLE_PERFORMANCE_WARNINGS = 0
311368
312369
313- def _call_usm_allocator (arrtype , size , usm_type , device ):
370+ def _call_usm_allocator (arrtype , size , usm_type , queue ):
314371 """Trampoline to call the intrinsic used for allocation"""
315- return arrtype ._usm_allocate (size , usm_type , device )
372+ return arrtype ._usm_allocate (size , usm_type , queue )
316373
317374
318375numba_config .DISABLE_PERFORMANCE_WARNINGS = 1
319376
320377
321378@intrinsic
322- def intrin_usm_alloc (typingctx , allocsize , usm_type , device ):
379+ def intrin_usm_alloc (typingctx , allocsize , usm_type , queue ):
323380 """Intrinsic to call into the allocator for Array"""
324381
325382 def codegen (context , builder , signature , args ):
326- [allocsize , usm_type , device ] = args
383+ [allocsize , usm_type , queue ] = args
327384 dpexrtCtx = dpexrt .DpexRTContext (context )
328- meminfo = dpexrtCtx .meminfo_alloc (builder , allocsize , usm_type , device )
385+ meminfo = dpexrtCtx .meminfo_alloc (builder , allocsize , usm_type , queue )
329386 return meminfo
330387
331388 mip = types .MemInfoPointer (types .voidptr ) # return untyped pointer
332- sig = signature (mip , allocsize , usm_type , device )
389+ sig = signature (mip , allocsize , usm_type , queue )
333390 return sig , codegen
0 commit comments