Skip to content
230 changes: 193 additions & 37 deletions sycl/include/syclcompat/launch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,52 +41,87 @@ constexpr size_t getArgumentCount(R (*f)(Types...)) {
return sizeof...(Types);
}

template <int Dim>
sycl::nd_range<3> transform_nd_range(const sycl::nd_range<Dim> &range) {
sycl::range<Dim> global_range = range.get_global_range();
sycl::range<Dim> local_range = range.get_local_range();
if constexpr (Dim == 3) {
return range;
} else if constexpr (Dim == 2) {
return sycl::nd_range<3>{{1, global_range[0], global_range[1]},
{1, local_range[0], local_range[1]}};
}
return sycl::nd_range<3>{{1, 1, global_range[0]}, {1, 1, local_range[0]}};
}
struct KernelParams {
sycl::range<3> global_range;
sycl::range<3> local_range;

inline KernelParams(const sycl::range<1> &global_range,
const sycl::range<1> &local_range)
: global_range({1, 1, global_range[0]}),
local_range({1, 1, local_range[0]}) {}
inline KernelParams(const sycl::range<2> &global_range,
const sycl::range<2> &local_range)
: global_range({1, global_range[0], global_range[1]}),
local_range({1, local_range[0], local_range[1]}) {}
inline KernelParams(const sycl::range<3> &global_range,
const sycl::range<3> &local_range)
: global_range(global_range), local_range(local_range) {}
inline KernelParams(const dim3 &grid_dim, const dim3 &block_dim)
: global_range(grid_dim * block_dim), local_range(block_dim) {}
};

