22#
33# SPDX-License-Identifier: Apache-2.0
44
5+ from functools import partial
6+
57import dpnp
68from numba .core import ir , types
79from numba .core .ir_utils import get_np_ufunc_typ , mk_unique_var
8- from numba .core .pythonapi import NativeValue , PythonAPI , box , unbox
910
1011from .usm_ndarray_type import USMNdArray
1112
1213
14+ def partialclass (cls , * args , ** kwds ):
15+ """Creates fabric class of the original class with preset initialization
16+ arguments."""
17+ cls0 = partial (cls , * args , ** kwds )
18+ new_cls = type (
19+ cls .__name__ + "Partial" ,
20+ (cls ,),
21+ {"__new__" : lambda cls , * args , ** kwds : cls0 (* args , ** kwds )},
22+ )
23+
24+ return new_cls
25+
26+
1327class DpnpNdArray (USMNdArray ):
1428 """
1529 The Numba type to represent an dpnp.ndarray. The type has the same
@@ -40,15 +54,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
4054 Returns: The DpnpNdArray class.
4155 """
4256 if method == "__call__" :
43- if not all (
44- (
45- isinstance (inp , DpnpNdArray )
46- or isinstance (inp , types .abstract .Number )
47- )
48- for inp in inputs
49- ):
57+ dpnp_type = None
58+
59+ for inp in inputs :
60+ if isinstance (inp , DpnpNdArray ):
61+ dpnp_type = inp
62+ continue
63+ if isinstance (inp , types .abstract .Number ):
64+ continue
65+
5066 return NotImplemented
51- return DpnpNdArray
67+
68+ assert dpnp_type is not None
69+
70+ return partialclass (
71+ DpnpNdArray , queue = dpnp_type .queue , usm_type = dpnp_type .usm_type
72+ )
5273 else :
5374 return
5475
@@ -71,6 +92,8 @@ def __allocate__(
7192 lhs_typ ,
7293 size_typ ,
7394 out ,
95+ # dpex specific argument:
96+ queue_ir_val = None ,
7497 ):
7598 """Generates the Numba typed IR representing the allocation of a new
7699 DpnpNdArray using the dpnp.ndarray overload.
@@ -94,6 +117,10 @@ def __allocate__(
94117
95118 Returns: The IR Value for the allocated array
96119 """
120+ # TODO: it looks like it is being called only for parfor allocations,
121+ # so can we rely on it? We can grab information from input arguments
122+ # from rhs, but doc does not set any restriction on parfor use only.
123+ assert queue_ir_val is not None
97124 g_np_var = ir .Var (scope , mk_unique_var ("$np_g_var" ), loc )
98125 if typemap :
99126 typemap [g_np_var .name ] = types .misc .Module (dpnp )
@@ -132,11 +159,13 @@ def __allocate__(
132159 usm_typ_var = ir .Var (scope , mk_unique_var ("$np_usm_type_var" ), loc )
133160 # A default device string arg added as a placeholder
134161 device_typ_var = ir .Var (scope , mk_unique_var ("$np_device_var" ), loc )
162+ queue_typ_var = ir .Var (scope , mk_unique_var ("$np_queue_var" ), loc )
135163
136164 if typemap :
137165 typemap [layout_var .name ] = types .literal (lhs_typ .layout )
138166 typemap [usm_typ_var .name ] = types .literal (lhs_typ .usm_type )
139- typemap [device_typ_var .name ] = types .literal (lhs_typ .device )
167+ typemap [device_typ_var .name ] = types .none
168+ typemap [queue_typ_var .name ] = lhs_typ .queue
140169
141170 layout_var_assign = ir .Assign (
142171 ir .Const (lhs_typ .layout , loc ), layout_var , loc
@@ -145,16 +174,29 @@ def __allocate__(
145174 ir .Const (lhs_typ .usm_type , loc ), usm_typ_var , loc
146175 )
147176 device_typ_var_assign = ir .Assign (
148- ir .Const (lhs_typ . device , loc ), device_typ_var , loc
177+ ir .Const (None , loc ), device_typ_var , loc
149178 )
179+ queue_typ_var_assign = ir .Assign (queue_ir_val , queue_typ_var , loc )
150180
151181 out .extend (
152- [layout_var_assign , usm_typ_var_assign , device_typ_var_assign ]
182+ [
183+ layout_var_assign ,
184+ usm_typ_var_assign ,
185+ device_typ_var_assign ,
186+ queue_typ_var_assign ,
187+ ]
153188 )
154189
155190 alloc_call = ir .Expr .call (
156191 attr_var ,
157- [size_var , typ_var , layout_var , device_typ_var , usm_typ_var ],
192+ [
193+ size_var ,
194+ typ_var ,
195+ layout_var ,
196+ device_typ_var ,
197+ usm_typ_var ,
198+ queue_typ_var ,
199+ ],
158200 (),
159201 loc ,
160202 )
@@ -170,6 +212,7 @@ def __allocate__(
170212 layout_var ,
171213 device_typ_var ,
172214 usm_typ_var ,
215+ queue_typ_var ,
173216 ]
174217 ],
175218 {},
0 commit comments