diff --git a/sycl/include/syclcompat/launch.hpp b/sycl/include/syclcompat/launch.hpp index 503f29ff8b91..7340f6e72c04 100644 --- a/sycl/include/syclcompat/launch.hpp +++ b/sycl/include/syclcompat/launch.hpp @@ -41,52 +41,87 @@ constexpr size_t getArgumentCount(R (*f)(Types...)) { return sizeof...(Types); } -template -sycl::nd_range<3> transform_nd_range(const sycl::nd_range &range) { - sycl::range global_range = range.get_global_range(); - sycl::range local_range = range.get_local_range(); - if constexpr (Dim == 3) { - return range; - } else if constexpr (Dim == 2) { - return sycl::nd_range<3>{{1, global_range[0], global_range[1]}, - {1, local_range[0], local_range[1]}}; - } - return sycl::nd_range<3>{{1, 1, global_range[0]}, {1, 1, local_range[0]}}; -} +struct KernelParams { + sycl::range<3> global_range; + sycl::range<3> local_range; + + inline KernelParams(const sycl::range<1> &global_range, + const sycl::range<1> &local_range) + : global_range({1, 1, global_range[0]}), + local_range({1, 1, local_range[0]}) {} + inline KernelParams(const sycl::range<2> &global_range, + const sycl::range<2> &local_range) + : global_range({1, global_range[0], global_range[1]}), + local_range({1, local_range[0], local_range[1]}) {} + inline KernelParams(const sycl::range<3> &global_range, + const sycl::range<3> &local_range) + : global_range(global_range), local_range(local_range) {} + inline KernelParams(const dim3 &grid_dim, const dim3 &block_dim) + : global_range(grid_dim * block_dim), local_range(block_dim) {} +}; template -std::enable_if_t, sycl::event> -launch(const sycl::nd_range<3> &range, sycl::queue q, Args... args) { - static_assert(detail::getArgumentCount(F) == sizeof...(args), - "Wrong number of arguments to SYCL kernel"); - static_assert( - std::is_same, void>::value, - "SYCL kernels should return void"); +sycl::event launch(const KernelParams &&kernel_params, size_t local_memory_size, + sycl::queue q, Args... args) { + using f_t = decltype(F); + if constexpr (getArgumentCount(F) == sizeof...(args)) { + using f_return_t = typename std::invoke_result_t; + static_assert(std::is_same_v, "SYCL kernels should return void"); + static_assert(std::is_invocable_v, "Kernel Functor needs to invocable"); + } else if constexpr(getArgumentCount(F) == sizeof...(args) + 1){ + using f_return_t = typename std::invoke_result_t; + static_assert(std::is_same_v, "SYCL kernels should return void"); + static_assert(std::is_invocable_v, "Kernel Functor needs to invocable"); + } - return q.parallel_for( - range, [=](sycl::nd_item<3>) { [[clang::always_inline]] F(args...); }); + return q.submit([&](sycl::handler &cgh) { + auto local_acc = sycl::local_accessor(local_memory_size, cgh); + cgh.parallel_for(sycl::nd_range<3>(kernel_params.global_range, + kernel_params.local_range), + [=](sycl::nd_item<3>) { + if constexpr (detail::getArgumentCount(F) == + sizeof...(args)) { + [[clang::always_inline]] F(args...); + } else if constexpr (detail::getArgumentCount(F) == + sizeof...(args) + 1) { + auto local_mem = local_acc.get_pointer(); + [[clang::always_inline]] F(args..., local_mem); + } + }); + }); } -template -sycl::event launch(const sycl::nd_range<3> &range, size_t mem_size, +template +sycl::event launch(const KernelParams &&kernel_params, size_t local_memory_size, sycl::queue q, Args... args) { - static_assert(detail::getArgumentCount(F) == sizeof...(args) + 1, - "Wrong number of arguments to SYCL kernel"); - using F_t = decltype(F); - using f_return_t = typename std::invoke_result_t; - static_assert(std::is_same::value, - "SYCL kernels should return void"); + using f_t = decltype(F); + if constexpr (getArgumentCount(F) == sizeof...(args)) { + using f_return_t = typename std::invoke_result_t; + static_assert(std::is_same_v, "SYCL kernels should return void"); + static_assert(std::is_invocable_v, "Kernel Functor needs to invocable"); + } else if constexpr(getArgumentCount(F) == sizeof...(args) + 1){ + using f_return_t = typename std::invoke_result_t; + static_assert(std::is_same_v, "SYCL kernels should return void"); + static_assert(std::is_invocable_v, "Kernel Functor needs to invocable"); + } return q.submit([&](sycl::handler &cgh) { - auto local_acc = sycl::local_accessor(mem_size, cgh); - cgh.parallel_for(range, [=](sycl::nd_item<3>) { - auto local_mem = local_acc.get_pointer(); - [[clang::always_inline]] F(args..., local_mem); - }); + auto local_acc = sycl::local_accessor(local_memory_size, cgh); + cgh.parallel_for( + sycl::nd_range<3>( + {kernel_params.global_range, kernel_params.local_range}), + [=](sycl::nd_item<3>) [[sycl::reqd_sub_group_size(SubgroupSize)]] { + if constexpr (detail::getArgumentCount(F) == sizeof...(args)) { + [[clang::always_inline]] F(args...); + } else if constexpr (detail::getArgumentCount(F) == + sizeof...(args) + 1) { + auto local_mem = local_acc.get_pointer(); + [[clang::always_inline]] F(args..., local_mem); + } + }); }); } - } // namespace detail template @@ -112,10 +147,11 @@ inline sycl::nd_range<1> compute_nd_range(int global_size_in, return compute_nd_range<1>(global_size_in, work_group_size); } + template std::enable_if_t, sycl::event> launch(const sycl::nd_range &range, sycl::queue q, Args... args) { - return detail::launch(detail::transform_nd_range(range), q, args...); + return detail::launch({range.get_global_range(), range.get_local_range()}, 0, q, args...); } template @@ -137,6 +173,32 @@ launch(const dim3 &grid, const dim3 &threads, Args... args) { return launch(grid, threads, get_default_queue(), args...); } +template +std::enable_if_t, sycl::event> +launch(const sycl::nd_range &range, sycl::queue q, Args... args) { + return detail::launch({range.get_global_range(), range.get_local_range()}, 0, q, args...); +} + +template +std::enable_if_t, sycl::event> +launch(const sycl::nd_range &range, Args... args) { + return launch(range, get_default_queue(), args...); +} + +// Alternative launch through dim3 objects +template +std::enable_if_t, sycl::event> +launch(const dim3 &grid, const dim3 &threads, sycl::queue q, Args... args) { + return launch(sycl::nd_range<3>{grid * threads, threads}, q, + args...); +} + +template +std::enable_if_t, sycl::event> +launch(const dim3 &grid, const dim3 &threads, Args... args) { + return launch(grid, threads, get_default_queue(), args...); +} + /// Launches a kernel with the templated F param and arguments on a /// device specified by the given nd_range and SYCL queue. /// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem, @@ -154,7 +216,7 @@ launch(const dim3 &grid, const dim3 &threads, Args... args) { template sycl::event launch(const sycl::nd_range &range, size_t mem_size, sycl::queue q, Args... args) { - return detail::launch(detail::transform_nd_range(range), mem_size, q, + return detail::launch({range.get_global_range(), range.get_local_range()}, mem_size, q, args...); } @@ -220,4 +282,98 @@ sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size, return launch(grid, threads, mem_size, get_default_queue(), args...); } +/// Launches a kernel with the requested sub group size SubgroupSize, templated +/// F param and arguments on a device specified by the given nd_range and SYCL +/// queue. + +/// @tparam SubgroupSize The subgroup size to be used by the kernel. +/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem, +/// Args... args). +/// @tparam Dim nd_range dimension number. +/// @tparam Args Types of the arguments to be passed to the kernel. +/// @param range Nd_range specifying the work group and global sizes for the +/// kernel. +/// @param q The SYCL queue on which to execute the kernel. +/// @param mem_size The size, in number of bytes, of the local +/// memory to be allocated for kernel. +/// @param args The arguments to be passed to the kernel. +/// @return A SYCL event object that can be used to synchronize with the +/// kernel's execution. +template +sycl::event launch(const sycl::nd_range &range, size_t mem_size, + sycl::queue q, Args... args) { + return detail::launch( + {range.get_global_range(), range.get_local_range()}, mem_size, q, + args...); +} + +/// Launches a kernel with the requested sub group size SubgroupSize, templated +/// F param and arguments on a device specified by the given nd_range using +/// theSYCL default queue. +/// @tparam SubgroupSize The subgroup size to be used by the kernel. +/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem, +/// Args... args). +/// @tparam Dim nd_range dimension number. +/// @tparam Args Types of the arguments to be passed to the kernel. +/// @param range Nd_range specifying the work group and global sizes for the +/// kernel. +/// @param mem_size The size, in number of bytes, of the local +/// memory to be allocated for kernel. +/// @param args The arguments to be passed to the kernel. +/// @return A SYCL event object that can be used to synchronize with the +/// kernel's execution. +template +sycl::event launch(const sycl::nd_range &range, size_t mem_size, + Args... args) { + return launch(range, mem_size, get_default_queue(), args...); +} + +/// Launches a kernel with the requested sub group size SubgroupSize, templated +/// F param and arguments on a device with a user-specified grid and block +/// dimensions following the standard of other programming models using a +/// user-defined SYCL queue +/// @tparam SubgroupSize The subgroup size to be used by the kernel. +/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem, +/// Args... args). +/// @tparam Dim nd_range dimension number. +/// @tparam Args Types of the arguments to be passed to the kernel. +/// @param grid Grid dimensions represented with an (x, y, z) iteration space. +/// @param threads Block dimensions represented with an (x, y, z) iteration +/// space. +/// @param mem_size The size, in number of bytes, of the local +/// memory to be allocated for kernel. +/// @param args The arguments to be passed to the kernel. +/// @return A SYCL event object that can be used to synchronize with the +/// kernel's execution. +template +sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size, + sycl::queue q, Args... args) { + return launch(sycl::nd_range<3>{grid * threads, threads}, + mem_size, q, args...); +} + +/// Launches a kernel with the requested sub group size SubgroupSize, templated +/// F param and arguments on a device with a user-specified grid and block +/// dimensions following the standard of other programming models using the +/// default SYCL queue. +/// @tparam SubgroupSize The subgroup size to be used by the kernel +/// @tparam F SYCL kernel to be executed, expects signature F(T* local_mem, +/// Args... args). +/// @tparam Dim nd_range dimension number. +/// @tparam Args Types of the arguments to be passed to the kernel. +/// @param grid Grid dimensions represented with an (x, y, z) iteration space. +/// @param threads Block dimensions represented with an (x, y, z) iteration +/// space. +/// @param mem_size The size, in number of bytes, of the +/// local memory to be allocated. +/// @param args The arguments to be passed to the kernel. +/// @return A SYCL event object that can be used to synchronize with the +/// kernel's execution. +template +sycl::event launch(const dim3 &grid, const dim3 &threads, size_t mem_size, + Args... args) { + return launch(grid, threads, mem_size, get_default_queue(), + args...); +} + } // namespace syclcompat diff --git a/sycl/test-e2e/syclcompat/launch/launch.cpp b/sycl/test-e2e/syclcompat/launch/launch.cpp index a06dabd6b18b..aa8ad4d218ab 100644 --- a/sycl/test-e2e/syclcompat/launch/launch.cpp +++ b/sycl/test-e2e/syclcompat/launch/launch.cpp @@ -60,6 +60,24 @@ void dynamic_local_mem_typed_kernel(T *data, char *local_mem) { } }; +template +void reqd_sg_size_kernel(int modifier_val, int num_elements, T *data) { + + const int id = + sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_global_id(0); + const int sg_size = sycl::ext::oneapi::this_work_item::get_nd_item<1>() + .get_sub_group() + .get_local_linear_range(); + if (id < num_elements) { + if (id < num_elements - modifier_val) { + data[id] = static_cast( + (id + modifier_val - sg_size) < 0 ? 0 : id + modifier_val - sg_size); + } else { + data[id] = static_cast(id + modifier_val + sg_size); + } + } +}; + template void compute_nd_range_3d(RangeParams range_param, std::string test_name) { std::cout << __PRETTY_FUNCTION__ << " " << test_name << std::endl; @@ -326,7 +344,76 @@ template void test_memsize_no_arg_launch_q() { memsize, lt.q_); } +template void test_reqd_sg_size() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + + LaunchTestWithArgs ltt; + if (ltt.skip_) // Unsupported aspect + return; + constexpr int SubgroupSize = 32; + const int modifier_val = 9; + + T *h_a = (T *)syclcompat::malloc_host(ltt.memsize_ * sizeof(T)); + T *d_a = (T *)syclcompat::malloc(ltt.memsize_ * sizeof(T)); + + syclcompat::launch>( + ltt.grid_, ltt.thread_, modifier_val, ltt.memsize_, d_a); + + syclcompat::wait_and_throw(); + syclcompat::memcpy(h_a, d_a, ltt.memsize_); + syclcompat::free(d_a); + + for (int i = 0; i < static_cast(ltt.memsize_); i++) { + T result; + if (i < (static_cast(ltt.memsize_) - modifier_val)) { + result = static_cast((i + modifier_val - SubgroupSize) < 0 + ? 0 + : (i + modifier_val - SubgroupSize)); + } else { + result = static_cast(i + modifier_val + SubgroupSize); + } + assert(h_a[i] == result); + } + + syclcompat::free(h_a); +} + +template void test_reqd_sg_size_q() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + + LaunchTestWithArgs ltt; + if (ltt.skip_) // Unsupported aspect + return; + constexpr int SubgroupSize = 32; + const int modifier_val = 9; + auto &q = ltt.in_order_q_; + + T *h_a = (T *)syclcompat::malloc_host(ltt.memsize_ * sizeof(T), q); + T *d_a = (T *)syclcompat::malloc(ltt.memsize_ * sizeof(T), q); + + syclcompat::launch>( + ltt.grid_, ltt.thread_, q, modifier_val, ltt.memsize_, d_a); + + syclcompat::wait_and_throw(); + syclcompat::memcpy(h_a, d_a, ltt.memsize_, q); + syclcompat::free(d_a, q); + + for (int i = 0; i < static_cast(ltt.memsize_); i++) { + T result; + if (i < (static_cast(ltt.memsize_) - modifier_val)) { + result = static_cast((i + modifier_val - SubgroupSize) < 0 + ? 0 + : (i + modifier_val - SubgroupSize)); + } else { + result = static_cast(i + modifier_val + SubgroupSize); + } + assert(h_a[i] == result); + } + syclcompat::free(h_a, q); +} + int main() { + test_launch_compute_nd_range_3d(); test_no_arg_launch(); test_one_arg_launch(); @@ -345,5 +432,8 @@ int main() { INSTANTIATE_ALL_TYPES(memsize_type_list, test_memsize_no_arg_launch); INSTANTIATE_ALL_TYPES(memsize_type_list, test_memsize_no_arg_launch_q); + INSTANTIATE_ALL_TYPES(memsize_type_list, test_reqd_sg_size); + INSTANTIATE_ALL_TYPES(memsize_type_list, test_reqd_sg_size_q); + return 0; }