Skip to content

[SYCL] Implement sub-group mask extension #4481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Sep 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
= SYCL_EXT_ONEAPI_GROUP_MASK
= SYCL_EXT_ONEAPI_SUB_GROUP_MASK
:source-highlighter: coderay
:coderay-linenums-mode: table

Expand All @@ -21,7 +21,7 @@ IMPORTANT: This specification is a draft.

NOTE: Khronos(R) is a registered trademark and SYCL(TM) and SPIR(TM) are trademarks of The Khronos Group Inc. OpenCL(TM) is a trademark of Apple Inc. used by permission by Khronos.

This document describes an extension which adds a `group_mask` type. Such a mask can be used to efficiently represent subsets of work-items in a group for which a given Boolean condition holds. Group mask functionality is currently limited to groups that are instances of the `sub_group` class.
This document describes an extension which adds a `sub_group_mask` type. Such a mask can be used to efficiently represent subsets of work-items in a sub-group for which a given Boolean condition holds.

== Notice

Expand Down Expand Up @@ -51,9 +51,9 @@ This extension is written against the SYCL 2020 specification, Revision 3.
This extension provides a feature-test macro as described in the core SYCL
specification section 6.3.3 "Feature test macros". Therefore, an
implementation supporting this extension must predefine the macro
`SYCL_EXT_ONEAPI_GROUP_MASK` to one of the values defined in the table below.
Applications can test for the existence of this macro to determine if the
implementation supports this feature, or applications can test the macro's
`SYCL_EXT_ONEAPI_SUB_GROUP_MASK` to one of the values defined in the table
below. Applications can test for the existence of this macro to determine if
the implementation supports this feature, or applications can test the macro's
value to determine which of the extension's APIs the implementation supports.

[%header,cols="1,5"]
Expand Down Expand Up @@ -81,18 +81,18 @@ must be encountered by all work-items in the group in converged control flow.
|===
|Function|Description

|`template <typename Group> Group::mask_type group_ballot(Group g, bool predicate = true) const`
|Return a `group_mask` representing the set of work-items in group _g_ for which _predicate_ is `true`.
|`template <typename Group> Group::mask_type group_ballot(Group g, bool predicate = true)`
|Return a `sub_group_mask` with one bit for each work-item in group _g_. A bit is set in this mask if and only if the corresponding work-item's _predicate_ is `true`.
|===

=== Group Masks

The group mask type is an opaque type, permitting implementations to use any
mask representation that has the same size and alignment across host and
device. The maximum number of bits that can be stored in a `group_mask` is
exposed as a static member variable, `group_mask::max_bits`.
device. The maximum number of bits that can be stored in a `sub_group_mask` is
exposed as a static member variable, `sub_group_mask::max_bits`.

Functions declared in the `group_mask` class can be called independently by
Functions declared in the `sub_group_mask` class can be called independently by
different work-items in the same group. An instance of a group class (e.g.
`group` or `sub_group`) is not required to manipulate a group mask.

Expand All @@ -107,7 +107,7 @@ work-item with the id `max_local_range()-1`.
|Return `true` if the bit corresponding to the specified _id_ is set in the
mask.

|`group_mask::reference operator[](id<1> id)`
|`sub_group_mask::reference operator[](id<1> id)`
|Return a reference to the bit corresponding to the specified _id_ in the mask.

|`bool test(id<1> id) const`
Expand Down Expand Up @@ -137,17 +137,15 @@ work-item with the id `max_local_range()-1`.
|Return the highest `id` with a corresponding bit set in the mask. If no bits
are set, the return value is equal to `size()`.

|`template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>> void insert_bits(T bits, id<1> pos = 0)`
|`template <typename T> void insert_bits(const T &bits, id<1> pos = 0)`
|Insert `CHAR_BIT * sizeof(T)` bits into the mask, starting from _pos_. `T`
must be an integral type or a SYCL `marray` of integral types. _pos_ must be a
multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+]
must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+]
`CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+]
`CHAR_BIT * sizeof(T)`) bits are ignored.

|`template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>> T extract_bits(id<1> pos = 0) const`
|`template <typename T> void extract_bits(T &out, id<1> pos = 0) const`
|Return `CHAR_BIT * sizeof(T)` bits from the mask, starting from _pos_. `T`
must be an integral type or a SYCL `marray` of integral types. _pos_ must be a
multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+]
must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+]
`CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+]
`CHAR_BIT * sizeof(T)`) bits of the return value are zero.

Expand Down Expand Up @@ -178,62 +176,63 @@ work-item with the id `max_local_range()-1`.
|`void flip(id<1> id)`
|Toggle the value of the bit corresponding to the specified _id_.

