Skip to content

Add support for work_group_memory extension #1984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpctl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
SyclKernelSubmitError,
SyclQueue,
SyclQueueCreationError,
WorkGroupMemory,
)
from ._sycl_queue_manager import get_device_cached_queue
from ._sycl_timer import SyclTimer
Expand Down Expand Up @@ -100,6 +101,7 @@
"SyclKernelInvalidRangeError",
"SyclKernelSubmitError",
"SyclQueueCreationError",
"WorkGroupMemory",
]
__all__ += [
"get_device_cached_queue",
Expand Down
17 changes: 16 additions & 1 deletion dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ cdef extern from "syclinterface/dpctl_sycl_enum_types.h":
_FLOAT 'DPCTL_FLOAT32_T',
_DOUBLE 'DPCTL_FLOAT64_T',
_VOID_PTR 'DPCTL_VOID_PTR',
_LOCAL_ACCESSOR 'DPCTL_LOCAL_ACCESSOR'
_LOCAL_ACCESSOR 'DPCTL_LOCAL_ACCESSOR',
_WORK_GROUP_MEMORY 'DPCTL_WORK_GROUP_MEMORY'

ctypedef enum _queue_property_type 'DPCTLQueuePropertyType':
_DEFAULT_PROPERTY 'DPCTL_DEFAULT_PROPERTY'
Expand Down Expand Up @@ -468,3 +469,17 @@ cdef extern from "syclinterface/dpctl_sycl_usm_interface.h":
cdef DPCTLSyclDeviceRef DPCTLUSM_GetPointerDevice(
DPCTLSyclUSMRef MRef,
DPCTLSyclContextRef CRef)

cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
cdef struct RawWorkGroupMemoryTy
ctypedef RawWorkGroupMemoryTy RawWorkGroupMemory

cdef struct DPCTLOpaqueWorkGroupMemory
ctypedef DPCTLOpaqueWorkGroupMemory *DPCTLSyclWorkGroupMemoryRef;

cdef DPCTLSyclWorkGroupMemoryRef DPCTLWorkGroupMemory_Create(size_t nbytes);

cdef void DPCTLWorkGroupMemory_Delete(
DPCTLSyclWorkGroupMemoryRef Ref);

cdef bint DPCTLWorkGroupMemory_Available();
17 changes: 16 additions & 1 deletion dpctl/_sycl_queue.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@

from libcpp cimport bool as cpp_bool

from ._backend cimport DPCTLSyclDeviceRef, DPCTLSyclQueueRef, _arg_data_type
from ._backend cimport (
DPCTLSyclDeviceRef,
DPCTLSyclQueueRef,
DPCTLSyclWorkGroupMemoryRef,
_arg_data_type,
)
from ._sycl_context cimport SyclContext
from ._sycl_device cimport SyclDevice
from ._sycl_event cimport SyclEvent
Expand Down Expand Up @@ -98,3 +103,13 @@ cdef public api class SyclQueue (_SyclQueue) [
cpdef prefetch(self, ptr, size_t count=*)
cpdef mem_advise(self, ptr, size_t count, int mem)
cpdef SyclEvent submit_barrier(self, dependent_events=*)

cdef public api class _WorkGroupMemory [
object Py_WorkGroupMemoryObject, type Py_WorkGroupMemoryType
]:
cdef DPCTLSyclWorkGroupMemoryRef _mem_ref

cdef public api class WorkGroupMemory(_WorkGroupMemory) [
object PyWorkGroupMemoryObject, type PyWorkGroupMemoryType
]:
pass
102 changes: 102 additions & 0 deletions dpctl/_sycl_queue.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,17 @@ from ._backend cimport ( # noqa: E211
DPCTLSyclContextRef,
DPCTLSyclDeviceSelectorRef,
DPCTLSyclEventRef,
DPCTLWorkGroupMemory_Available,
DPCTLWorkGroupMemory_Create,
DPCTLWorkGroupMemory_Delete,
_arg_data_type,
_backend_type,
_queue_property_type,
)
from .memory._memory cimport _Memory

import ctypes
import numbers

from .enum_types import backend_type

Expand Down Expand Up @@ -250,6 +254,15 @@ cdef class _kernel_arg_type:
_arg_data_type._LOCAL_ACCESSOR
)

@property
def dpctl_work_group_memory(self):
cdef str p_name = "dpctl_work_group_memory"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._WORK_GROUP_MEMORY
)


