diff --git a/libclc/ptx-nvidiacl/libspirv/SOURCES b/libclc/ptx-nvidiacl/libspirv/SOURCES index bec378d428511..4177aae12b416 100644 --- a/libclc/ptx-nvidiacl/libspirv/SOURCES +++ b/libclc/ptx-nvidiacl/libspirv/SOURCES @@ -93,7 +93,7 @@ images/image_helpers.ll images/image.cl group/collectives_helpers.ll group/collectives.cl -group/group_ballot.cl +group/group_non_uniform.cl atomic/atomic_add.cl atomic/atomic_and.cl atomic/atomic_cmpxchg.cl diff --git a/libclc/ptx-nvidiacl/libspirv/group/group_ballot.cl b/libclc/ptx-nvidiacl/libspirv/group/group_non_uniform.cl similarity index 80% rename from libclc/ptx-nvidiacl/libspirv/group/group_ballot.cl rename to libclc/ptx-nvidiacl/libspirv/group/group_non_uniform.cl index 33285028b7b39..50826d9bf53e2 100644 --- a/libclc/ptx-nvidiacl/libspirv/group/group_ballot.cl +++ b/libclc/ptx-nvidiacl/libspirv/group/group_non_uniform.cl @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "membermask.h" +#include #include #include @@ -34,3 +35,9 @@ _Z29__spirv_GroupNonUniformBallotjb(unsigned flag, bool predicate) { return res; } + +_CLC_DEF _CLC_CONVERGENT uint +_Z37__spirv_GroupNonUniformBallotBitCountN5__spv5Scope4FlagEiDv4_j( + uint scope, uint flag, __clc_vec4_uint32_t mask) { + return __clc_native_popcount(__nvvm_read_ptx_sreg_lanemask_lt() & mask[0]); +} diff --git a/sycl/include/sycl/detail/cuda/masked_redux.hpp b/sycl/include/sycl/detail/cuda/masked_redux.hpp new file mode 100644 index 0000000000000..3e6fd92a1003a --- /dev/null +++ b/sycl/include/sycl/detail/cuda/masked_redux.hpp @@ -0,0 +1,396 @@ +//==----- masked_redux.hpp - cuda masked reduction builtins and impls -----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { + namespace detail { + +template +using IsRedux = + std::bool_constant::value && + sycl::detail::IsBitAND::value || + sycl::detail::IsBitOR::value || + sycl::detail::IsBitXOR::value || + sycl::detail::IsPlus::value || + sycl::detail::IsMinimum::value || + sycl::detail::IsMaximum::value>; + +#ifdef __SYCL_DEVICE_ONLY__ +#if defined(__NVPTX__) + +//// Masked reductions using redux.sync, requires integer types + +template +std::enable_if_t::value && + sycl::detail::IsMinimum::value, + T> +masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + return __nvvm_redux_sync_umin(x, MemberMask); +} + +template +std::enable_if_t::value && + sycl::detail::IsMinimum::value, + T> +masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + return __nvvm_redux_sync_min(x, MemberMask); +} + +template +std::enable_if_t::value && + sycl::detail::IsMaximum::value, + T> +masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + return __nvvm_redux_sync_umax(x, MemberMask); +} + +template +std::enable_if_t::value && + sycl::detail::IsMaximum::value, + T> +masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + return __nvvm_redux_sync_max(x, MemberMask); +} + +template +std::enable_if_t<(sycl::detail::is_sugeninteger::value || + sycl::detail::is_sigeninteger::value) && + sycl::detail::IsPlus::value, + T> +masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + return __nvvm_redux_sync_add(x, MemberMask); +} + +template +std::enable_if_t<(sycl::detail::is_sugeninteger::value || + sycl::detail::is_sigeninteger::value) && + sycl::detail::IsBitAND::value, + T> +masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + return __nvvm_redux_sync_and(x, MemberMask); +} + +template +std::enable_if_t<(sycl::detail::is_sugeninteger::value || + sycl::detail::is_sigeninteger::value) && + sycl::detail::IsBitOR::value, + T> +masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + return __nvvm_redux_sync_or(x, MemberMask); +} + +template +std::enable_if_t<(sycl::detail::is_sugeninteger::value || + sycl::detail::is_sigeninteger::value) && + sycl::detail::IsBitXOR::value, + T> +masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + return __nvvm_redux_sync_xor(x, MemberMask); +} +//// + +//// Shuffle based masked reduction impls + +// Cluster group reduction using shfls, T = double +template +inline __SYCL_ALWAYS_INLINE std::enable_if_t< + ext::oneapi::experimental::is_fixed_size_group::value && + std::is_same_v, + T> +masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + for (int i = g.get_local_range()[0] / 2; i > 0; i /= 2) { + int x_a, x_b; + asm volatile("mov.b64 {%0,%1},%2; \n\t" : "=r"(x_a), "=r"(x_b) : "l"(x)); + + auto tmp_a = __nvvm_shfl_sync_bfly_i32(MemberMask, x_a, -1, i); + auto tmp_b = __nvvm_shfl_sync_bfly_i32(MemberMask, x_b, -1, i); + double tmp; + asm volatile("mov.b64 %0,{%1,%2}; \n\t" + : "=l"(tmp) + : "r"(tmp_a), "r"(tmp_b)); + x = binary_op(x, tmp); + } + + return x; +} + +// Cluster group reduction using shfls, T = float +template +inline __SYCL_ALWAYS_INLINE std::enable_if_t< + ext::oneapi::experimental::is_fixed_size_group::value && + std::is_same_v, + T> +masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + + for (int i = g.get_local_range()[0] / 2; i > 0; i /= 2) { + auto tmp = + __nvvm_shfl_sync_bfly_i32(MemberMask, __nvvm_bitcast_f2i(x), -1, i); + x = binary_op(x, __nvvm_bitcast_i2f(tmp)); + } + return x; +} + +// Cluster group reduction using shfls, std::is_integral_v +template +inline __SYCL_ALWAYS_INLINE std::enable_if_t< + ext::oneapi::experimental::is_fixed_size_group::value && + std::is_integral_v, + T> +masked_reduction_cuda_shfls( + Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { // todo membermask naming? + + for (int i = g.get_local_range()[0] / 2; i > 0; i /= 2) { + auto tmp = __nvvm_shfl_sync_bfly_i32(MemberMask, x, -1, i); + x = binary_op(x, tmp); + } + return x; +} + +//TODO naming conventions are not everywhere consistent, finalize this + +template +inline __SYCL_ALWAYS_INLINE std::enable_if_t< + ext::oneapi::experimental::is_user_constructed_group_v, T> +non_uniform_shfl_T(const uint32_t MemberMask, T x, int delta) { + if constexpr (ext::oneapi::experimental::is_fixed_size_group::value) { + return __nvvm_shfl_sync_up_i32(MemberMask, x, delta, 0); + } else { + return __nvvm_shfl_sync_idx_i32(MemberMask, x, delta, 31); + } +} + +template +inline __SYCL_ALWAYS_INLINE std::enable_if_t< + ext::oneapi::experimental::is_user_constructed_group_v, T> +non_uniform_shfl(Group g, const uint32_t MemberMask, T x, int delta) { + T res; + if constexpr (std::is_same_v) { + int x_a, x_b; + asm volatile("mov.b64 {%0,%1},%2; \n\t" : "=r"(x_a), "=r"(x_b) : "l"(x)); + + auto tmp_a = non_uniform_shfl_T(MemberMask, x_a, delta); + auto tmp_b = non_uniform_shfl_T(MemberMask, x_b, delta); + asm volatile("mov.b64 %0,{%1,%2}; \n\t" + : "=l"(res) + : "r"(tmp_a), "r"(tmp_b)); + } else { + auto input = std::is_same_v ? __nvvm_bitcast_f2i(x) : x; + auto tmp_b32 = non_uniform_shfl_T(MemberMask, input, delta); + res = std::is_same_v ? __nvvm_bitcast_i2f(tmp_b32) : tmp_b32; + } + return res; +} + +// Opportunistic/Ballot group reduction using shfls +template +inline __SYCL_ALWAYS_INLINE std::enable_if_t< + ext::oneapi::experimental::is_user_constructed_group_v && + !ext::oneapi::experimental::is_fixed_size_group::value, + T> +masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + + if (MemberMask == 0xffffffff) { + for (int i = 16; i > 0; i /= 2) { + + auto tmp = __nvvm_shfl_sync_bfly_i32(MemberMask, x, -1, i); + x = binary_op(x, tmp); + } + return x; + } + + unsigned local_set_bit = g.get_local_id()[0] + 1; + + // number of elements remaining requiring binary operations + auto op_range = g.get_local_range()[0]; + + // remainder that won't have a binary partner each pass of while loop + int remainder; + + while (op_range / 2 >= 1) { + remainder = op_range % 2; + + // stride between local_ids forming a binary op + int stride = op_range / 2; + + // unfolded position of set bit in mask from shfl src lane. + int unfolded_src_set_bit = local_set_bit + stride; + + // __nvvm_fns automatically wraps around to correct bit position + // There is no performance impact on src_set_bit position wrt local_set_bit + auto tmp = non_uniform_shfl( + g, MemberMask, x, __nvvm_fns(MemberMask, 0, unfolded_src_set_bit)); + + if (!(local_set_bit == 1 && remainder != 0)) { + x = binary_op(x, tmp); + } + + op_range = std::ceil((float)op_range / 2.0f); + } + + int broadID; + int maskRev; + asm("brev.b32 %0, %1;" : "=r"(maskRev) : "r"(MemberMask)); + asm("clz.b32 %0, %1;" : "=r"(broadID) : "r"(maskRev)); + + return non_uniform_shfl(g, MemberMask, x, broadID); +} + +// Non Redux types must fall back to shfl based implementations. +template +std::enable_if_t< + std::is_same, std::false_type>::value && + ext::oneapi::experimental::is_user_constructed_group_v, + T> +masked_reduction_cuda_sm80(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + return masked_reduction_cuda_shfls(g, x, binary_op, MemberMask); +} +//// + +// todo these functions not cuda specific.. +template +inline __SYCL_ALWAYS_INLINE + std::enable_if_t::value || + sycl::detail::IsBitOR::value || + sycl::detail::IsBitXOR::value, + T> + get_identity() { + return 0; +} + +template +inline __SYCL_ALWAYS_INLINE + std::enable_if_t::value, T> + get_identity() { + return 1; +} + +template +inline __SYCL_ALWAYS_INLINE + std::enable_if_t::value, T> + get_identity() { + return ~0; +} + +#define GET_ID(OP_CHECK, OP) \ + template \ + inline __SYCL_ALWAYS_INLINE \ + std::enable_if_t::value, T> \ + get_identity() { \ + if constexpr (std::is_same_v) { \ + return std::numeric_limits::OP(); \ + } else if constexpr (std::is_same_v) { \ + return std::numeric_limits::OP(); \ + } else if constexpr (std::is_same_v) { \ + return std::numeric_limits::OP(); \ + } else if constexpr (std::is_same_v) { \ + return std::numeric_limits::OP(); \ + } else if constexpr (std::is_same_v) { \ + return std::numeric_limits::OP(); \ + } else if constexpr (std::is_same_v) { \ + return std::numeric_limits::OP(); \ + } else if constexpr (std::is_same_v) { \ + return std::numeric_limits::OP(); \ + } else if constexpr (std::is_same_v) { \ + return std::numeric_limits::OP(); \ + } else if constexpr (std::is_same_v) { \ + return std::numeric_limits::OP(); \ + } else if constexpr (std::is_same_v) { \ + return std::numeric_limits::OP(); \ + } \ + return 0; \ + } + +GET_ID(IsMinimum, max) +GET_ID(IsMaximum, min) + +#undef GET_ID + +// Cluster group scan using shfls +template <__spv::GroupOperation Op, typename Group, typename T, + class BinaryOperation> +inline __SYCL_ALWAYS_INLINE std::enable_if_t< + ext::oneapi::experimental::is_fixed_size_group::value, T> +masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { // todo membermask naming? + + for (int i = 1; i < g.get_local_range()[0]; i *= 2) { + auto tmp = + non_uniform_shfl(g, MemberMask, x, i); + if (g.get_local_id()[0] >= i) + x = binary_op(x, tmp); + } + + if constexpr (Op == __spv::GroupOperation::ExclusiveScan) { + + x = non_uniform_shfl(g, MemberMask, x, 1); + if (g.get_local_id()[0] == 0) { + return get_identity(); + } + } + return x; +} + +template <__spv::GroupOperation Op, typename Group, typename T, + class BinaryOperation> +inline __SYCL_ALWAYS_INLINE std::enable_if_t< + ext::oneapi::experimental::is_user_constructed_group_v && + !ext::oneapi::experimental::is_fixed_size_group::value, + T> +masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op, + const uint32_t MemberMask) { + + // position of this lanes set bit with respect to all set bits in mask + // local_set_bit = 1 for first set bit in mask. + // todo finalize naming convention + int local_id_val = g.get_local_id()[0]; + + int local_set_bit = local_id_val + 1; + + for (int i = 1; i < g.get_local_range()[0]; i *= 2) { + int unfolded_src_set_bit = local_set_bit - i; + + auto tmp = non_uniform_shfl( + g, MemberMask, x, __nvvm_fns(MemberMask, 0, unfolded_src_set_bit)); + if (local_id_val >= i) + x = binary_op(x, tmp); + } + + if constexpr (Op == __spv::GroupOperation::ExclusiveScan) { + x = non_uniform_shfl(g, MemberMask, x, + __nvvm_fns(MemberMask, 0, local_set_bit - 1)); + if (g.get_local_id()[0] == 0) { + return get_identity(); + } + } + return x; +} + +#endif +#endif +} // namespace detail +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/detail/spirv.hpp b/sycl/include/sycl/detail/spirv.hpp index 18ff6e4ef6c45..40e0606dd8a30 100644 --- a/sycl/include/sycl/detail/spirv.hpp +++ b/sycl/include/sycl/detail/spirv.hpp @@ -128,6 +128,7 @@ template bool GroupAll(Group, bool pred) { template bool GroupAll(ext::oneapi::experimental::ballot_group g, bool pred) { +#if defined (__SPIR__) // ballot_group partitions its parent into two groups (0 and 1) // We have to force each group down different control flow // Work-items in the "false" group (0) may still be active @@ -136,16 +137,25 @@ bool GroupAll(ext::oneapi::experimental::ballot_group g, } else { return __spirv_GroupNonUniformAll(group_scope::value, pred); } +#elif defined (__NVPTX__) + sycl::vec MemberMask = detail::ExtractMask(detail::GetMask(g)); + return __nvvm_vote_all_sync(MemberMask[0], pred); +#endif } template bool GroupAll( - ext::oneapi::experimental::fixed_size_group, + ext::oneapi::experimental::fixed_size_group g, bool pred) { + #if defined (__SPIR__) // GroupNonUniformAll doesn't support cluster size, so use a reduction return __spirv_GroupNonUniformBitwiseAnd( group_scope::value, static_cast(__spv::GroupOperation::ClusteredReduce), static_cast(pred), PartitionSize); + #elif defined (__NVPTX__) + sycl::vec MemberMask = detail::ExtractMask(detail::GetMask(g)); + return __nvvm_vote_all_sync(MemberMask[0], pred); +#endif } template bool GroupAll(ext::oneapi::experimental::tangle_group, bool pred) { @@ -164,6 +174,7 @@ template bool GroupAny(Group, bool pred) { template bool GroupAny(ext::oneapi::experimental::ballot_group g, bool pred) { +#if defined (__SPIR__) // ballot_group partitions its parent into two groups (0 and 1) // We have to force each group down different control flow // Work-items in the "false" group (0) may still be active @@ -172,6 +183,10 @@ bool GroupAny(ext::oneapi::experimental::ballot_group g, } else { return __spirv_GroupNonUniformAny(group_scope::value, pred); } +#elif defined (__NVPTX__) + sycl::vec MemberMask = detail::ExtractMask(detail::GetMask(g)); + return __nvvm_vote_any_sync(MemberMask[0], pred); +#endif } template bool GroupAny( @@ -281,6 +296,7 @@ GroupBroadcast(sycl::ext::oneapi::experimental::ballot_group g, // ballot_group partitions its parent into two groups (0 and 1) // We have to force each group down different control flow // Work-items in the "false" group (0) may still be active +#if defined(__SPIR__) if (g.get_group_id() == 1) { return __spirv_GroupNonUniformBroadcast(group_scope::value, OCLX, OCLId); @@ -288,6 +304,10 @@ GroupBroadcast(sycl::ext::oneapi::experimental::ballot_group g, return __spirv_GroupNonUniformBroadcast(group_scope::value, OCLX, OCLId); } +#elif defined(__NVPTX__) + sycl::vec MemberMask = detail::ExtractMask(detail::GetMask(g)); + return __nvvm_shfl_sync_idx_i32(MemberMask[0], x, LocalId, 31); +#endif } template EnableIfNativeBroadcast GroupBroadcast( @@ -1010,7 +1030,7 @@ ControlBarrier(Group, memory_scope FenceScope, memory_order Order) { template typename std::enable_if_t< ext::oneapi::experimental::is_user_constructed_group_v> -ControlBarrier(Group, memory_scope FenceScope, memory_order Order) { +ControlBarrier(Group g, memory_scope FenceScope, memory_order Order) { #if defined(__SPIR__) // SPIR-V does not define an instruction to synchronize partial groups. // However, most (possibly all?) of the current SPIR-V targets execute @@ -1022,7 +1042,7 @@ ControlBarrier(Group, memory_scope FenceScope, memory_order Order) { __spv::MemorySemanticsMask::WorkgroupMemory | __spv::MemorySemanticsMask::CrossWorkgroupMemory); #elif defined(__NVPTX__) - // TODO: Call syncwarp with appropriate mask extracted from the group + __nvvm_bar_warp_sync(detail::ExtractMask(detail::GetMask(g))[0]); #endif } diff --git a/sycl/include/sycl/detail/type_traits.hpp b/sycl/include/sycl/detail/type_traits.hpp index b6613ea080c03..a84c3fc40baf4 100644 --- a/sycl/include/sycl/detail/type_traits.hpp +++ b/sycl/include/sycl/detail/type_traits.hpp @@ -46,6 +46,8 @@ struct is_fixed_topology_group : std::true_type { template struct is_user_constructed_group : std::false_type {}; +template struct is_fixed_size_group : std::false_type {}; + template inline constexpr bool is_user_constructed_group_v = is_user_constructed_group::value; diff --git a/sycl/include/sycl/ext/oneapi/experimental/ballot_group.hpp b/sycl/include/sycl/ext/oneapi/experimental/ballot_group.hpp index fcdce42652075..b900c335153d2 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/ballot_group.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/ballot_group.hpp @@ -121,8 +121,8 @@ template class ballot_group { friend ballot_group get_ballot_group(ParentGroup g, bool predicate); - friend uint32_t sycl::detail::IdToMaskPosition>( - ballot_group Group, uint32_t Id); + friend sub_group_mask sycl::detail::GetMask>( + ballot_group Group); }; template diff --git a/sycl/include/sycl/ext/oneapi/experimental/fixed_size_group.hpp b/sycl/include/sycl/ext/oneapi/experimental/fixed_size_group.hpp index 3c2a1b07b74d7..a0226283fb259 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/fixed_size_group.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/fixed_size_group.hpp @@ -9,6 +9,7 @@ #pragma once #include +#include namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { @@ -111,11 +112,24 @@ template class fixed_size_group { #endif } +#if defined (__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +private: + sub_group_mask Mask; +#endif + protected: + +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + fixed_size_group(ext::oneapi::sub_group_mask mask):Mask(mask) {} +#else fixed_size_group() {} +#endif friend fixed_size_group get_fixed_size_group(ParentGroup g); + + friend sub_group_mask sycl::detail::GetMask>( + fixed_size_group Group); }; template @@ -125,7 +139,20 @@ inline std::enable_if_t> && get_fixed_size_group(Group group) { (void)group; #ifdef __SYCL_DEVICE_ONLY__ +#if defined(__NVPTX__) + uint32_t loc_id = group.get_local_linear_id(); + uint32_t loc_size = group.get_local_linear_range(); + uint32_t bits = PartitionSize == 32 + ? 0xffffffff + : ((1 << PartitionSize) - 1) + << ((loc_id / PartitionSize) * PartitionSize); + + return fixed_size_group( + sycl::detail::Builder::createSubGroupMask( + bits, loc_size)); +#else return fixed_size_group(); +#endif #else throw runtime_error("Non-uniform groups are not supported on host device.", PI_ERROR_INVALID_DEVICE); @@ -136,6 +163,10 @@ template struct is_user_constructed_group> : std::true_type {}; +template +struct is_fixed_size_group> + : std::true_type {}; + } // namespace ext::oneapi::experimental template diff --git a/sycl/include/sycl/ext/oneapi/experimental/non_uniform_groups.hpp b/sycl/include/sycl/ext/oneapi/experimental/non_uniform_groups.hpp index 0c31f9ad2290f..cd5a551788aaf 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/non_uniform_groups.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/non_uniform_groups.hpp @@ -9,11 +9,15 @@ #pragma once #include #include +#include #include #include namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext::oneapi::experimental { +template class cluster_group; +} namespace detail { @@ -39,10 +43,22 @@ inline uint32_t CallerPositionInMask(ext::oneapi::sub_group_mask Mask) { } #endif +// todo "inline" works correctly in nvptx backend too? +template +inline ext::oneapi::sub_group_mask GetMask(NonUniformGroup Group) { + return Group.Mask; +} + +template +inline ext::oneapi::sub_group_mask GetMask( + ext::oneapi::experimental::cluster_group Group) { + return Group.Mask; +} + template inline uint32_t IdToMaskPosition(NonUniformGroup Group, uint32_t Id) { // TODO: This will need to be optimized - sycl::vec MemberMask = ExtractMask(Group.Mask); + sycl::vec MemberMask = ExtractMask(GetMask(Group)); uint32_t Count = 0; for (int i = 0; i < 4; ++i) { for (int b = 0; b < 32; ++b) { diff --git a/sycl/include/sycl/ext/oneapi/experimental/opportunistic_group.hpp b/sycl/include/sycl/ext/oneapi/experimental/opportunistic_group.hpp index a1c08a35b399b..268eb819b39e7 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/opportunistic_group.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/opportunistic_group.hpp @@ -118,9 +118,8 @@ class opportunistic_group { friend opportunistic_group this_kernel::get_opportunistic_group(); - friend uint32_t - sycl::detail::IdToMaskPosition(opportunistic_group Group, - uint32_t Id); + friend sub_group_mask + sycl::detail::GetMask(opportunistic_group Group); }; namespace this_kernel { @@ -133,7 +132,10 @@ inline opportunistic_group get_opportunistic_group() { sub_group_mask mask = sycl::ext::oneapi::group_ballot(sg, true); return opportunistic_group(mask); #elif defined(__NVPTX__) - // TODO: Construct from __activemask +uint32_t active_mask; +asm volatile("activemask.b32 %0;" : "=r"(active_mask)); + sub_group_mask mask = sycl::detail::Builder::createSubGroupMask(active_mask, 32); + return opportunistic_group(mask); #endif #else throw runtime_error("Non-uniform groups are not supported on host device.", diff --git a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp index af86bdafae9d6..fbea7f442a146 100644 --- a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp @@ -94,11 +94,15 @@ struct sub_group_mask { bool none() const { return count() == 0; } uint32_t count() const { unsigned int count = 0; + #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + asm("popc.b32 %0, %1;" : "=r"(count) : "r"(Bits)); + #else auto word = (Bits & valuable_bits(bits_num)); while (word) { word &= (word - 1); count++; } + #endif return count; } uint32_t size() const { return bits_num; } @@ -168,6 +172,7 @@ struct sub_group_mask { template >> void extract_bits(marray &bits, id<1> pos = 0) const { + //todo cuda just take first elem size_t cur_pos = pos.get(0); for (auto &elem : bits) { if (cur_pos < size()) { diff --git a/sycl/include/sycl/group_algorithm.hpp b/sycl/include/sycl/group_algorithm.hpp index e4e42fc75091b..b095f646c45ad 100644 --- a/sycl/include/sycl/group_algorithm.hpp +++ b/sycl/include/sycl/group_algorithm.hpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace sycl { __SYCL_INLINE_VER_NAMESPACE(_V1) { @@ -190,12 +191,12 @@ Function for_each(Group g, Ptr first, Ptr last, Function f) { // scalar arithmetic, complex (plus only), and vector arithmetic template -std::enable_if_t<(is_group_v> && - (detail::is_scalar_arithmetic::value || - (detail::is_complex::value && - detail::is_multiplies::value)) && - detail::is_native_op::value), - T> +detail::enable_if_t<(is_group_v> || ext::oneapi::experimental::is_user_constructed_group_v && + (detail::is_scalar_arithmetic::value || + (detail::is_complex::value && + detail::is_multiplies::value)) && + detail::is_native_op::value), + T> reduce_over_group(Group g, T x, BinaryOperation binary_op) { // FIXME: Do not special-case for half precision static_assert( @@ -204,8 +205,23 @@ reduce_over_group(Group g, T x, BinaryOperation binary_op) { std::is_same_v), "Result type of binary_op must match reduction accumulation type."); #ifdef __SYCL_DEVICE_ONLY__ +#if defined(__NVPTX__) + sycl::vec MemberMask = + sycl::detail::ExtractMask(sycl::detail::GetMask(g)); + if constexpr (ext::oneapi::experimental::is_user_constructed_group_v) { +#if (__SYCL_CUDA_ARCH__ >= 800) + return detail::masked_reduction_cuda_sm80(g, x, binary_op, MemberMask[0]); +#else + return detail::masked_reduction_cuda_shfls(g, x, binary_op, MemberMask[0]); +#endif + } else { + return sycl::detail::calc<__spv::GroupOperation::Reduce>( + g, typename sycl::detail::GroupOpTag::type(), x, binary_op); + } +#else return sycl::detail::calc<__spv::GroupOperation::Reduce>( g, typename sycl::detail::GroupOpTag::type(), x, binary_op); +#endif #else (void)g; throw runtime_error("Group algorithms are not supported on host.", @@ -625,21 +641,33 @@ group_broadcast(Group g, T x) { // the three argument version is specialized thrice: scalar, complex, and // vector template -std::enable_if_t<(is_group_v> && - (detail::is_scalar_arithmetic::value || - (detail::is_complex::value && - detail::is_multiplies::value)) && - detail::is_native_op::value), - T> +detail::enable_if_t<(is_group_v> || ext::oneapi::experimental::is_user_constructed_group_v && + (detail::is_scalar_arithmetic::value || + (detail::is_complex::value && + detail::is_multiplies::value)) && + detail::is_native_op::value), + T> exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { // FIXME: Do not special-case for half precision static_assert(std::is_same_v || (std::is_same_v && std::is_same_v), "Result type of binary_op must match scan accumulation type."); + #ifdef __SYCL_DEVICE_ONLY__ +#if defined(__NVPTX__) + sycl::vec MemberMask = + sycl::detail::ExtractMask(sycl::detail::GetMask(g)); + if constexpr (ext::oneapi::experimental::is_user_constructed_group_v) { + return detail::masked_scan_cuda_shfls<__spv::GroupOperation::ExclusiveScan>(g, x, binary_op, MemberMask[0]); + } else { + return sycl::detail::calc<__spv::GroupOperation::ExclusiveScan>( + g, typename sycl::detail::GroupOpTag::type(), x, binary_op); + } +#else return sycl::detail::calc<__spv::GroupOperation::ExclusiveScan>( g, typename sycl::detail::GroupOpTag::type(), x, binary_op); +#endif #else (void)g; throw runtime_error("Group algorithms are not supported on host.", @@ -869,8 +897,19 @@ inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { std::is_same_v), "Result type of binary_op must match scan accumulation type."); #ifdef __SYCL_DEVICE_ONLY__ +#if defined(__NVPTX__) + sycl::vec MemberMask = + sycl::detail::ExtractMask(sycl::detail::GetMask(g)); + if constexpr (ext::oneapi::experimental::is_user_constructed_group_v) { + return detail::masked_scan_cuda_shfls<__spv::GroupOperation::InclusiveScan>(g, x, binary_op, MemberMask[0]); + } else { + return sycl::detail::calc<__spv::GroupOperation::InclusiveScan>( + g, typename sycl::detail::GroupOpTag::type(), x, binary_op); + } +#else return sycl::detail::calc<__spv::GroupOperation::InclusiveScan>( g, typename sycl::detail::GroupOpTag::type(), x, binary_op); +#endif #else (void)g; throw runtime_error("Group algorithms are not supported on host.",