|`bool operator==(group_mask rhs) const`
|`bool operator==(const sub_group_mask &rhs) const`
|Return true if each bit in this mask is equal to the corresponding bit in
`rhs`.

|`bool operator!=(group_mask rhs) const`
|`bool operator!=(const sub_group_mask &rhs) const`
|Return true if any bit in this mask is not equal to the corresponding bit in
`rhs`.

|`group_mask operator &=(group_mask rhs)`
|`sub_group_mask &operator &=(const sub_group_mask &rhs)`
|Set the bits of this mask to the result of performing a bitwise AND with this
mask and `rhs`.

|`group_mask operator \|=(group_mask rhs)`
|`sub_group_mask &operator \|=(const sub_group_mask &rhs)`
|Set the bits of this mask to the result of performing a bitwise OR with this
mask and `rhs`.

|`group_mask operator ^=(group_mask rhs)`
|`sub_group_mask &operator ^=(const sub_group_mask &rhs)`
|Set the bits of this mask to the result of performing a bitwise XOR with this
mask and `rhs`.

|`group_mask operator pass:[<<=](size_t shift)`
|`sub_group_mask &operator pass:[<<=](size_t shift)`
|Set the bits of this mask to the result of shifting its bits _shift_ positions
to the left using a logical shift. Bits that are shifted out to the left are
discarded, and zeroes are shifted in from the right.

|`group_mask operator >>=(size_t shift)`
|`sub_group_mask &operator >>=(size_t shift)`
|Set the bits of this mask to the result of shifting its bits _shift_ positions
to the right using a logical shift. Bits that are shifted out to the right are
discarded, and zeroes are shifted in from the left.

|`group_mask operator ~() const`
|`sub_group_mask operator ~() const`
|Return a mask representing the result of flipping all the bits in this mask.

|`group_mask operator <<(size_t shift)`
|`sub_group_mask operator <<(size_t shift) const`
|Return a mask representing the result of shifting its bits _shift_ positions
to the left using a logical shift. Bits that are shifted out to the left are
discarded, and zeroes are shifted in from the right.

|`group_mask operator >>(size_t shift)`
|`sub_group_mask operator >>(size_t shift) const`
|Return a mask representing the result of shifting its bits _shift_ positions
to the right using a logical shift. Bits that are shifted out to the right are
discarded, and zeroes are shifted in from the left.

|===

|===
|Function|Description

|`group_mask operator &(const group_mask& lhs, const group_mask& rhs)`
|`sub_group_mask operator &(const sub_group_mask& lhs, const sub_group_mask& rhs)`
|Return a mask representing the result of performing a bitwise AND of `lhs` and
`rhs`.

|`group_mask operator \|(const group_mask& lhs, const group_mask& rhs)`
|`sub_group_mask operator \|(const sub_group_mask& lhs, const sub_group_mask& rhs)`
|Return a mask representing the result of performing a bitwise OR of `lhs` and
`rhs`.

|`group_mask operator ^(const group_mask& lhs, const group_mask& rhs)`
|`sub_group_mask operator ^(const sub_group_mask& lhs, const sub_group_mask& rhs)`
|Return a mask representing the result of performing a bitwise XOR of `lhs` and
`rhs`.

