Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Add ballot_group support to algorithms #8784

Merged
merged 20 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,68 @@ __SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT int
__spirv_GroupNonUniformBallotFindLSB(__spv::Scope::Flag,
__ocl_vec_t<uint32_t, 4>) noexcept;

template <typename ValueT, typename IdT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformBroadcast(__spv::Scope::Flag, ValueT, IdT);

__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT bool
__spirv_GroupNonUniformAll(__spv::Scope::Flag, bool);

__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT bool
__spirv_GroupNonUniformAny(__spv::Scope::Flag, bool);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformSMin(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformUMin(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformFMin(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformSMax(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformUMax(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformFMax(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformIAdd(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformFAdd(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformIMul(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformFMul(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformBitwiseOr(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformBitwiseXor(__spv::Scope::Flag, unsigned int, ValueT);

template <typename ValueT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformBitwiseAnd(__spv::Scope::Flag, unsigned int, ValueT);

extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT void
__clc_BarrierInitialize(int64_t *state, int32_t expected_count) noexcept;

Expand Down
190 changes: 176 additions & 14 deletions sycl/include/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <sycl/detail/generic_type_traits.hpp>
#include <sycl/detail/helpers.hpp>
#include <sycl/detail/type_traits.hpp>
#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp>
#include <sycl/id.hpp>
#include <sycl/memory_enums.hpp>

Expand All @@ -23,6 +24,9 @@ __SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext {
namespace oneapi {
struct sub_group;
namespace experimental {
template <typename ParentGroup> class ballot_group;
} // namespace experimental
} // namespace oneapi
} // namespace ext

Expand Down Expand Up @@ -56,6 +60,11 @@ template <> struct group_scope<::sycl::ext::oneapi::sub_group> {
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
};

template <typename ParentGroup>
struct group_scope<sycl::ext::oneapi::experimental::ballot_group<ParentGroup>> {
static constexpr __spv::Scope::Flag value = group_scope<ParentGroup>::value;
};

// Generic shuffles and broadcasts may require multiple calls to
// intrinsics, and should use the fewest broadcasts possible
// - Loop over chunks until remaining bytes < chunk size
Expand Down Expand Up @@ -94,13 +103,35 @@ void GenericCall(const Functor &ApplyToBytes) {
}
}

template <typename Group> bool GroupAll(bool pred) {
template <typename Group> bool GroupAll(Group g, bool pred) {
steffenlarsen marked this conversation as resolved.
Show resolved Hide resolved
return __spirv_GroupAll(group_scope<Group>::value, pred);
}
template <typename ParentGroup>
bool GroupAll(ext::oneapi::experimental::ballot_group<ParentGroup> g,
bool pred) {
// Each ballot_group implicitly represents two groups
// We have to force each half down different control flow
if (g.get_group_id() == 1) {
Pennycook marked this conversation as resolved.
Show resolved Hide resolved
return __spirv_GroupNonUniformAll(group_scope<ParentGroup>::value, pred);
} else {
return __spirv_GroupNonUniformAll(group_scope<ParentGroup>::value, pred);
}
}

template <typename Group> bool GroupAny(bool pred) {
template <typename Group> bool GroupAny(Group g, bool pred) {
return __spirv_GroupAny(group_scope<Group>::value, pred);
}
template <typename ParentGroup>
bool GroupAny(ext::oneapi::experimental::ballot_group<ParentGroup> g,
bool pred) {
// Each ballot_group implicitly represents two groups
// We have to force each half down different control flow
if (g.get_group_id() == 1) {
return __spirv_GroupNonUniformAny(group_scope<ParentGroup>::value, pred);
} else {
return __spirv_GroupNonUniformAny(group_scope<ParentGroup>::value, pred);
}
}

// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
// FIXME: Do not special-case for half once all backends support all data types.
Expand Down Expand Up @@ -157,7 +188,7 @@ template <> struct GroupId<::sycl::ext::oneapi::sub_group> {
using type = uint32_t;
};
template <typename Group, typename T, typename IdT>
EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
EnableIfNativeBroadcast<T, IdT> GroupBroadcast(Group, T x, IdT local_id) {
using GroupIdT = typename GroupId<Group>::type;
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
using OCLT = detail::ConvertToOpenCLType_t<T>;
Expand All @@ -167,23 +198,50 @@ EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
}
template <typename ParentGroup, typename T, typename IdT>
EnableIfNativeBroadcast<T, IdT>
GroupBroadcast(sycl::ext::oneapi::experimental::ballot_group<ParentGroup> g,
T x, IdT local_id) {
// Remap local_id to its original numbering in ParentGroup
Pennycook marked this conversation as resolved.
Show resolved Hide resolved
auto LocalId = detail::IdToMaskPosition(g, local_id);

// TODO: Refactor to avoid duplication after design settles
Pennycook marked this conversation as resolved.
Show resolved Hide resolved
using GroupIdT = typename GroupId<ParentGroup>::type;
GroupIdT GroupLocalId = static_cast<GroupIdT>(LocalId);
using OCLT = detail::ConvertToOpenCLType_t<T>;
using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);

// Each ballot_group implicitly represents two groups
// We have to force each half down different control flow
steffenlarsen marked this conversation as resolved.
Show resolved Hide resolved
if (g.get_group_id() == 1) {
return __spirv_GroupNonUniformBroadcast(group_scope<ParentGroup>::value,
OCLX, OCLId);
} else {
return __spirv_GroupNonUniformBroadcast(group_scope<ParentGroup>::value,
OCLX, OCLId);
}
}

template <typename Group, typename T, typename IdT>
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(Group g, T x, IdT local_id) {
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
auto BroadcastX = bit_cast<BroadcastT>(x);
BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
BroadcastT Result = GroupBroadcast(g, BroadcastX, local_id);
return bit_cast<T>(Result);
}
template <typename Group, typename T, typename IdT>
EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
EnableIfGenericBroadcast<T, IdT> GroupBroadcast(Group g, T x, IdT local_id) {
// Initialize with x to support type T without default constructor
T Result = x;
char *XBytes = reinterpret_cast<char *>(&x);
char *ResultBytes = reinterpret_cast<char *>(&Result);
auto BroadcastBytes = [=](size_t Offset, size_t Size) {
uint64_t BroadcastX, BroadcastResult;
std::memcpy(&BroadcastX, XBytes + Offset, Size);
BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
BroadcastResult = GroupBroadcast(g, BroadcastX, local_id);
std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
};
GenericCall<T>(BroadcastBytes);
Expand All @@ -192,9 +250,10 @@ EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {

// Broadcast with vector local index
template <typename Group, typename T, int Dimensions>
EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
EnableIfNativeBroadcast<T> GroupBroadcast(Group g, T x,
id<Dimensions> local_id) {
if (Dimensions == 1) {
return GroupBroadcast<Group>(x, local_id[0]);
return GroupBroadcast(g, x, local_id[0]);
}
using IdT = vec<size_t, Dimensions>;
using OCLT = detail::ConvertToOpenCLType_t<T>;
Expand All @@ -208,17 +267,26 @@ EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
}
template <typename ParentGroup, typename T>
EnableIfNativeBroadcast<T>
GroupBroadcast(sycl::ext::oneapi::experimental::ballot_group<ParentGroup> g,
T x, id<1> local_id) {
// Limited to 1D indices for now because ParentGroup must be sub-group
Pennycook marked this conversation as resolved.
Show resolved Hide resolved
return GroupBroadcast(g, x, local_id[0]);
}
template <typename Group, typename T, int Dimensions>
EnableIfBitcastBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
EnableIfBitcastBroadcast<T> GroupBroadcast(Group g, T x,
id<Dimensions> local_id) {
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
auto BroadcastX = bit_cast<BroadcastT>(x);
BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
BroadcastT Result = GroupBroadcast(g, BroadcastX, local_id);
return bit_cast<T>(Result);
}
template <typename Group, typename T, int Dimensions>
EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
EnableIfGenericBroadcast<T> GroupBroadcast(Group g, T x,
id<Dimensions> local_id) {
if (Dimensions == 1) {
return GroupBroadcast<Group>(x, local_id[0]);
return GroupBroadcast(g, x, local_id[0]);
}
// Initialize with x to support type T without default constructor
T Result = x;
Expand All @@ -227,7 +295,7 @@ EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
auto BroadcastBytes = [=](size_t Offset, size_t Size) {
uint64_t BroadcastX, BroadcastResult;
std::memcpy(&BroadcastX, XBytes + Offset, Size);
BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
BroadcastResult = GroupBroadcast(g, BroadcastX, local_id);
std::memcpy(ResultBytes + Offset, &BroadcastResult, Size);
};
GenericCall<T>(BroadcastBytes);
Expand Down Expand Up @@ -801,6 +869,100 @@ EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
return Result;
}

template <typename Group>
typename std::enable_if_t<
ext::oneapi::experimental::is_fixed_topology_group_v<Group>>
ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
__spirv_ControlBarrier(group_scope<Group>::value, getScope(FenceScope),
getMemorySemanticsMask(Order) |
__spv::MemorySemanticsMask::SubgroupMemory |
__spv::MemorySemanticsMask::WorkgroupMemory |
__spv::MemorySemanticsMask::CrossWorkgroupMemory);
}

template <typename Group>
typename std::enable_if_t<
ext::oneapi::experimental::is_user_constructed_group_v<Group>>
ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
#if defined(__SPIR__)
// SPIR-V does not define an instruction to synchronize partial groups.
// However, most (possibly all?) of the current SPIR-V targets execute
// work-items in lockstep, so we can probably get away with a MemoryBarrier.
// TODO: Replace this if SPIR-V defines a NonUniformControlBarrier
__spirv_MemoryBarrier(getScope(FenceScope),
getMemorySemanticsMask(Order) |
__spv::MemorySemanticsMask::SubgroupMemory |
__spv::MemorySemanticsMask::WorkgroupMemory |
__spv::MemorySemanticsMask::CrossWorkgroupMemory);
#elif defined(__NVPTX__)
// TODO: Call syncwarp with appropriate mask extracted from the group
#endif
}

// TODO: Refactor to avoid duplication after design settles
#define __SYCL_GROUP_COLLECTIVE_OVERLOAD(Instruction) \
template <__spv::GroupOperation Op, typename Group, typename T> \
inline typename std::enable_if_t< \
ext::oneapi::experimental::is_fixed_topology_group_v<Group>, T> \
Group##Instruction(Group G, T x) { \
using ConvertedT = detail::ConvertToOpenCLType_t<T>; \
\
using OCLT = \
conditional_t<std::is_same<ConvertedT, cl_char>() || \
std::is_same<ConvertedT, cl_short>(), \
cl_int, \
conditional_t<std::is_same<ConvertedT, cl_uchar>() || \
std::is_same<ConvertedT, cl_ushort>(), \
cl_uint, ConvertedT>>; \
OCLT Arg = x; \
OCLT Ret = __spirv_Group##Instruction(group_scope<Group>::value, \
static_cast<unsigned int>(Op), Arg); \
return Ret; \
} \
\
template <__spv::GroupOperation Op, typename ParentGroup, typename T> \
inline T Group##Instruction( \
ext::oneapi::experimental::ballot_group<ParentGroup> g, T x) { \
using ConvertedT = detail::ConvertToOpenCLType_t<T>; \
\
using OCLT = \
conditional_t<std::is_same<ConvertedT, cl_char>() || \
std::is_same<ConvertedT, cl_short>(), \
cl_int, \
conditional_t<std::is_same<ConvertedT, cl_uchar>() || \
std::is_same<ConvertedT, cl_ushort>(), \
cl_uint, ConvertedT>>; \
OCLT Arg = x; \
/* Each ballot_group implicitly represents two groups */ \
/* We have to force each half down different control flow */ \
constexpr auto Scope = group_scope<ParentGroup>::value; \
constexpr auto OpInt = static_cast<unsigned int>(Op); \
if (g.get_group_id() == 1) { \
return __spirv_GroupNonUniform##Instruction(Scope, OpInt, Arg); \
} else { \
return __spirv_GroupNonUniform##Instruction(Scope, OpInt, Arg); \
} \
}

__SYCL_GROUP_COLLECTIVE_OVERLOAD(SMin)
__SYCL_GROUP_COLLECTIVE_OVERLOAD(UMin)
__SYCL_GROUP_COLLECTIVE_OVERLOAD(FMin)

__SYCL_GROUP_COLLECTIVE_OVERLOAD(SMax)
__SYCL_GROUP_COLLECTIVE_OVERLOAD(UMax)
__SYCL_GROUP_COLLECTIVE_OVERLOAD(FMax)

__SYCL_GROUP_COLLECTIVE_OVERLOAD(IAdd)
__SYCL_GROUP_COLLECTIVE_OVERLOAD(FAdd)

__SYCL_GROUP_COLLECTIVE_OVERLOAD(IMulKHR)
__SYCL_GROUP_COLLECTIVE_OVERLOAD(FMulKHR)
__SYCL_GROUP_COLLECTIVE_OVERLOAD(CMulINTEL)

__SYCL_GROUP_COLLECTIVE_OVERLOAD(BitwiseOrKHR)
__SYCL_GROUP_COLLECTIVE_OVERLOAD(BitwiseXorKHR)
__SYCL_GROUP_COLLECTIVE_OVERLOAD(BitwiseAndKHR)

} // namespace spirv
} // namespace detail
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
Expand Down
23 changes: 23 additions & 0 deletions sycl/include/sycl/detail/type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,29 @@ struct sub_group;
namespace experimental {
template <typename Group, std::size_t Extent> class group_with_scratchpad;

template <class T> struct is_fixed_topology_group : std::false_type {};

template <class T>
inline constexpr bool is_fixed_topology_group_v =
is_fixed_topology_group<T>::value;

#ifdef SYCL_EXT_ONEAPI_ROOT_GROUP
template <> struct is_fixed_topology_group<root_group> : std::true_type {};
#endif

template <int Dimensions>
struct is_fixed_topology_group<sycl::group<Dimensions>> : std::true_type {};

template <>
struct is_fixed_topology_group<sycl::ext::oneapi::sub_group> : std::true_type {
};

template <class T> struct is_user_constructed_group : std::false_type {};

template <class T>
inline constexpr bool is_user_constructed_group_v =
is_user_constructed_group<T>::value;

namespace detail {
template <typename T> struct is_group_helper : std::false_type {};

Expand Down
Loading