Skip to content

Commit c30e1bc

Browse files
mingjie-intelDiptorup Deb
authored andcommitted
Changes to usm_ndarray_type.
- Made all args to the constructor except ndim as optional. - If no queue or device is provided then select a default queue using dpctl. - Select a default dtype using the same logic as dpctl.tensor. - Bugfix: make sure the derived dtype is passed to the parent Array type's constuctor. - Fix tests and examples impacted by UsmNdArray type changes. - Skip all dpnp.empty tests for now. These tests will be changed once the new implementation for dpnp.empty is merged.
1 parent 8d0c8ea commit c30e1bc

File tree

8 files changed

+65
-21
lines changed

8 files changed

+65
-21
lines changed

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import dpctl
99
import dpctl.tensor
1010
from numba.core.typeconv import Conversion
11+
from numba.core.typeinfer import CallConstraint
1112
from numba.core.types.npytypes import Array
13+
from numba.np.numpy_support import from_dtype
1214

1315
from numba_dpex.utils import address_space
1416

@@ -18,10 +20,10 @@ class USMNdArray(Array):
1820

1921
def __init__(
2022
self,
21-
dtype,
2223
ndim,
23-
layout,
24-
usm_type="unknown",
24+
layout="C",
25+
dtype=None,
26+
usm_type="device",
2527
device="unknown",
2628
queue=None,
2729
readonly=False,
@@ -32,15 +34,53 @@ def __init__(
3234
self.usm_type = usm_type
3335
self.addrspace = addrspace
3436

35-
# Normalize the device filter string and get the fully qualified three
36-
# tuple (backend:device_type:device_num) filter string from dpctl.
37-
if device != "unknown":
38-
_d = dpctl.SyclDevice(device)
39-
self.device = _d.filter_string
37+
if queue is not None and device != "unknown":
38+
if not isinstance(device, str):
39+
raise TypeError(
40+
"The device keyword arg should be a str object specifying "
41+
"a SYCL filter selector"
42+
)
43+
if not isinstance(queue, dpctl.SyclQueue):
44+
raise TypeError(
45+
"The queue keyword arg should be a dpctl.SyclQueue object"
46+
)
47+
d1 = queue.sycl_device
48+
d2 = dpctl.SyclDevice(device)
49+
if d1 != d2:
50+
raise TypeError(
51+
"The queue keyword arg and the device keyword arg specify "
52+
"different SYCL devices"
53+
)
54+
self.queue = queue
55+
self.device = device
56+
elif queue is None and device != "unknown":
57+
if not isinstance(device, str):
58+
raise TypeError(
59+
"The device keyword arg should be a str object specifying "
60+
"a SYCL filter selector"
61+
)
62+
self.queue = dpctl.SyclQueue(device)
63+
self.device = device
64+
elif queue is not None and device == "unknown":
65+
if not isinstance(queue, dpctl.SyclQueue):
66+
raise TypeError(
67+
"The queue keyword arg should be a dpctl.SyclQueue object"
68+
)
69+
self.device = self.queue.sycl_device.filter_string
70+
self.queue = queue
4071
else:
41-
self.device = "unknown"
72+
self.queue = dpctl.SyclQueue()
73+
self.device = self.queue.sycl_device.filter_string
4274

43-
self.queue = queue
75+
if not dtype:
76+
dummy_tensor = dpctl.tensor.empty(
77+
sh=1, order=layout, usm_type=usm_type, sycl_queue=self.queue
78+
)
79+
# convert dpnp type to numba/numpy type
80+
_dtype = dummy_tensor.dtype
81+
self.dtype = from_dtype(_dtype)
82+
else:
83+
self.dtype = dtype
4484

4585
if name is None:
4686
type_name = "usm_ndarray"
@@ -50,20 +90,21 @@ def __init__(
5090
type_name = "unaligned " + type_name
5191
name_parts = (
5292
type_name,
53-
dtype,
93+
self.dtype,
5494
ndim,
5595
layout,
5696
self.addrspace,
5797
usm_type,
5898
self.device,
99+
self.queue,
59100
)
60101
name = (
61102
"%s(dtype=%s, ndim=%s, layout=%s, address_space=%s, "
62-
"usm_type=%s, sycl_device=%s)" % name_parts
103+
"usm_type=%s, device=%s, sycl_device=%s)" % name_parts
63104
)
64105

65106
super().__init__(
66-
dtype,
107+
self.dtype,
67108
ndim,
68109
layout,
69110
readonly=readonly,

numba_dpex/examples/kernel/kernel_specialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
# ------------ Example 1. ------------ #
2020

2121
# Define type specializations using the numba_ndpx usm_ndarray data type.
22-
i64arrty = usm_ndarray(int64, 1, "C")
23-
f32arrty = usm_ndarray(float32, 1, "C")
22+
i64arrty = usm_ndarray(1, "C", int64)
23+
f32arrty = usm_ndarray(1, "C", float32)
2424

2525

2626
# specialize a kernel for the i64arrty

numba_dpex/tests/kernel_tests/test_barrier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from numba_dpex import float32, usm_ndarray, void
1212
from numba_dpex.tests._helper import filter_strings
1313

14-
f32arrty = usm_ndarray(float32, 1, "C")
14+
f32arrty = usm_ndarray(ndim=1, dtype=float32, layout="C")
1515

1616

1717
@pytest.mark.parametrize("filter_str", filter_strings)

numba_dpex/tests/kernel_tests/test_kernel_has_return_value_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numba_dpex as dpex
1010
from numba_dpex import int32, usm_ndarray
1111

12-
i32arrty = usm_ndarray(int32, 1, "C")
12+
i32arrty = usm_ndarray(ndim=1, dtype=int32, layout="C")
1313

1414

1515
def f(a):

numba_dpex/tests/kernel_tests/test_kernel_specialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
)
1414
from numba_dpex.core.kernel_interface.utils import Range
1515

16-
i64arrty = usm_ndarray(int64, 1, "C")
17-
f32arrty = usm_ndarray(float32, 1, "C")
16+
i64arrty = usm_ndarray(ndim=1, dtype=int64, layout="C")
17+
f32arrty = usm_ndarray(ndim=1, dtype=float32, layout="C")
1818

1919
specialized_kernel1 = dpex.kernel((i64arrty, i64arrty, i64arrty))
2020
specialized_kernel2 = dpex.kernel(

numba_dpex/tests/njit_tests/dpnp_ndarray/test_dpnp_empty.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
usm_types = ["device", "shared", "host"]
1414

1515

16+
@pytest.mark.skip(reason="Disabling old dpnp.empty tests")
1617
@pytest.mark.parametrize("shape", shapes)
1718
@pytest.mark.parametrize("dtype", dtypes)
1819
@pytest.mark.parametrize("usm_type", usm_types)

numba_dpex/tests/njit_tests/dpnp_ndarray/test_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def test_model_for_DpnpNdArray():
1616
1717
"""
1818

19-
model = default_manager.lookup(DpnpNdArray(types.float64, 1, "C"))
19+
model = default_manager.lookup(
20+
DpnpNdArray(ndim=1, dtype=types.float64, layout="C")
21+
)
2022
assert isinstance(model, ArrayModel)
2123

2224

numba_dpex/tests/test_debuginfo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
debug_options = [True, False]
1717

18-
f32arrty = usm_ndarray(float32, 1, "C")
18+
f32arrty = usm_ndarray(ndim=1, dtype=float32, layout="C")
1919

2020

2121
@pytest.fixture(params=debug_options)

0 commit comments

Comments
 (0)