template <auto F, typename... Args>
std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
launch(const sycl::nd_range<3> &range, sycl::queue q, Args... args) {
static_assert(detail::getArgumentCount(F) == sizeof...(args),
"Wrong number of arguments to SYCL kernel");
static_assert(
std::is_same<std::invoke_result_t<decltype(F), Args...>, void>::value,
"SYCL kernels should return void");
sycl::event launch(const KernelParams &&kernel_params, size_t local_memory_size,
sycl::queue q, Args... args) {
using f_t = decltype(F);
if constexpr (getArgumentCount(F) == sizeof...(args)) {
using f_return_t = typename std::invoke_result_t<f_t, Args...>;
static_assert(std::is_same_v<f_return_t, void>, "SYCL kernels should return void");
static_assert(std::is_invocable_v<decltype(F), Args...>, "Kernel Functor needs to invocable");
} else if constexpr(getArgumentCount(F) == sizeof...(args) + 1){
using f_return_t = typename std::invoke_result_t<f_t, Args..., char*>;
static_assert(std::is_same_v<f_return_t, void>, "SYCL kernels should return void");
static_assert(std::is_invocable_v<decltype(F), Args..., char*>, "Kernel Functor needs to invocable");
}

return q.parallel_for(
range, [=](sycl::nd_item<3>) { [[clang::always_inline]] F(args...); });
return q.submit([&](sycl::handler &cgh) {
auto local_acc = sycl::local_accessor<char, 1>(local_memory_size, cgh);
cgh.parallel_for(sycl::nd_range<3>(kernel_params.global_range,
kernel_params.local_range),
[=](sycl::nd_item<3>) {
if constexpr (detail::getArgumentCount(F) ==
sizeof...(args)) {
[[clang::always_inline]] F(args...);
} else if constexpr (detail::getArgumentCount(F) ==
sizeof...(args) + 1) {
auto local_mem = local_acc.get_pointer();
[[clang::always_inline]] F(args..., local_mem);
}
});
});
}

template <auto F, typename... Args>
sycl::event launch(const sycl::nd_range<3> &range, size_t mem_size,
template <auto F, int SubgroupSize, typename... Args>
sycl::event launch(const KernelParams &&kernel_params, size_t local_memory_size,
sycl::queue q, Args... args) {
static_assert(detail::getArgumentCount(F) == sizeof...(args) + 1,
"Wrong number of arguments to SYCL kernel");

using F_t = decltype(F);
using f_return_t = typename std::invoke_result_t<F_t, Args..., char *>;
static_assert(std::is_same<f_return_t, void>::value,
"SYCL kernels should return void");
using f_t = decltype(F);
if constexpr (getArgumentCount(F) == sizeof...(args)) {
using f_return_t = typename std::invoke_result_t<f_t, Args...>;
static_assert(std::is_same_v<f_return_t, void>, "SYCL kernels should return void");
static_assert(std::is_invocable_v<decltype(F), Args...>, "Kernel Functor needs to invocable");
} else if constexpr(getArgumentCount(F) == sizeof...(args) + 1){
using f_return_t = typename std::invoke_result_t<f_t, Args..., char*>;
static_assert(std::is_same_v<f_return_t, void>, "SYCL kernels should return void");
static_assert(std::is_invocable_v<decltype(F), Args..., char*>, "Kernel Functor needs to invocable");
}

return q.submit([&](sycl::handler &cgh) {
auto local_acc = sycl::local_accessor<char, 1>(mem_size, cgh);
cgh.parallel_for(range, [=](sycl::nd_item<3>) {
auto local_mem = local_acc.get_pointer();
[[clang::always_inline]] F(args..., local_mem);
});
auto local_acc = sycl::local_accessor<char, 1>(local_memory_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(
{kernel_params.global_range, kernel_params.local_range}),
[=](sycl::nd_item<3>) [[sycl::reqd_sub_group_size(SubgroupSize)]] {
if constexpr (detail::getArgumentCount(F) == sizeof...(args)) {
[[clang::always_inline]] F(args...);
} else if constexpr (detail::getArgumentCount(F) ==
sizeof...(args) + 1) {
auto local_mem = local_acc.get_pointer();
[[clang::always_inline]] F(args..., local_mem);
}
});
});
}

} // namespace detail

template <int Dim>
Expand All @@ -112,10 +147,11 @@ inline sycl::nd_range<1> compute_nd_range(int global_size_in,
return compute_nd_range<1>(global_size_in, work_group_size);
}


template <auto F, int Dim, typename... Args>
std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
launch(const sycl::nd_range<Dim> &range, sycl::queue q, Args... args) {
return detail::launch<F>(detail::transform_nd_range<Dim>(range), q, args...);
return detail::launch<F>({range.get_global_range(), range.get_local_range()}, 0, q, args...);
}

template <auto F, int Dim, typename... Args>
Expand All @@ -137,6 +173,32 @@ launch(const dim3 &grid, const dim3 &threads, Args... args) {
return launch<F>(grid, threads, get_default_queue(), args...);
}

template <int SubgroupSize, auto F, int Dim, typename... Args>
std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
launch(const sycl::nd_range<Dim> &range, sycl::queue q, Args... args) {
return detail::launch<F, SubgroupSize>({range.get_global_range(), range.get_local_range()}, 0, q, args...);
}

template <int SubgroupSize, auto F, int Dim, typename... Args>
std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
launch(const sycl::nd_range<Dim> &range, Args... args) {
return launch<SubgroupSize, F>(range, get_default_queue(), args...);
}

// Alternative launch through dim3 objects
template <int SubgroupSize, auto F, typename... Args>
std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
launch(const dim3 &grid, const dim3 &threads, sycl::queue q, Args... args) {
return launch<SubgroupSize, F>(sycl::nd_range<3>{grid * threads, threads}, q,
args...);
}

template <int SubgroupSize, auto F, typename... Args>
std::enable_if_t<std::is_invocable_v<decltype(F), Args...>, sycl::event>
launch(const dim3 &grid, const dim3 &threads, Args... args) {
return launch<SubgroupSize, F>(grid, threads, get_default_queue(), args...);
}

/// Launches a kernel with the templated F param and arguments on a
/// device specified by the given nd_range and SYCL queue.
/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem,
Expand All @@ -154,7 +216,7 @@ launch(const dim3 &grid, const dim3 &threads, Args... args) {
template <auto F, int Dim, typename... Args>
sycl::event launch(const sycl::nd_range<Dim> &range, size_t mem_size,
sycl::queue q, Args... args) {
return detail::launch<F>(detail::transform_nd_range<Dim>(range), mem_size, q,
return detail::launch<F>({range.get_global_range(), range.get_local_range()}, mem_size, q,
args...);
}

Expand Down Expand Up @@ -220,4 +282,98 @@ sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size,
return launch<F>(grid, threads, mem_size, get_default_queue(), args...);
}

/// Launches a kernel with the requested sub group size SubgroupSize, templated
/// F param and arguments on a device specified by the given nd_range and SYCL
/// queue.

/// @tparam SubgroupSize The subgroup size to be used by the kernel.
/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem,
/// Args... args).
/// @tparam Dim nd_range dimension number.
/// @tparam Args Types of the arguments to be passed to the kernel.
/// @param range Nd_range specifying the work group and global sizes for the
/// kernel.
/// @param q The SYCL queue on which to execute the kernel.
/// @param mem_size The size, in number of bytes, of the local
/// memory to be allocated for kernel.
/// @param args The arguments to be passed to the kernel.
/// @return A SYCL event object that can be used to synchronize with the
/// kernel's execution.
template <int SubgroupSize, auto F, int Dim, typename... Args>
sycl::event launch(const sycl::nd_range<Dim> &range, size_t mem_size,
sycl::queue q, Args... args) {
return detail::launch<F, SubgroupSize>(
{range.get_global_range(), range.get_local_range()}, mem_size, q,
args...);
}

/// Launches a kernel with the requested sub group size SubgroupSize, templated
/// F param and arguments on a device specified by the given nd_range using
/// theSYCL default queue.
/// @tparam SubgroupSize The subgroup size to be used by the kernel.
/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem,
/// Args... args).
/// @tparam Dim nd_range dimension number.
/// @tparam Args Types of the arguments to be passed to the kernel.
/// @param range Nd_range specifying the work group and global sizes for the
/// kernel.
/// @param mem_size The size, in number of bytes, of the local
/// memory to be allocated for kernel.
/// @param args The arguments to be passed to the kernel.
/// @return A SYCL event object that can be used to synchronize with the
/// kernel's execution.
template <int SubgroupSize, auto F, int Dim, typename... Args>
sycl::event launch(const sycl::nd_range<Dim> &range, size_t mem_size,
Args... args) {
return launch<SubgroupSize, F>(range, mem_size, get_default_queue(), args...);
}

/// Launches a kernel with the requested sub group size SubgroupSize, templated
/// F param and arguments on a device with a user-specified grid and block
/// dimensions following the standard of other programming models using a
/// user-defined SYCL queue
/// @tparam SubgroupSize The subgroup size to be used by the kernel.
/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem,
/// Args... args).
/// @tparam Dim nd_range dimension number.
/// @tparam Args Types of the arguments to be passed to the kernel.
/// @param grid Grid dimensions represented with an (x, y, z) iteration space.
/// @param threads Block dimensions represented with an (x, y, z) iteration
/// space.
/// @param mem_size The size, in number of bytes, of the local
/// memory to be allocated for kernel.
/// @param args The arguments to be passed to the kernel.
/// @return A SYCL event object that can be used to synchronize with the
/// kernel's execution.
template <int SubgroupSize, auto F, typename... Args>
sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size,
sycl::queue q, Args... args) {
return launch<SubgroupSize, F>(sycl::nd_range<3>{grid * threads, threads},
mem_size, q, args...);
}

/// Launches a kernel with the requested sub group size SubgroupSize, templated
/// F param and arguments on a device with a user-specified grid and block
/// dimensions following the standard of other programming models using the
/// default SYCL queue.
/// @tparam SubgroupSize The subgroup size to be used by the kernel
/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem,
/// Args... args).
/// @tparam Dim nd_range dimension number.
/// @tparam Args Types of the arguments to be passed to the kernel.
/// @param grid Grid dimensions represented with an (x, y, z) iteration space.
/// @param threads Block dimensions represented with an (x, y, z) iteration
/// space.
/// @param mem_size The size, in number of bytes, of the
/// local memory to be allocated.
/// @param args The arguments to be passed to the kernel.
/// @return A SYCL event object that can be used to synchronize with the
/// kernel's execution.
template <int SubgroupSize, auto F, typename... Args>
sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size,
Args... args) {
return launch<SubgroupSize, F>(grid, threads, mem_size, get_default_queue(),
args...);
}

} // namespace syclcompat
90 changes: 90 additions & 0 deletions sycl/test-e2e/syclcompat/launch/launch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,24 @@ void dynamic_local_mem_typed_kernel(T *data, char *local_mem) {
}
};