kernel_arg_type = _kernel_arg_type()

Expand Down Expand Up @@ -849,6 +862,9 @@ cdef class SyclQueue(_SyclQueue):
elif isinstance(arg, _Memory):
kargs[idx]= <void*>(<size_t>arg._pointer)
kargty[idx] = _arg_data_type._VOID_PTR
elif isinstance(arg, WorkGroupMemory):
kargs[idx] = <void*>(<size_t>arg._ref)
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
else:
ret = -1
return ret
Expand Down Expand Up @@ -1524,3 +1540,89 @@ cdef api SyclQueue SyclQueue_Make(DPCTLSyclQueueRef QRef):
"""
cdef DPCTLSyclQueueRef copied_QRef = DPCTLQueue_Copy(QRef)
return SyclQueue._create(copied_QRef)

cdef class _WorkGroupMemory:
def __dealloc__(self):
if(self._mem_ref):
DPCTLWorkGroupMemory_Delete(self._mem_ref)

cdef class WorkGroupMemory:
"""
WorkGroupMemory(nbytes)
Python class representing the ``work_group_memory`` class from the
Workgroup Memory oneAPI SYCL extension for low-overhead allocation of local
memory shared by the workitems in a workgroup.

This class is intended be used as kernel argument when launching kernels.

This is based on a DPC++ SYCL extension and only available in newer
versions. Use ``is_available()`` to check availability in your build.

There are multiple ways to create a `WorkGroupMemory`.

- If the constructor is invoked with just a single argument, this argument
is interpreted as the number of bytes to allocated in the shared local
memory.

- If the constructor is invoked with two arguments, the first argument is
interpreted as the datatype of the local memory, using the numpy type
naming scheme.
The second argument is interpreted as the number of elements to allocate.
The number of bytes to allocate is then computed from the byte size of
the data type and the element count.

Args:
args:
Variadic argument, see class documentation.

Raises:
TypeError: In case of incorrect arguments given to constructors,
unexpected types of input arguments.
"""
def __cinit__(self, *args):
cdef size_t nbytes
if not DPCTLWorkGroupMemory_Available():
raise RuntimeError("Workgroup memory extension not available")

if not (0 < len(args) < 3):
raise TypeError("WorkGroupMemory constructor takes 1 or 2 "
f"arguments, but {len(args)} were given")

if len(args) == 1:
if not isinstance(args[0], numbers.Integral):
raise TypeError("WorkGroupMemory single argument constructor"
"expects first argument to be `int`",
f"but got {type(args[0])}")
nbytes = <size_t>(args[0])
else:
if not isinstance(args[0], str):
raise TypeError("WorkGroupMemory constructor expects first"
f"argument to be `str`, but got {type(args[0])}")
if not isinstance(args[1], numbers.Integral):
raise TypeError("WorkGroupMemory constructor expects second"
f"argument to be `int`, but got {type(args[1])}")
dtype = <str>(args[0])
count = <size_t>(args[1])
if not dtype[0] in ["i", "u", "f"]:
raise TypeError(f"Unrecognized type value: '{dtype}'")
try:
bit_width = int(dtype[1:])
except ValueError:
raise TypeError(f"Unrecognized type value: '{dtype}'")

byte_size = <size_t>bit_width
nbytes = count * byte_size

self._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)

"""Check whether the work_group_memory extension is available"""
@staticmethod
def is_available():
return DPCTLWorkGroupMemory_Available()

property _ref:
"""Returns the address of the C API ``DPCTLWorkGroupMemoryRef``
pointer as a ``size_t``.
"""
def __get__(self):
return <size_t>self._mem_ref
6 changes: 4 additions & 2 deletions dpctl/apis/include/dpctl_capi.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
#pragma once

// clang-format off
// Ordering of includes is important here. dpctl_sycl_types defines types
// used by dpctl's Python C-API headers.
// Ordering of includes is important here. dpctl_sycl_types and
// dpctl_sycl_extension_interface define types used by dpctl's Python
// C-API headers.
#include "syclinterface/dpctl_sycl_types.h"
#include "syclinterface/dpctl_sycl_extension_interface.h"
#ifdef __cplusplus
#define CYTHON_EXTERN_C extern "C"
#else
Expand Down
13 changes: 13 additions & 0 deletions dpctl/sycl.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ cdef extern from "sycl/sycl.hpp" namespace "sycl":
"sycl::kernel_bundle<sycl::bundle_state::executable>":
pass

cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
cdef struct RawWorkGroupMemoryTy
ctypedef RawWorkGroupMemoryTy RawWorkGroupMemory

cdef extern from "syclinterface/dpctl_sycl_type_casters.hpp" \
namespace "dpctl::syclinterface":
# queue
Expand All @@ -67,3 +71,12 @@ cdef extern from "syclinterface/dpctl_sycl_type_casters.hpp" \
"dpctl::syclinterface::wrap<sycl::event>" (const event *)
cdef event * unwrap_event "dpctl::syclinterface::unwrap<sycl::event>" (
dpctl_backend.DPCTLSyclEventRef)

# work group memory extension
cdef dpctl_backend.DPCTLSyclWorkGroupMemoryRef wrap_work_group_memory \
"dpctl::syclinterface::wrap<RawWorkGroupMemory>" \
(const RawWorkGroupMemory *)

cdef RawWorkGroupMemory * unwrap_work_group_memory \
"dpctl::syclinterface::unwrap<RawWorkGroupMemory>" (
dpctl_backend.DPCTLSyclWorkGroupMemoryRef)
Binary file not shown.
1 change: 1 addition & 0 deletions dpctl/tests/test_sycl_kernel_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,4 @@ def test_kernel_arg_type():
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_float64)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_local_accessor)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_work_group_memory)
90 changes: 90 additions & 0 deletions dpctl/tests/test_work_group_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines unit test cases for the work_group_memory in a SYCL kernel"""

