Skip to content

Commit f785652

Browse files
khaledDiptorup Deb
authored andcommitted
Update unit tests for dpnp array constructors.
1 parent 75e6c12 commit f785652

File tree

10 files changed

+1013
-211
lines changed

10 files changed

+1013
-211
lines changed

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
ndim,
2424
layout="C",
2525
dtype=None,
26+
is_fill_value_float=False,
2627
usm_type="device",
2728
device=None,
2829
queue=None,
@@ -66,9 +67,22 @@ def __init__(
6667
self.device = self.queue.sycl_device.filter_string
6768

6869
if not dtype:
69-
dummy_tensor = dpctl.tensor.empty(
70-
1, order=layout, usm_type=usm_type, sycl_queue=self.queue
71-
)
70+
if is_fill_value_float:
71+
dummy_tensor = dpctl.tensor.empty(
72+
1,
73+
dtype=dpctl.tensor.float64,
74+
order=layout,
75+
usm_type=usm_type,
76+
sycl_queue=self.queue,
77+
)
78+
else:
79+
dummy_tensor = dpctl.tensor.empty(
80+
1,
81+
dtype=dpctl.tensor.int64,
82+
order=layout,
83+
usm_type=usm_type,
84+
sycl_queue=self.queue,
85+
)
7286
# convert dpnp type to numba/numpy type
7387
_dtype = dummy_tensor.dtype
7488
self.dtype = from_dtype(_dtype)

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import dpnp
66
from numba import errors, types
7+
from numba.core.types import scalars
8+
from numba.core.types.containers import UniTuple
79
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
810
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
911
from numba.extending import overload
@@ -20,15 +22,27 @@
2022
impl_dpnp_ones_like,
2123
impl_dpnp_zeros,
2224
impl_dpnp_zeros_like,
23-
intrin_usm_alloc,
2425
)
2526

2627
# =========================================================================
2728
# Helps to parse dpnp constructor arguments
2829
# =========================================================================
2930

3031