template <typename T>
void reqd_sg_size_kernel(int modifier_val, int num_elements, T *data) {

const int id =
sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_global_id(0);
const int sg_size = sycl::ext::oneapi::this_work_item::get_nd_item<1>()
.get_sub_group()
.get_local_linear_range();
if (id < num_elements) {
if (id < num_elements - modifier_val) {
data[id] = static_cast<T>(
(id + modifier_val - sg_size) < 0 ? 0 : id + modifier_val - sg_size);
} else {
data[id] = static_cast<T>(id + modifier_val + sg_size);
}
}
};

template <int Dim>
void compute_nd_range_3d(RangeParams<Dim> range_param, std::string test_name) {
std::cout << __PRETTY_FUNCTION__ << " " << test_name << std::endl;
Expand Down Expand Up @@ -326,7 +344,76 @@ template <typename T> void test_memsize_no_arg_launch_q() {
memsize, lt.q_);
}

template <typename T> void test_reqd_sg_size() {
std::cout << __PRETTY_FUNCTION__ << std::endl;

LaunchTestWithArgs<T> ltt;
if (ltt.skip_) // Unsupported aspect
return;
constexpr int SubgroupSize = 32;
const int modifier_val = 9;

T *h_a = (T *)syclcompat::malloc_host(ltt.memsize_ * sizeof(T));
T *d_a = (T *)syclcompat::malloc(ltt.memsize_ * sizeof(T));

syclcompat::launch<SubgroupSize, reqd_sg_size_kernel<T>>(
ltt.grid_, ltt.thread_, modifier_val, ltt.memsize_, d_a);

syclcompat::wait_and_throw();
syclcompat::memcpy<T>(h_a, d_a, ltt.memsize_);
syclcompat::free(d_a);

for (int i = 0; i < static_cast<int>(ltt.memsize_); i++) {
T result;
if (i < (static_cast<int>(ltt.memsize_) - modifier_val)) {
result = static_cast<T>((i + modifier_val - SubgroupSize) < 0
? 0
: (i + modifier_val - SubgroupSize));
} else {
result = static_cast<T>(i + modifier_val + SubgroupSize);
}
assert(h_a[i] == result);
}

syclcompat::free(h_a);
}

