@@ -59,6 +59,7 @@ from ._backend cimport ( # noqa: E211
59
59
DPCTLWorkGroupMemory_Delete,
60
60
_arg_data_type,
61
61
_backend_type,
62
+ _md_local_accessor,
62
63
_queue_property_type,
63
64
)
64
65
from .memory._memory cimport _Memory
@@ -125,6 +126,95 @@ cdef class kernel_arg_type_attribute:
125
126
return self .attr_value
126
127
127
128
129
+ cdef class LocalAccessor:
130
+ """
131
+ LocalAccessor(dtype, shape)
132
+
133
+ Python class for specifying the dimensionality and type of a
134
+ ``sycl::local_accessor``, to be used as a kernel argument type.
135
+
136
+ Args:
137
+ dtype (str):
138
+ the data type of the local memory.
139
+ The permitted values are
140
+
141
+ `'i1'`, `'i2'`, `'i4'`, `'i8'`:
142
+ signed integral types int8_t, int16_t, int32_t, int64_t
143
+ `'u1'`, `'u2'`, `'u4'`, `'u8'`
144
+ unsigned integral types uint8_t, uint16_t, uint32_t,
145
+ uint64_t
146
+ `'f4'`, `'f8'`,
147
+ single- and double-precision floating-point types float and
148
+ double
149
+ shape (tuple, list):
150
+ Size of LocalAccessor dimensions. Dimension of the LocalAccessor is
151
+ determined by the length of the tuple. Must be of length 1, 2, or 3,
152
+ and contain only non-negative integers.
153
+
154
+ Raises:
155
+ TypeError:
156
+ If the given shape is not a tuple or list.
157
+ ValueError:
158
+ If the given shape sequence is not between one and three elements long.
159
+ TypeError:
160
+ If the shape is not a sequence of integers.
161
+ ValueError:
162
+ If the shape contains a negative integer.
163
+ ValueError:
164
+ If the dtype string is unrecognized.
165
+ """
166
+ cdef _md_local_accessor lacc
167
+
168
+ def __cinit__ (self , str dtype , shape ):
169
+ if not isinstance (shape, (list , tuple )):
170
+ raise TypeError (f" `shape` must be a list or tuple, got {type(shape)}" )
171
+ ndim = len (shape)
172
+ if ndim < 1 or ndim > 3 :
173
+ raise ValueError (" LocalAccessor must have dimension between one and three" )
174
+ for s in shape:
175
+ if not isinstance (s, numbers.Integral):
176
+ raise TypeError (" LocalAccessor shape must be a sequence of integers" )
177
+ if s < 0 :
178
+ raise ValueError (" LocalAccessor dimensions must be non-negative" )
179
+ self .lacc.ndim = ndim
180
+ self .lacc.dim0 = < size_t> shape[0 ]
181
+ self .lacc.dim1 = < size_t> shape[1 ] if ndim > 1 else 1
182
+ self .lacc.dim2 = < size_t> shape[2 ] if ndim > 2 else 1
183
+
184
+ if dtype == ' i1' :
185
+ self .lacc.dpctl_type_id = _arg_data_type._INT8_T
186
+ elif dtype == ' u1' :
187
+ self .lacc.dpctl_type_id = _arg_data_type._UINT8_T
188
+ elif dtype == ' i2' :
189
+ self .lacc.dpctl_type_id = _arg_data_type._INT16_T
190
+ elif dtype == ' u2' :
191
+ self .lacc.dpctl_type_id = _arg_data_type._UINT16_T
192
+ elif dtype == ' i4' :
193
+ self .lacc.dpctl_type_id = _arg_data_type._INT32_T
194
+ elif dtype == ' u4' :
195
+ self .lacc.dpctl_type_id = _arg_data_type._UINT32_T
196
+ elif dtype == ' i8' :
197
+ self .lacc.dpctl_type_id = _arg_data_type._INT64_T
198
+ elif dtype == ' u8' :
199
+ self .lacc.dpctl_type_id = _arg_data_type._UINT64_T
200
+ elif dtype == ' f4' :
201
+ self .lacc.dpctl_type_id = _arg_data_type._FLOAT
202
+ elif dtype == ' f8' :
203
+ self .lacc.dpctl_type_id = _arg_data_type._DOUBLE
204
+ else :
205
+ raise ValueError (f" Unrecognized type value: '{dtype}'" )
206
+
207
+ def __repr__ (self ):
208
+ return f" LocalAccessor({self.lacc.ndim})"
209
+
210
+ cdef size_t addressof(self ):
211
+ """
212
+ Returns the address of the _md_local_accessor for this LocalAccessor
213
+ cast to ``size_t``.
214
+ """
215
+ return < size_t> & self .lacc
216
+
217
+
128
218
cdef class _kernel_arg_type:
129
219
"""
130
220
An enumeration of supported kernel argument types in
@@ -865,6 +955,9 @@ cdef class SyclQueue(_SyclQueue):
865
955
elif isinstance (arg, WorkGroupMemory):
866
956
kargs[idx] = < void * > (< size_t> arg._ref)
867
957
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
958
+ elif isinstance (arg, LocalAccessor):
959
+ kargs[idx] = < void * > ((< LocalAccessor> arg).addressof())
960
+ kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
868
961
else :
869
962
ret = - 1
870
963
return ret
0 commit comments