import os

import pytest

import dpctl
import dpctl.tensor


def get_spirv_abspath(fn):
curr_dir = os.path.dirname(os.path.abspath(__file__))
spirv_file = os.path.join(curr_dir, "input_files", fn)
return spirv_file


# The kernel in the SPIR-V file used in this test was generated from the
# following SYCL source code:
# #include <sycl/sycl.hpp>
# using namespace sycl;
# namespace syclexp = sycl::ext::oneapi::experimental;
# namespace syclext = sycl::ext::oneapi;
# using data_t = int32_t;
#
# extern "C" SYCL_EXTERNAL
# SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
# void local_mem_kernel(data_t* in, data_t* out,
# syclexp::work_group_memory<data_t> mem){
# auto* local_mem = &mem;
# auto item = syclext::this_work_item::get_nd_item<1>();
# size_t global_id = item.get_global_linear_id();
# size_t local_id = item.get_local_linear_id();
# local_mem[local_id] = in[global_id];
# out[global_id] = local_mem[local_id];
# }


def test_submit_work_group_memory():
if not dpctl.WorkGroupMemory.is_available():
pytest.skip("Work group memory extension not supported")

try:
q = dpctl.SyclQueue("level_zero")
except dpctl.SyclQueueCreationError:
pytest.skip("LevelZero queue could not be created")
spirv_file = get_spirv_abspath("work-group-memory-kernel.spv")
with open(spirv_file, "br") as spv:
spv_bytes = spv.read()
prog = dpctl.program.create_program_from_spirv(q, spv_bytes)
kernel = prog.get_sycl_kernel("__sycl_kernel_local_mem_kernel")
local_size = 16
global_size = local_size * 8

x = dpctl.tensor.ones(global_size, dtype="int32")
y = dpctl.tensor.zeros(global_size, dtype="int32")
x.sycl_queue.wait()
y.sycl_queue.wait()

try:
q.submit(
kernel,
[
x.usm_data,
y.usm_data,
dpctl.WorkGroupMemory("i4", local_size),
],
[global_size],
[local_size],
)
q.wait()
except dpctl._sycl_queue.SyclKernelSubmitError:
pytest.skip(f"Kernel submission to {q.sycl_device} failed")

assert dpctl.tensor.all(x == y)
Loading
Loading