template <typename T> void test_reqd_sg_size_q() {
std::cout << __PRETTY_FUNCTION__ << std::endl;

LaunchTestWithArgs<T> ltt;
if (ltt.skip_) // Unsupported aspect
return;
constexpr int SubgroupSize = 32;
const int modifier_val = 9;
auto &q = ltt.in_order_q_;

T *h_a = (T *)syclcompat::malloc_host(ltt.memsize_ * sizeof(T), q);
T *d_a = (T *)syclcompat::malloc(ltt.memsize_ * sizeof(T), q);

syclcompat::launch<SubgroupSize, reqd_sg_size_kernel<T>>(
ltt.grid_, ltt.thread_, q, modifier_val, ltt.memsize_, d_a);

syclcompat::wait_and_throw();
syclcompat::memcpy<T>(h_a, d_a, ltt.memsize_, q);
syclcompat::free(d_a, q);

for (int i = 0; i < static_cast<int>(ltt.memsize_); i++) {
T result;
if (i < (static_cast<int>(ltt.memsize_) - modifier_val)) {
result = static_cast<T>((i + modifier_val - SubgroupSize) < 0
? 0
: (i + modifier_val - SubgroupSize));
} else {
result = static_cast<T>(i + modifier_val + SubgroupSize);
}
assert(h_a[i] == result);
}
syclcompat::free(h_a, q);
}

int main() {

test_launch_compute_nd_range_3d();
test_no_arg_launch();
test_one_arg_launch();
Expand All @@ -345,5 +432,8 @@ int main() {
INSTANTIATE_ALL_TYPES(memsize_type_list, test_memsize_no_arg_launch);
INSTANTIATE_ALL_TYPES(memsize_type_list, test_memsize_no_arg_launch_q);

INSTANTIATE_ALL_TYPES(memsize_type_list, test_reqd_sg_size);
INSTANTIATE_ALL_TYPES(memsize_type_list, test_reqd_sg_size_q);

return 0;
}