44
55import dpnp
66from numba import errors , types
7+ from numba .core .types import scalars
8+ from numba .core .types .containers import UniTuple
79from numba .core .typing .npydecl import parse_dtype as _ty_parse_dtype
810from numba .core .typing .npydecl import parse_shape as _ty_parse_shape
911from numba .extending import overload
2022 impl_dpnp_ones_like ,
2123 impl_dpnp_zeros ,
2224 impl_dpnp_zeros_like ,
23- intrin_usm_alloc ,
2425)
2526
2627# =========================================================================
2728# Helps to parse dpnp constructor arguments
2829# =========================================================================
2930
3031
31- def _parse_dtype (dtype , data = None ):
32+ def _parse_dim (x1 ):
33+ if hasattr (x1 , "ndim" ) and x1 .ndim :
34+ return x1 .ndim
35+ elif isinstance (x1 , scalars .Integer ):
36+ r = 1
37+ return r
38+ elif isinstance (x1 , UniTuple ):
39+ r = len (x1 )
40+ return r
41+ else :
42+ return 0
43+
44+
45+ def _parse_dtype (dtype ):
3246 """Resolve dtype parameter.
3347
3448 Resolves the dtype parameter based on the given value
@@ -44,9 +58,8 @@ class for nd-arrays. Defaults to None.
4458 numba.core.types.functions.NumberClass: Resolved numba type
4559 class for number classes.
4660 """
61+
4762 _dtype = None
48- if data and isinstance (data , types .Array ):
49- _dtype = data .dtype
5063 if not is_nonelike (dtype ):
5164 _dtype = _ty_parse_dtype (dtype )
5265 return _dtype
@@ -60,6 +73,9 @@ def _parse_layout(layout):
6073 raise errors .NumbaValueError (msg )
6174 return layout_type_str
6275 elif isinstance (layout , str ):
76+ if layout not in ["C" , "F" , "A" ]:
77+ msg = f"Invalid layout specified: '{ layout } '"
78+ raise errors .NumbaValueError (msg )
6379 return layout
6480 else :
6581 raise TypeError (
@@ -94,6 +110,9 @@ def _parse_usm_type(usm_type):
94110 raise errors .NumbaValueError (msg )
95111 return usm_type_str
96112 elif isinstance (usm_type , str ):
113+ if usm_type not in ["shared" , "device" , "host" ]:
114+ msg = f"Invalid usm_type specified: '{ usm_type } '"
115+ raise errors .NumbaValueError (msg )
97116 return usm_type
98117 else :
99118 raise TypeError (
@@ -150,6 +169,7 @@ def build_dpnp_ndarray(
150169 ndim ,
151170 layout = "C" ,
152171 dtype = None ,
172+ is_fill_value_float = False ,
153173 usm_type = "device" ,
154174 device = None ,
155175 sycl_queue = None ,
@@ -163,6 +183,8 @@ def build_dpnp_ndarray(
163183 Data type of the array. Can be typestring, a `numpy.dtype`
164184 object, `numpy` char string, or a numpy scalar type.
165185 Default: None.
186+ is_fill_value_float (bool): Specify if the fill value is floating
187+ point.
166188 usm_type (numba.core.types.misc.StringLiteral, optional):
167189 The type of SYCL USM allocation for the output array.
168190 Allowed values are "device"|"shared"|"host".
@@ -198,6 +220,7 @@ def build_dpnp_ndarray(
198220 ndim = ndim ,
199221 layout = layout ,
200222 dtype = dtype ,
223+ is_fill_value_float = is_fill_value_float ,
201224 usm_type = usm_type ,
202225 device = device ,
203226 queue = sycl_queue ,
@@ -280,6 +303,7 @@ def ol_dpnp_empty(
280303 _ndim ,
281304 layout = _layout ,
282305 dtype = _dtype ,
306+ is_fill_value_float = True ,
283307 usm_type = _usm_type ,
284308 device = _device ,
285309 sycl_queue = _sycl_queue ,
@@ -384,6 +408,7 @@ def ol_dpnp_zeros(
384408 _ndim ,
385409 layout = _layout ,
386410 dtype = _dtype ,
411+ is_fill_value_float = True ,
387412 usm_type = _usm_type ,
388413 device = _device ,
389414 sycl_queue = _sycl_queue ,
@@ -488,6 +513,7 @@ def ol_dpnp_ones(
488513 _ndim ,
489514 layout = _layout ,
490515 dtype = _dtype ,
516+ is_fill_value_float = True ,
491517 usm_type = _usm_type ,
492518 device = _device ,
493519 sycl_queue = _sycl_queue ,
@@ -586,6 +612,7 @@ def ol_dpnp_full(
586612
587613 _ndim = _ty_parse_shape (shape )
588614 _dtype = _parse_dtype (dtype )
615+ _is_fill_value_float = isinstance (fill_value , scalars .Float )
589616 _layout = _parse_layout (order )
590617 _usm_type = _parse_usm_type (usm_type ) if usm_type else "device"
591618 _device = _parse_device_filter_string (device ) if device else None
@@ -596,6 +623,7 @@ def ol_dpnp_full(
596623 _ndim ,
597624 layout = _layout ,
598625 dtype = _dtype ,
626+ is_fill_value_float = _is_fill_value_float ,
599627 usm_type = _usm_type ,
600628 device = _device ,
601629 sycl_queue = _sycl_queue ,
@@ -699,8 +727,8 @@ def ol_dpnp_empty_like(
699727 + "inside overloaded dpnp.empty_like() function."
700728 )
701729
702- _ndim = x1 . ndim if hasattr (x1 , "ndim" ) and x1 . ndim else 0
703- _dtype = _parse_dtype ( dtype , data = x1 )
730+ _ndim = _parse_dim (x1 )
731+ _dtype = x1 . dtype if isinstance ( x1 , types . Array ) else _parse_dtype ( dtype )
704732 _order = x1 .layout if order is None else order
705733 _usm_type = _parse_usm_type (usm_type ) if usm_type else "device"
706734 _device = _parse_device_filter_string (device ) if device else None
@@ -812,8 +840,8 @@ def ol_dpnp_zeros_like(
812840 + "inside overloaded dpnp.zeros_like() function."
813841 )
814842
815- _ndim = x1 . ndim if hasattr (x1 , "ndim" ) and x1 . ndim else 0
816- _dtype = _parse_dtype ( dtype , data = x1 )
843+ _ndim = _parse_dim (x1 )
844+ _dtype = x1 . dtype if isinstance ( x1 , types . Array ) else _parse_dtype ( dtype )
817845 _order = x1 .layout if order is None else order
818846 _usm_type = _parse_usm_type (usm_type ) if usm_type else "device"
819847 _device = _parse_device_filter_string (device ) if device else None
@@ -924,8 +952,8 @@ def ol_dpnp_ones_like(
924952 + "inside overloaded dpnp.ones_like() function."
925953 )
926954
927- _ndim = x1 . ndim if hasattr (x1 , "ndim" ) and x1 . ndim else 0
928- _dtype = _parse_dtype ( dtype , data = x1 )
955+ _ndim = _parse_dim (x1 )
956+ _dtype = x1 . dtype if isinstance ( x1 , types . Array ) else _parse_dtype ( dtype )
929957 _order = x1 .layout if order is None else order
930958 _usm_type = _parse_usm_type (usm_type ) if usm_type else "device"
931959 _device = _parse_device_filter_string (device ) if device else None
@@ -1041,8 +1069,9 @@ def ol_dpnp_full_like(
10411069 + "inside overloaded dpnp.full_like() function."
10421070 )
10431071
1044- _ndim = x1 .ndim if hasattr (x1 , "ndim" ) and x1 .ndim else 0
1045- _dtype = _parse_dtype (dtype , data = x1 )
1072+ _ndim = _parse_dim (x1 )
1073+ _dtype = x1 .dtype if isinstance (x1 , types .Array ) else _parse_dtype (dtype )
1074+ _is_fill_value_float = isinstance (fill_value , scalars .Float )
10461075 _order = x1 .layout if order is None else order
10471076 _usm_type = _parse_usm_type (usm_type ) if usm_type else "device"
10481077 _device = _parse_device_filter_string (device ) if device else None
@@ -1052,6 +1081,7 @@ def ol_dpnp_full_like(
10521081 _ndim ,
10531082 layout = _order ,
10541083 dtype = _dtype ,
1084+ is_fill_value_float = _is_fill_value_float ,
10551085 usm_type = _usm_type ,
10561086 device = _device ,
10571087 sycl_queue = _sycl_queue ,
0 commit comments