Skip to content

Commit c3e6952

Browse files
authored
Merge pull request #1991 from IntelPython/add-md-local-accessor
Expose LocalAccessor as kernel argument type
2 parents dffd393 + 1d3453e commit c3e6952

6 files changed

+141
-0
lines changed

dpctl/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from ._sycl_event import SyclEvent
4949
from ._sycl_platform import SyclPlatform, get_platforms, lsplatform
5050
from ._sycl_queue import (
51+
LocalAccessor,
5152
SyclKernelInvalidRangeError,
5253
SyclKernelSubmitError,
5354
SyclQueue,
@@ -102,6 +103,7 @@
102103
"SyclKernelSubmitError",
103104
"SyclQueueCreationError",
104105
"WorkGroupMemory",
106+
"LocalAccessor",
105107
]
106108
__all__ += [
107109
"get_device_cached_queue",

dpctl/_backend.pxd

+6
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,12 @@ cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h":
362362

363363

364364
cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
365+
ctypedef struct _md_local_accessor 'MDLocalAccessor':
366+
size_t ndim
367+
_arg_data_type dpctl_type_id
368+
size_t dim0
369+
size_t dim1
370+
size_t dim2
365371
cdef bool DPCTLQueue_AreEq(const DPCTLSyclQueueRef QRef1,
366372
const DPCTLSyclQueueRef QRef2)
367373
cdef DPCTLSyclQueueRef DPCTLQueue_Create(

dpctl/_sycl_queue.pyx

+93
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ from ._backend cimport ( # noqa: E211
5959
DPCTLWorkGroupMemory_Delete,
6060
_arg_data_type,
6161
_backend_type,
62+
_md_local_accessor,
6263
_queue_property_type,
6364
)
6465
from .memory._memory cimport _Memory
@@ -125,6 +126,95 @@ cdef class kernel_arg_type_attribute:
125126
return self.attr_value
126127

127128

129+
cdef class LocalAccessor:
130+
"""
131+
LocalAccessor(dtype, shape)
132+
133+
Python class for specifying the dimensionality and type of a
134+
``sycl::local_accessor``, to be used as a kernel argument type.
135+
136+
Args:
137+
dtype (str):
138+
the data type of the local memory.
139+
The permitted values are
140+
141+
`'i1'`, `'i2'`, `'i4'`, `'i8'`:
142+
signed integral types int8_t, int16_t, int32_t, int64_t
143+
`'u1'`, `'u2'`, `'u4'`, `'u8'`
144+
unsigned integral types uint8_t, uint16_t, uint32_t,
145+
uint64_t
146+
`'f4'`, `'f8'`,
147+
single- and double-precision floating-point types float and
148+
double
149+
shape (tuple, list):
150+
Size of LocalAccessor dimensions. Dimension of the LocalAccessor is
151+
determined by the length of the tuple. Must be of length 1, 2, or 3,
152+
and contain only non-negative integers.
153+
154+
Raises:
155+
TypeError:
156+
If the given shape is not a tuple or list.
157+
ValueError:
158+
If the given shape sequence is not between one and three elements long.
159+
TypeError:
160+
If the shape is not a sequence of integers.
161+
ValueError:
162+
If the shape contains a negative integer.
163+
ValueError:
164+
If the dtype string is unrecognized.
165+
"""
166+
cdef _md_local_accessor lacc
167+
168+
def __cinit__(self, str dtype, shape):
169+
if not isinstance(shape, (list, tuple)):
170+
raise TypeError(f"`shape` must be a list or tuple, got {type(shape)}")
171+
ndim = len(shape)
172+
if ndim < 1 or ndim > 3:
173+
raise ValueError("LocalAccessor must have dimension between one and three")
174+
for s in shape:
175+
if not isinstance(s, numbers.Integral):
176+
raise TypeError("LocalAccessor shape must be a sequence of integers")
177+
if s < 0:
178+
raise ValueError("LocalAccessor dimensions must be non-negative")
179+
self.lacc.ndim = ndim
180+
self.lacc.dim0 = <size_t> shape[0]
181+
self.lacc.dim1 = <size_t> shape[1] if ndim > 1 else 1
182+
self.lacc.dim2 = <size_t> shape[2] if ndim > 2 else 1
183+
184+
if dtype == 'i1':
185+
self.lacc.dpctl_type_id = _arg_data_type._INT8_T
186+
elif dtype == 'u1':
187+
self.lacc.dpctl_type_id = _arg_data_type._UINT8_T
188+
elif dtype == 'i2':
189+
self.lacc.dpctl_type_id = _arg_data_type._INT16_T
190+
elif dtype == 'u2':
191+
self.lacc.dpctl_type_id = _arg_data_type._UINT16_T
192+
elif dtype == 'i4':
193+
self.lacc.dpctl_type_id = _arg_data_type._INT32_T
194+
elif dtype == 'u4':
195+
self.lacc.dpctl_type_id = _arg_data_type._UINT32_T
196+
elif dtype == 'i8':
197+
self.lacc.dpctl_type_id = _arg_data_type._INT64_T
198+
elif dtype == 'u8':
199+
self.lacc.dpctl_type_id = _arg_data_type._UINT64_T
200+
elif dtype == 'f4':
201+
self.lacc.dpctl_type_id = _arg_data_type._FLOAT
202+
elif dtype == 'f8':
203+
self.lacc.dpctl_type_id = _arg_data_type._DOUBLE
204+
else:
205+
raise ValueError(f"Unrecognized type value: '{dtype}'")
206+
207+
def __repr__(self):
208+
return f"LocalAccessor({self.lacc.ndim})"
209+
210+
cdef size_t addressof(self):
211+
"""
212+
Returns the address of the _md_local_accessor for this LocalAccessor
213+
cast to ``size_t``.
214+
"""
215+
return <size_t>&self.lacc
216+
217+
128218
cdef class _kernel_arg_type:
129219
"""
130220
An enumeration of supported kernel argument types in
@@ -865,6 +955,9 @@ cdef class SyclQueue(_SyclQueue):
865955
elif isinstance(arg, WorkGroupMemory):
866956
kargs[idx] = <void*>(<size_t>arg._ref)
867957
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
958+
elif isinstance(arg, LocalAccessor):
959+
kargs[idx] = <void*>((<LocalAccessor>arg).addressof())
960+
kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
868961
else:
869962
ret = -1
870963
return ret
Binary file not shown.
Binary file not shown.

dpctl/tests/test_sycl_kernel_submit.py

+40
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""
1919

2020
import ctypes
21+
import os
2122

2223
import numpy as np
2324
import pytest
@@ -279,3 +280,42 @@ def test_kernel_arg_type():
279280
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)
280281
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_local_accessor)
281282
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_work_group_memory)
283+
284+
285+
def get_spirv_abspath(fn):
286+
curr_dir = os.path.dirname(os.path.abspath(__file__))
287+
spirv_file = os.path.join(curr_dir, "input_files", fn)
288+
return spirv_file
289+
290+
291+
# the process for generating the .spv files in this test is documented in
292+
# libsyclinterface/tests/test_sycl_queue_submit_local_accessor_arg.cpp
293+
# in a comment starting on line 123
294+
def test_submit_local_accessor_arg():
295+
try:
296+
q = dpctl.SyclQueue("level_zero")
297+
except dpctl.SyclQueueCreationError:
298+
pytest.skip("OpenCL queue could not be created")
299+
fn = get_spirv_abspath("local_accessor_kernel_inttys_fp32.spv")
300+
with open(fn, "br") as f:
301+
spirv_bytes = f.read()
302+
prog = dpctl_prog.create_program_from_spirv(q, spirv_bytes)
303+
krn = prog.get_sycl_kernel("_ZTS14SyclKernel_SLMIlE")
304+
lws = 32
305+
gws = lws * 10
306+
x = dpt.ones(gws, dtype="i8")
307+
x.sycl_queue.wait()
308+
try:
309+
e = q.submit(
310+
krn,
311+
[x.usm_data, dpctl.LocalAccessor("i8", (lws,))],
312+
[gws],
313+
[lws],
314+
)
315+
e.wait()
316+
except dpctl._sycl_queue.SyclKernelSubmitError:
317+
pytest.skip(f"Kernel submission failed for device {q.sycl_device}")
318+
expected = dpt.arange(1, x.size + 1, dtype=x.dtype, device=x.device) * (
319+
2 * lws
320+
)
321+
assert dpt.all(x == expected)

0 commit comments

Comments
 (0)