Expand All @@ -247,7 +246,7 @@ namespace sycl {
namespace ext {
namespace oneapi {

struct group_mask {
struct sub_group_mask {

// enable reference to individual bit
struct reference {
Expand All @@ -271,11 +270,11 @@ struct group_mask {
id<1> find_low() const;
id<1> find_high() const;

template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>>
void insert_bits(T bits, id<1> pos = 0);
template <typename T>
void insert_bits(const T &bits, id<1> pos = 0);

template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>>
T extract_bits(id<1> pos = 0);
template <typename T>
void extract_bits(T &out, id<1> pos = 0);

void set();
void set(id<1> id, bool value = true);
Expand All @@ -286,24 +285,24 @@ struct group_mask {
void flip();
void flip(id<1> id);

bool operator==(group_mask rhs) const;
bool operator!=(group_mask rhs) const;
bool operator==(const sub_group_mask &rhs) const;
bool operator!=(const sub_group_mask &rhs) const;

group_mask operator &=(group_mask rhs);
group_mask operator |=(group_mask rhs);
group_mask operator ^=(group_mask rhs);
group_mask operator <<=(size_t);
group_mask operator >>=(size_t rhs);
sub_group_mask &operator &=(const sub_group_mask &rhs);
sub_group_mask &operator |=(const sub_group_mask &rhs);
sub_group_mask &operator ^=(const sub_group_mask &rhs);
sub_group_mask &operator <<=(size_t n);
sub_group_mask &operator >>=(size_t n);

group_mask operator ~() const;
group_mask operator <<(size_t) const;
group_mask operator >>(size_t) const;
sub_group_mask operator ~() const;
sub_group_mask operator <<(size_t n) const;
sub_group_mask operator >>(size_t n) const;

};

group_mask operator &(const group_mask& lhs, const group_mask& rhs);
group_mask operator |(const group_mask& lhs, const group_mask& rhs);
group_mask operator ^(const group_mask& lhs, const group_mask& rhs);
sub_group_mask operator &(const sub_group_mask& lhs, const sub_group_mask& rhs);
sub_group_mask operator |(const sub_group_mask& lhs, const sub_group_mask& rhs);
sub_group_mask operator ^(const sub_group_mask& lhs, const sub_group_mask& rhs);

} // namespace oneapi
} // namespace ext
Expand All @@ -328,6 +327,7 @@ None.
|========================================
|Rev|Date|Author|Changes
|1|2021-08-11|John Pennycook|*Initial public working draft*
|2|2021-09-13|Vladimir Lazarev|*Update during implementation*
|========================================

//************************************************************************
Expand Down
3 changes: 3 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,9 @@ __spirv_ocl_prefetch(const __attribute__((opencl_global)) char *Ptr,
extern SYCL_EXTERNAL uint16_t __spirv_ConvertFToBF16INTEL(float) noexcept;
extern SYCL_EXTERNAL float __spirv_ConvertBF16ToFINTEL(uint16_t) noexcept;

__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT __ocl_vec_t<uint32_t, 4>
__spirv_GroupNonUniformBallot(uint32_t Execution, bool Predicate) noexcept;

#else // if !__SYCL_DEVICE_ONLY__

template <typename dataT>
Expand Down
1 change: 1 addition & 0 deletions sycl/include/CL/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@
#include <sycl/ext/oneapi/matrix/matrix.hpp>
#include <sycl/ext/oneapi/reduction.hpp>
#include <sycl/ext/oneapi/sub_group.hpp>
#include <sycl/ext/oneapi/sub_group_mask.hpp>
6 changes: 6 additions & 0 deletions sycl/include/CL/sycl/detail/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ template <int Dims> class range;
template <int Dims> class id;
template <int Dims> class nd_item;
template <int Dims> class h_item;
template <typename Type, std::size_t NumElements> class marray;
enum class memory_order;

namespace detail {
Expand Down Expand Up @@ -82,6 +83,11 @@ class Builder {
return group<Dims>(Global, Local, Global / Local, Index);
}

template <class ResType>
static ResType createSubGroupMask(uint32_t Bits, size_t BitsNum) {
return ResType(Bits, BitsNum);
}

template <int Dims, bool WithOffset>
static detail::enable_if_t<WithOffset, item<Dims, WithOffset>>
createItem(const range<Dims> &Extent, const id<Dims> &Index,
Expand Down
1 change: 1 addition & 0 deletions sycl/include/CL/sycl/feature_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace sycl {

// TODO: Move these feature-test macros to compiler driver.
#define SYCL_EXT_INTEL_DEVICE_INFO 2
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 1
#define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1
// As for SYCL_EXT_ONEAPI_MATRIX:
// 1- provides AOT initial implementation for AMX for the experimental matrix
Expand Down
12 changes: 6 additions & 6 deletions sycl/include/CL/sycl/marray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ template <typename Type, std::size_t NumElements> class marray {
}

#define __SYCL_BINOP_INTEGRAL(BINOP, OPASSIGN) \
template <typename T = DataT> \
friend typename std::enable_if<std::is_integral<T>::value, marray> \
operator BINOP(const marray &Lhs, const marray &Rhs) { \
template <typename T = DataT, \
typename = std::enable_if<std::is_integral<T>::value, marray>> \
friend marray operator BINOP(const marray &Lhs, const marray &Rhs) { \
marray Ret; \
for (size_t I = 0; I < NumElements; ++I) { \
Ret[I] = Lhs[I] BINOP Rhs[I]; \
Expand All @@ -166,9 +166,9 @@ template <typename Type, std::size_t NumElements> class marray {
operator BINOP(const marray &Lhs, const T &Rhs) { \
return Lhs BINOP marray(static_cast<DataT>(Rhs)); \
} \
template <typename T = DataT> \
friend typename std::enable_if<std::is_integral<T>::value, marray> \
&operator OPASSIGN(marray &Lhs, const marray &Rhs) { \
template <typename T = DataT, \
typename = std::enable_if<std::is_integral<T>::value, marray>> \
friend marray &operator OPASSIGN(marray &Lhs, const marray &Rhs) { \
Lhs = Lhs BINOP Rhs; \
return Lhs; \
} \
Expand Down
Loading