Skip to content

Commit 9d8ae9e

Browse files
Avoid creating SyclDevice from filter_string
1 parent 58b7b84 commit 9d8ae9e

File tree

2 files changed

+42
-37
lines changed

2 files changed

+42
-37
lines changed

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,48 +31,53 @@ def __init__(
3131
aligned=True,
3232
addrspace=address_space.GLOBAL,
3333
):
34+
# Creating SyclDevice from filter_string is expensive. So, USMNdArray should be able to
35+
# accept and SyclDevice from usm_ndarray as device parameter
36+
if not isinstance(device, (str, dpctl.SyclDevice)):
37+
raise TypeError(
38+
"The device keyword arg should be a str object specifying "
39+
"a SYCL filter selector"
40+
)
41+
42+
if not isinstance(queue, dpctl.SyclQueue) and queue is not None:
43+
raise TypeError(
44+
"The queue keyword arg should be a dpctl.SyclQueue object or None"
45+
)
46+
3447
self.usm_type = usm_type
3548
self.addrspace = addrspace
3649

37-
if queue is not None and device != "unknown":
38-
if not isinstance(device, str):
39-
raise TypeError(
40-
"The device keyword arg should be a str object specifying "
41-
"a SYCL filter selector"
42-
)
43-
if not isinstance(queue, dpctl.SyclQueue):
44-
raise TypeError(
45-
"The queue keyword arg should be a dpctl.SyclQueue object"
46-
)
47-
d1 = queue.sycl_device
48-
d2 = dpctl.SyclDevice(device)
49-
if d1 != d2:
50-
raise TypeError(
51-
"The queue keyword arg and the device keyword arg specify "
52-
"different SYCL devices"
53-
)
50+
def to_device(dev):
51+
if isinstance(dev, dpctl.SyclDevice):
52+
return dev
53+
54+
return dpctl.SyclDevice(dev)
55+
56+
def device_as_string(dev):
57+
if isinstance(dev, dpctl.SyclDevice):
58+
return dev.filter_string
59+
60+
return dev
61+
62+
if queue is not None:
63+
if device != "unknown":
64+
if queue.sycl_device != to_device(device):
65+
raise TypeError(
66+
"The queue keyword arg and the device keyword arg specify "
67+
"different SYCL devices"
68+
)
69+
5470
self.queue = queue
55-
self.device = device
56-
elif queue is None and device != "unknown":
57-
if not isinstance(device, str):
58-
raise TypeError(
59-
"The device keyword arg should be a str object specifying "
60-
"a SYCL filter selector"
61-
)
71+
else:
72+
if device == "unknown":
73+
device = None
74+
75+
device_str = device_as_string(device)
6276
self.queue = dpctl.tensor._device.normalize_queue_device(
63-
device=device
77+
device=device_str
6478
)
65-
self.device = device
66-
elif queue is not None and device == "unknown":
67-
if not isinstance(queue, dpctl.SyclQueue):
68-
raise TypeError(
69-
"The queue keyword arg should be a dpctl.SyclQueue object"
70-
)
71-
self.device = self.queue.sycl_device.filter_string
72-
self.queue = queue
73-
else:
74-
self.queue = dpctl.tensor._device.normalize_queue_device()
75-
self.device = self.queue.sycl_device.filter_string
79+
80+
self.device = self.queue.sycl_device.filter_string
7681

7782
if not dtype:
7883
dummy_tensor = dpctl.tensor.empty(

numba_dpex/core/typing/typeof.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _typeof_helper(val, array_class_type):
4343
)
4444

4545
try:
46-
device = val.sycl_device.filter_string
46+
device = val.sycl_device
4747
except AttributeError:
4848
raise ValueError("The device for the usm_ndarray could not be inferred")
4949

0 commit comments

Comments
 (0)