Skip to content

[SYCL] Add non-uniform group classes #8202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 17, 2023
10 changes: 10 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,16 @@ __SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
__SYCL_EXPORT __ocl_vec_t<uint32_t, 4>
__spirv_GroupNonUniformBallot(uint32_t Execution, bool Predicate) noexcept;

// TODO: I'm not 100% sure that these NonUniform instructions should be
// convergent Following precedent set for GroupNonUniformBallot above
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT uint32_t
__spirv_GroupNonUniformBallotBitCount(__spv::Scope::Flag, int,
__ocl_vec_t<uint32_t, 4>) noexcept;

__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT int
__spirv_GroupNonUniformBallotFindLSB(__spv::Scope::Flag,
__ocl_vec_t<uint32_t, 4>) noexcept;

extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT void
__clc_BarrierInitialize(int64_t *state, int32_t expected_count) noexcept;

Expand Down
6 changes: 6 additions & 0 deletions sycl/include/CL/__spirv/spirv_vars.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ __SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumSubgroups;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupId;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupLocalInvocationId;

__SPIRV_VAR_QUALIFIERS __ocl_vec_t<uint32_t, 4> __spirv_BuiltInSubgroupEqMask;
__SPIRV_VAR_QUALIFIERS __ocl_vec_t<uint32_t, 4> __spirv_BuiltInSubgroupGeMask;
__SPIRV_VAR_QUALIFIERS __ocl_vec_t<uint32_t, 4> __spirv_BuiltInSubgroupGtMask;
__SPIRV_VAR_QUALIFIERS __ocl_vec_t<uint32_t, 4> __spirv_BuiltInSubgroupLeMask;
__SPIRV_VAR_QUALIFIERS __ocl_vec_t<uint32_t, 4> __spirv_BuiltInSubgroupLtMask;

__DPCPP_SYCL_EXTERNAL inline size_t __spirv_GlobalInvocationId_x() {
return __spirv_BuiltInGlobalInvocationId.x;
}
Expand Down
153 changes: 153 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/ballot_group.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
//==------ ballot_group.hpp --- SYCL extension for non-uniform groups ------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once
#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp>
#include <sycl/ext/oneapi/sub_group_mask.hpp>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext::oneapi::experimental {

template <typename ParentGroup> class ballot_group;

template <typename Group>
inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
std::is_same_v<Group, sycl::sub_group>,
ballot_group<Group>>
get_ballot_group(Group group, bool predicate);

