88import dpctl
99import dpctl .tensor
1010from numba .core .typeconv import Conversion
11+ from numba .core .typeinfer import CallConstraint
1112from numba .core .types .npytypes import Array
13+ from numba .np .numpy_support import from_dtype
1214
1315from numba_dpex .utils import address_space
1416
@@ -18,10 +20,10 @@ class USMNdArray(Array):
1820
1921 def __init__ (
2022 self ,
21- dtype ,
2223 ndim ,
23- layout ,
24- usm_type = "unknown" ,
24+ layout = "C" ,
25+ dtype = None ,
26+ usm_type = "device" ,
2527 device = "unknown" ,
2628 queue = None ,
2729 readonly = False ,
@@ -32,15 +34,53 @@ def __init__(
3234 self .usm_type = usm_type
3335 self .addrspace = addrspace
3436
35- # Normalize the device filter string and get the fully qualified three
36- # tuple (backend:device_type:device_num) filter string from dpctl.
37- if device != "unknown" :
38- _d = dpctl .SyclDevice (device )
39- self .device = _d .filter_string
37+ if queue is not None and device != "unknown" :
38+ if not isinstance (device , str ):
39+ raise TypeError (
40+ "The device keyword arg should be a str object specifying "
41+ "a SYCL filter selector"
42+ )
43+ if not isinstance (queue , dpctl .SyclQueue ):
44+ raise TypeError (
45+ "The queue keyword arg should be a dpctl.SyclQueue object"
46+ )
47+ d1 = queue .sycl_device
48+ d2 = dpctl .SyclDevice (device )
49+ if d1 != d2 :
50+ raise TypeError (
51+ "The queue keyword arg and the device keyword arg specify "
52+ "different SYCL devices"
53+ )
54+ self .queue = queue
55+ self .device = device
56+ elif queue is None and device != "unknown" :
57+ if not isinstance (device , str ):
58+ raise TypeError (
59+ "The device keyword arg should be a str object specifying "
60+ "a SYCL filter selector"
61+ )
62+ self .queue = dpctl .SyclQueue (device )
63+ self .device = device
64+ elif queue is not None and device == "unknown" :
65+ if not isinstance (queue , dpctl .SyclQueue ):
66+ raise TypeError (
67+ "The queue keyword arg should be a dpctl.SyclQueue object"
68+ )
69+ self .device = self .queue .sycl_device .filter_string
70+ self .queue = queue
4071 else :
41- self .device = "unknown"
72+ self .queue = dpctl .SyclQueue ()
73+ self .device = self .queue .sycl_device .filter_string
4274
43- self .queue = queue
75+ if not dtype :
76+ dummy_tensor = dpctl .tensor .empty (
77+ sh = 1 , order = layout , usm_type = usm_type , sycl_queue = self .queue
78+ )
79+ # convert dpnp type to numba/numpy type
80+ _dtype = dummy_tensor .dtype
81+ self .dtype = from_dtype (_dtype )
82+ else :
83+ self .dtype = dtype
4484
4585 if name is None :
4686 type_name = "usm_ndarray"
@@ -50,20 +90,21 @@ def __init__(
5090 type_name = "unaligned " + type_name
5191 name_parts = (
5292 type_name ,
53- dtype ,
93+ self . dtype ,
5494 ndim ,
5595 layout ,
5696 self .addrspace ,
5797 usm_type ,
5898 self .device ,
99+ self .queue ,
59100 )
60101 name = (
61102 "%s(dtype=%s, ndim=%s, layout=%s, address_space=%s, "
62- "usm_type=%s, sycl_device=%s)" % name_parts
103+ "usm_type=%s, device=%s, sycl_device=%s)" % name_parts
63104 )
64105
65106 super ().__init__ (
66- dtype ,
107+ self . dtype ,
67108 ndim ,
68109 layout ,
69110 readonly = readonly ,
0 commit comments