31-
def _parse_dtype(dtype, data=None):
32+
def _parse_dim(x1):
33+
if hasattr(x1, "ndim") and x1.ndim:
34+
return x1.ndim
35+
elif isinstance(x1, scalars.Integer):
36+
r = 1
37+
return r
38+
elif isinstance(x1, UniTuple):
39+
r = len(x1)
40+
return r
41+
else:
42+
return 0
43+
44+
45+
def _parse_dtype(dtype):
3246
"""Resolve dtype parameter.
3347
3448
Resolves the dtype parameter based on the given value
@@ -44,9 +58,8 @@ class for nd-arrays. Defaults to None.
4458
numba.core.types.functions.NumberClass: Resolved numba type
4559
class for number classes.
4660
"""
61+
4762
_dtype = None
48-
if data and isinstance(data, types.Array):
49-
_dtype = data.dtype
5063
if not is_nonelike(dtype):
5164
_dtype = _ty_parse_dtype(dtype)
5265
return _dtype
@@ -60,6 +73,9 @@ def _parse_layout(layout):
6073
raise errors.NumbaValueError(msg)
6174
return layout_type_str
6275
elif isinstance(layout, str):
76+
if layout not in ["C", "F", "A"]:
77+
msg = f"Invalid layout specified: '{layout}'"
78+
raise errors.NumbaValueError(msg)
6379
return layout
6480
else:
6581
raise TypeError(
@@ -94,6 +110,9 @@ def _parse_usm_type(usm_type):
94110
raise errors.NumbaValueError(msg)
95111
return usm_type_str
96112
elif isinstance(usm_type, str):
113+
if usm_type not in ["shared", "device", "host"]:
114+
msg = f"Invalid usm_type specified: '{usm_type}'"
115+
raise errors.NumbaValueError(msg)
97116
return usm_type
98117
else:
99118
raise TypeError(
@@ -150,6 +169,7 @@ def build_dpnp_ndarray(
150169
ndim,
151170
layout="C",
152171
dtype=None,
172+
is_fill_value_float=False,
153173
usm_type="device",
154174
device=None,
155175
sycl_queue=None,
@@ -163,6 +183,8 @@ def build_dpnp_ndarray(
163183
Data type of the array. Can be typestring, a `numpy.dtype`
164184
object, `numpy` char string, or a numpy scalar type.
165185
Default: None.
186+
is_fill_value_float (bool): Specify if the fill value is floating
187+
point.
166188
usm_type (numba.core.types.misc.StringLiteral, optional):
167189
The type of SYCL USM allocation for the output array.
168190
Allowed values are "device"|"shared"|"host".
@@ -198,6 +220,7 @@ def build_dpnp_ndarray(
198220
ndim=ndim,
199221
layout=layout,
200222
dtype=dtype,
223+
is_fill_value_float=is_fill_value_float,
201224
usm_type=usm_type,
202225
device=device,
203226
queue=sycl_queue,
@@ -280,6 +303,7 @@ def ol_dpnp_empty(
280303
_ndim,
281304
layout=_layout,
282305
dtype=_dtype,
306+
is_fill_value_float=True,
283307
usm_type=_usm_type,
284308
device=_device,
285309
sycl_queue=_sycl_queue,
@@ -384,6 +408,7 @@ def ol_dpnp_zeros(
384408
_ndim,
385409
layout=_layout,
386410
dtype=_dtype,
411+
is_fill_value_float=True,
387412
usm_type=_usm_type,
388413
device=_device,
389414
sycl_queue=_sycl_queue,
@@ -488,6 +513,7 @@ def ol_dpnp_ones(
488513
_ndim,
489514
layout=_layout,
490515
dtype=_dtype,
516+
is_fill_value_float=True,
491517
usm_type=_usm_type,
492518
device=_device,
493519
sycl_queue=_sycl_queue,
@@ -586,6 +612,7 @@ def ol_dpnp_full(
586612

587613
_ndim = _ty_parse_shape(shape)
588614
_dtype = _parse_dtype(dtype)
615+
_is_fill_value_float = isinstance(fill_value, scalars.Float)
589616
_layout = _parse_layout(order)
590617
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
591618
_device = _parse_device_filter_string(device) if device else None
@@ -596,6 +623,7 @@ def ol_dpnp_full(
596623
_ndim,
597624
layout=_layout,
598625
dtype=_dtype,
626+
is_fill_value_float=_is_fill_value_float,
599627
usm_type=_usm_type,
600628
device=_device,
601629
sycl_queue=_sycl_queue,
@@ -699,8 +727,8 @@ def ol_dpnp_empty_like(
699727
+ "inside overloaded dpnp.empty_like() function."
700728
)
701729

702-
_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
703-
_dtype = _parse_dtype(dtype, data=x1)
730+
_ndim = _parse_dim(x1)
731+
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
704732
_order = x1.layout if order is None else order
705733
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
706734
_device = _parse_device_filter_string(device) if device else None
@@ -812,8 +840,8 @@ def ol_dpnp_zeros_like(
812840
+ "inside overloaded dpnp.zeros_like() function."
813841
)
814842

815-
_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
816-
_dtype = _parse_dtype(dtype, data=x1)
843+
_ndim = _parse_dim(x1)
844+
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
817845
_order = x1.layout if order is None else order
818846
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
819847
_device = _parse_device_filter_string(device) if device else None
@@ -924,8 +952,8 @@ def ol_dpnp_ones_like(
924952
+ "inside overloaded dpnp.ones_like() function."
925953
)
926954

927-
_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
928-
_dtype = _parse_dtype(dtype, data=x1)
955+
_ndim = _parse_dim(x1)
956+
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
929957
_order = x1.layout if order is None else order
930958
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
931959
_device = _parse_device_filter_string(device) if device else None
@@ -1041,8 +1069,9 @@ def ol_dpnp_full_like(
10411069
+ "inside overloaded dpnp.full_like() function."
10421070
)
10431071

1044-
_ndim = x1.ndim if hasattr(x1, "ndim") and x1.ndim else 0
1045-
_dtype = _parse_dtype(dtype, data=x1)
1072+
_ndim = _parse_dim(x1)
1073+
_dtype = x1.dtype if isinstance(x1, types.Array) else _parse_dtype(dtype)
1074+
_is_fill_value_float = isinstance(fill_value, scalars.Float)
10461075
_order = x1.layout if order is None else order
10471076
_usm_type = _parse_usm_type(usm_type) if usm_type else "device"
10481077
_device = _parse_device_filter_string(device) if device else None
@@ -1052,6 +1081,7 @@ def ol_dpnp_full_like(
10521081
_ndim,
10531082
layout=_order,
10541083
dtype=_dtype,
1084+
is_fill_value_float=_is_fill_value_float,
10551085
usm_type=_usm_type,
10561086
device=_device,
10571087
sycl_queue=_sycl_queue,

numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,61 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
"""Tests for dpnp ndarray constructors."""
5+
"""Tests for the dpnp.empty overload."""
66

77
import dpctl
88
import dpnp
99
import pytest
10+
from numba import errors
1011

1112
from numba_dpex import dpjit
1213

1314
shapes = [11, (2, 5)]
1415
dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64]
1516
usm_types = ["device", "shared", "host"]
16-
devices = ["cpu", None]
17+
18+
19+
@pytest.mark.parametrize("shape", shapes)
20+
def test_dpnp_empty_default(shape):
21+
"""Test dpnp.empty() with default parameters inside dpjit."""
22+
23+
@dpjit
24+
def func(shape):
25+
c = dpnp.empty(shape)
26+
return c
27+
28+
try:
29+
c = func(shape)
30+
except Exception:
31+
pytest.fail("Calling dpnp.empty() inside dpjit failed.")
32+
33+
if len(c.shape) == 1:
34+
assert c.shape[0] == shape
35+
else:
36+
assert c.shape == shape
37+
38+
dummy = dpnp.empty(shape)
39+
40+
assert c.dtype == dummy.dtype
41+
assert c.usm_type == dummy.usm_type
42+
assert c.sycl_device == dummy.sycl_device
43+
assert c.sycl_queue == dummy.sycl_queue
44+
if c.sycl_queue != dummy.sycl_queue:
45+
pytest.xfail(
46+
"Returned queue does not have the queue in the dummy array."
47+
)
48+
assert c.sycl_queue == dpctl._sycl_queue_manager.get_device_cached_queue(
49+
dummy.sycl_device
50+
)
1751

1852

1953
@pytest.mark.parametrize("shape", shapes)
2054
@pytest.mark.parametrize("dtype", dtypes)
2155
@pytest.mark.parametrize("usm_type", usm_types)
22-
@pytest.mark.parametrize("device", devices)
23-
def test_dpnp_empty(shape, dtype, usm_type, device):
56+
def test_dpnp_empty_from_device(shape, dtype, usm_type):
57+
""" "Use device only in dpnp.emtpy() inside dpjit."""
58+
device = dpctl.SyclDevice().filter_string
59+
2460
@dpjit
2561
def func(shape):
2662
c = dpnp.empty(shape, dtype=dtype, usm_type=usm_type, device=device)
@@ -29,7 +65,7 @@ def func(shape):
2965
try:
3066
c = func(shape)
3167
except Exception:
32-
pytest.fail("Calling dpnp.empty inside dpjit failed")
68+
pytest.fail("Calling dpnp.empty() inside dpjit failed.")
3369

3470
if len(c.shape) == 1:
3571
assert c.shape[0] == shape
@@ -38,32 +74,61 @@ def func(shape):
3874

3975
assert c.dtype == dtype
4076
assert c.usm_type == usm_type
41-
if device is not None:
42-
assert (
43-
c.sycl_device.filter_string
44-
== dpctl.SyclDevice(device).filter_string
77+
assert c.sycl_device.filter_string == device
78+
if c.sycl_queue != dpctl._sycl_queue_manager.get_device_cached_queue(
79+
device
80+
):
81+
pytest.xfail(
82+
"Returned queue does not have the queue cached against the device."
4583
)
46-
else:
47-
c.sycl_device.filter_string == dpctl.SyclDevice().filter_string
4884

4985

5086
@pytest.mark.parametrize("shape", shapes)
51-
def test_dpnp_empty_default_dtype(shape):
87+
@pytest.mark.parametrize("dtype", dtypes)
88+
@pytest.mark.parametrize("usm_type", usm_types)
89+
def test_dpnp_empty_from_queue(shape, dtype, usm_type):
90+
""" "Use queue only in dpnp.emtpy() inside dpjit."""
91+
5292
@dpjit
53-
def func(shape):
54-
c = dpnp.empty(shape)
93+
def func(shape, queue):
94+
c = dpnp.empty(shape, dtype=dtype, usm_type=usm_type, sycl_queue=queue)
5595
return c
5696

97+
queue = dpctl.SyclQueue()
98+
5799
try:
58-
c = func(shape)
100+
c = func(shape, queue)
59101
except Exception:
60-
pytest.fail("Calling dpnp.empty inside dpjit failed")
102+
pytest.fail("Calling dpnp.empty() inside dpjit failed.")
61103

62104
if len(c.shape) == 1:
63105
assert c.shape[0] == shape
64106
else:
65107
assert c.shape == shape
66108

67-
dummy_tensor = dpctl.tensor.empty(shape)
109+
assert c.dtype == dtype
110+
assert c.usm_type == usm_type
111+
assert c.sycl_device == queue.sycl_device
112+
113+
if c.sycl_queue != queue:
114+
pytest.xfail(
115+
"Returned queue does not have the queue passed to the dpnp function."
116+
)
117+
118+
119+
def test_dpnp_empty_exceptions():
120+
"""Test if exception is raised when both queue and device are specified."""
121+
device = dpctl.SyclDevice().filter_string
68122

69-
assert c.dtype == dummy_tensor.dtype
123+
@dpjit
124+
def func(shape, queue):
125+
c = dpnp.empty(shape, sycl_queue=queue, device=device)
126+
return c
127+
128+
queue = dpctl.SyclQueue()
129+
130+
try:
131+
func(10, queue)
132+
except Exception as e:
133+
assert isinstance(e, errors.TypingError)
134+
assert "`device` and `sycl_queue` are exclusive keywords" in str(e)

0 commit comments

Comments
 (0)