@@ -31,48 +31,53 @@ def __init__(
3131 aligned = True ,
3232 addrspace = address_space .GLOBAL ,
3333 ):
34+ # Creating SyclDevice from filter_string is expensive. So, USMNdArray should be able to
35+ # accept and SyclDevice from usm_ndarray as device parameter
36+ if not isinstance (device , (str , dpctl .SyclDevice )):
37+ raise TypeError (
38+ "The device keyword arg should be a str object specifying "
39+ "a SYCL filter selector"
40+ )
41+
42+ if not isinstance (queue , dpctl .SyclQueue ) and queue is not None :
43+ raise TypeError (
44+ "The queue keyword arg should be a dpctl.SyclQueue object or None"
45+ )
46+
3447 self .usm_type = usm_type
3548 self .addrspace = addrspace
3649
37- if queue is not None and device != "unknown" :
38- if not isinstance (device , str ):
39- raise TypeError (
40- "The device keyword arg should be a str object specifying "
41- "a SYCL filter selector"
42- )
43- if not isinstance (queue , dpctl .SyclQueue ):
44- raise TypeError (
45- "The queue keyword arg should be a dpctl.SyclQueue object"
46- )
47- d1 = queue .sycl_device
48- d2 = dpctl .SyclDevice (device )
49- if d1 != d2 :
50- raise TypeError (
51- "The queue keyword arg and the device keyword arg specify "
52- "different SYCL devices"
53- )
50+ def to_device (dev ):
51+ if isinstance (dev , dpctl .SyclDevice ):
52+ return dev
53+
54+ return dpctl .SyclDevice (dev )
55+
56+ def device_as_string (dev ):
57+ if isinstance (dev , dpctl .SyclDevice ):
58+ return dev .filter_string
59+
60+ return dev
61+
62+ if queue is not None :
63+ if device != "unknown" :
64+ if queue .sycl_device != to_device (device ):
65+ raise TypeError (
66+ "The queue keyword arg and the device keyword arg specify "
67+ "different SYCL devices"
68+ )
69+
5470 self .queue = queue
55- self .device = device
56- elif queue is None and device != "unknown" :
57- if not isinstance (device , str ):
58- raise TypeError (
59- "The device keyword arg should be a str object specifying "
60- "a SYCL filter selector"
61- )
71+ else :
72+ if device == "unknown" :
73+ device = None
74+
75+ device_str = device_as_string (device )
6276 self .queue = dpctl .tensor ._device .normalize_queue_device (
63- device = device
77+ device = device_str
6478 )
65- self .device = device
66- elif queue is not None and device == "unknown" :
67- if not isinstance (queue , dpctl .SyclQueue ):
68- raise TypeError (
69- "The queue keyword arg should be a dpctl.SyclQueue object"
70- )
71- self .device = self .queue .sycl_device .filter_string
72- self .queue = queue
73- else :
74- self .queue = dpctl .tensor ._device .normalize_queue_device ()
75- self .device = self .queue .sycl_device .filter_string
79+
80+ self .device = self .queue .sycl_device .filter_string
7681
7782 if not dtype :
7883 dummy_tensor = dpctl .tensor .empty (
0 commit comments