diff --git a/sycl/include/CL/sycl/ONEAPI/sub_group.hpp b/sycl/include/CL/sycl/ONEAPI/sub_group.hpp index 11d09c114bf81..8989c4cf027a2 100644 --- a/sycl/include/CL/sycl/ONEAPI/sub_group.hpp +++ b/sycl/include/CL/sycl/ONEAPI/sub_group.hpp @@ -295,29 +295,71 @@ struct sub_group { PI_INVALID_DEVICE); #endif } - +#ifdef __SYCL_DEVICE_ONLY__ +#ifdef __NVPTX__ template sycl::detail::enable_if_t< - sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && - N != 1, + sycl::detail::sub_group::AcceptableForGlobalLoadStore::value, vec> load(const multi_ptr src) const { -#ifdef __SYCL_DEVICE_ONLY__ -#ifdef __NVPTX__ vec res; for (int i = 0; i < N; ++i) { res[i] = *(src.get() + i * get_max_local_range()[0] + get_local_id()[0]); } return res; -#else + } +#else // __NVPTX__ + template + sycl::detail::enable_if_t< + sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && + N != 1 && N != 3 && N != 16, + vec> + load(const multi_ptr src) const { return sycl::detail::sub_group::load(src); -#endif // __NVPTX__ -#else + } + + template + sycl::detail::enable_if_t< + sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && + N == 16, + vec> + load(const multi_ptr src) const { + return {sycl::detail::sub_group::load<8, T>(src), + sycl::detail::sub_group::load<8, T>(src + + 8 * get_max_local_range()[0])}; + } + + template + sycl::detail::enable_if_t< + sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && + N == 3, + vec> + load(const multi_ptr src) const { + return { + sycl::detail::sub_group::load<1, T>(src), + sycl::detail::sub_group::load<2, T>(src + get_max_local_range()[0])}; + } + + template + sycl::detail::enable_if_t< + sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && + N == 1, + vec> + load(const multi_ptr src) const { + return sycl::detail::sub_group::load(src); + } +#endif // ___NVPTX___ +#else // __SYCL_DEVICE_ONLY__ + template + sycl::detail::enable_if_t< + sycl::detail::sub_group::AcceptableForGlobalLoadStore::value, + vec> + load(const multi_ptr src) const { (void)src; throw runtime_error("Sub-groups are not supported on host device.", PI_INVALID_DEVICE); -#endif } +#endif // __SYCL_DEVICE_ONLY__ template sycl::detail::enable_if_t< @@ -337,25 +379,6 @@ struct sub_group { #endif } - template - sycl::detail::enable_if_t< - sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && - N == 1, - vec> - load(const multi_ptr src) const { -#ifdef __SYCL_DEVICE_ONLY__ -#ifdef __NVPTX__ - return src.get()[get_local_id()[0]]; -#else - return sycl::detail::sub_group::load(src); -#endif // __NVPTX__ -#else - (void)src; - throw runtime_error("Sub-groups are not supported on host device.", - PI_INVALID_DEVICE); -#endif - } - #ifdef __SYCL_DEVICE_ONLY__ // Method for decorated pointer template @@ -437,45 +460,63 @@ struct sub_group { #endif } +#ifdef __SYCL_DEVICE_ONLY__ +#ifdef __NVPTX__ + template + sycl::detail::enable_if_t< + sycl::detail::sub_group::AcceptableForGlobalLoadStore::value> + store(multi_ptr dst, const vec &x) const { + for (int i = 0; i < N; ++i) { + *(dst.get() + i * get_max_local_range()[0] + get_local_id()[0]) = x[i]; + } + } +#else // __NVPTX__ + template + sycl::detail::enable_if_t< + sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && + N != 1 && N != 3 && N != 16> + store(multi_ptr dst, const vec &x) const { + sycl::detail::sub_group::store(dst, x); + } + template sycl::detail::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && N == 1> store(multi_ptr dst, const vec &x) const { -#ifdef __SYCL_DEVICE_ONLY__ -#ifdef __NVPTX__ - dst.get()[get_local_id()[0]] = x[0]; -#else - store(dst, x); -#endif // __NVPTX__ -#else - (void)dst; - (void)x; - throw runtime_error("Sub-groups are not supported on host device.", - PI_INVALID_DEVICE); -#endif + sycl::detail::sub_group::store(dst, x); } template sycl::detail::enable_if_t< sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && - N != 1> - store(multi_ptr dst, const vec &x) const { -#ifdef __SYCL_DEVICE_ONLY__ -#ifdef __NVPTX__ - for (int i = 0; i < N; ++i) { - *(dst.get() + i * get_max_local_range()[0] + get_local_id()[0]) = x[i]; - } -#else - sycl::detail::sub_group::store(dst, x); + N == 3> + store(multi_ptr dst, const vec &x) const { + store<1, T, Space>(dst, x.s0()); + store<2, T, Space>(dst + get_max_local_range()[0], {x.s1(), x.s2()}); + } + + template + sycl::detail::enable_if_t< + sycl::detail::sub_group::AcceptableForGlobalLoadStore::value && + N == 16> + store(multi_ptr dst, const vec &x) const { + store<8, T, Space>(dst, x.lo()); + store<8, T, Space>(dst + 8 * get_max_local_range()[0], x.hi()); + } + #endif // __NVPTX__ -#else +#else // __SYCL_DEVICE_ONLY__ + template + sycl::detail::enable_if_t< + sycl::detail::sub_group::AcceptableForGlobalLoadStore::value> + store(multi_ptr dst, const vec &x) const { (void)dst; (void)x; throw runtime_error("Sub-groups are not supported on host device.", PI_INVALID_DEVICE); -#endif } +#endif // __SYCL_DEVICE_ONLY__ template sycl::detail::enable_if_t<