template <typename ParentGroup> class ballot_group {
public:
using id_type = id<1>;
using range_type = range<1>;
using linear_id_type = typename ParentGroup::linear_id_type;
static constexpr int dimensions = 1;
static constexpr sycl::memory_scope fence_scope = ParentGroup::fence_scope;

id_type get_group_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return (Predicate) ? 1 : 0;
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

id_type get_local_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return detail::CallerPositionInMask(Mask);
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

range_type get_group_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return 2;
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

range_type get_local_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return Mask.count();
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

linear_id_type get_group_linear_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return static_cast<linear_id_type>(get_group_id()[0]);
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

linear_id_type get_local_linear_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return static_cast<linear_id_type>(get_local_id()[0]);
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

linear_id_type get_group_linear_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return static_cast<linear_id_type>(get_group_range()[0]);
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

linear_id_type get_local_linear_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return static_cast<linear_id_type>(get_local_range()[0]);
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

bool leader() const {
#ifdef __SYCL_DEVICE_ONLY__
uint32_t Lowest = static_cast<uint32_t>(Mask.find_low()[0]);
return __spirv_SubgroupLocalInvocationId() == Lowest;
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

private:
sub_group_mask Mask;
bool Predicate;

protected:
ballot_group(sub_group_mask m, bool p) : Mask(m), Predicate(p) {}

friend ballot_group<ParentGroup>
get_ballot_group<ParentGroup>(ParentGroup g, bool predicate);
};

template <typename Group>
inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
std::is_same_v<Group, sycl::sub_group>,
ballot_group<Group>>
get_ballot_group(Group group, bool predicate) {
(void)group;
#ifdef __SYCL_DEVICE_ONLY__
// ballot_group partitions into two groups using the predicate
// Membership mask for one group is negation of the other
sub_group_mask mask = sycl::ext::oneapi::group_ballot(group, predicate);
if (predicate) {
return ballot_group<sycl::sub_group>(mask, predicate);
} else {
return ballot_group<sycl::sub_group>(~mask, predicate);
}
#else
(void)predicate;
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

template <typename ParentGroup>
struct is_user_constructed_group<ballot_group<ParentGroup>> : std::true_type {};

} // namespace ext::oneapi::experimental
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
141 changes: 141 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/cluster_group.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
//==------ cluster_group.hpp --- SYCL extension for non-uniform groups -----==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext::oneapi::experimental {

template <size_t ClusterSize, typename ParentGroup> class cluster_group;

template <size_t ClusterSize, typename Group>
inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
std::is_same_v<Group, sycl::sub_group>,
cluster_group<ClusterSize, Group>>
get_cluster_group(Group group);

template <size_t ClusterSize, typename ParentGroup> class cluster_group {
public:
using id_type = id<1>;
using range_type = range<1>;
using linear_id_type = typename ParentGroup::linear_id_type;
static constexpr int dimensions = 1;
static constexpr sycl::memory_scope fence_scope = ParentGroup::fence_scope;

id_type get_group_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_SubgroupLocalInvocationId() / ClusterSize;
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

id_type get_local_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_SubgroupLocalInvocationId() % ClusterSize;
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

range_type get_group_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_SubgroupMaxSize() / ClusterSize;
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

range_type get_local_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return ClusterSize;
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

linear_id_type get_group_linear_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return static_cast<linear_id_type>(get_group_id()[0]);
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

linear_id_type get_local_linear_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return static_cast<linear_id_type>(get_local_id()[0]);
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

linear_id_type get_group_linear_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return static_cast<linear_id_type>(get_group_range()[0]);
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

linear_id_type get_local_linear_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return static_cast<linear_id_type>(get_local_range()[0]);
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

bool leader() const {
#ifdef __SYCL_DEVICE_ONLY__
return get_local_linear_id() == 0;
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

protected:
cluster_group() {}

friend cluster_group<ClusterSize, ParentGroup>
get_cluster_group<ClusterSize, ParentGroup>(ParentGroup g);
};

template <size_t ClusterSize, typename Group>
inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
std::is_same_v<Group, sycl::sub_group>,
cluster_group<ClusterSize, Group>>
get_cluster_group(Group group) {
(void)group;
#ifdef __SYCL_DEVICE_ONLY__
return cluster_group<ClusterSize, sycl::sub_group>();
#else
throw runtime_error("Non-uniform groups are not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

template <size_t ClusterSize, typename ParentGroup>
struct is_user_constructed_group<cluster_group<ClusterSize, ParentGroup>>
: std::true_type {};

} // namespace ext::oneapi::experimental
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//==--- non_uniform_groups.hpp --- SYCL extension for non-uniform groups ---==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once
#include <CL/__spirv/spirv_ops.hpp>
#include <CL/__spirv/spirv_vars.hpp>
#include <sycl/ext/oneapi/sub_group_mask.hpp>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext::oneapi::experimental {

template <class T> struct is_fixed_topology_group : std::false_type {};

template <class T>
inline constexpr bool is_fixed_topology_group_v =
is_fixed_topology_group<T>::value;

#ifdef SYCL_EXT_ONEAPI_ROOT_GROUP
template <> struct is_fixed_topology_group<root_group> : std::true_type {};
#endif

template <int Dimensions>
struct is_fixed_topology_group<sycl::group<Dimensions>> : std::true_type {};

template <> struct is_fixed_topology_group<sycl::sub_group> : std::true_type {};

template <class T> struct is_user_constructed_group : std::false_type {};

template <class T>
inline constexpr bool is_user_constructed_group_v =
is_user_constructed_group<T>::value;

#ifdef __SYCL_DEVICE_ONLY__
// TODO: This may need to be generalized beyond uint32_t for big masks
namespace detail {
uint32_t CallerPositionInMask(sub_group_mask Mask) {
// FIXME: It would be nice to be able to jump straight to an __ocl_vec_t
sycl::marray<unsigned, 4> TmpMArray;
Mask.extract_bits(TmpMArray);
sycl::vec<unsigned, 4> MemberMask;
for (int i = 0; i < 4; ++i) {
MemberMask[i] = TmpMArray[i];
}
auto OCLMask =
sycl::detail::ConvertToOpenCLType_t<sycl::vec<unsigned, 4>>(MemberMask);
return __spirv_GroupNonUniformBallotBitCount(
__spv::Scope::Subgroup, (int)__spv::GroupOperation::ExclusiveScan,
OCLMask);
}
} // namespace detail
#endif

} // namespace ext::oneapi::experimental
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
Loading