From 880ae2a269647440d2bb9055561e0cbf639fa035 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Tue, 24 Aug 2021 11:40:14 +0300 Subject: [PATCH 01/15] [SYCL] Implement GroupMask extension Specification is available under https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/GroupMask/GroupMask.asciidoc --- .../extensions/GroupMask/GroupMask.asciidoc | 23 ++-- sycl/include/CL/__spirv/spirv_ops.hpp | 3 + sycl/include/CL/sycl.hpp | 1 + sycl/include/CL/sycl/feature_test.hpp | 1 + sycl/include/CL/sycl/marray.hpp | 12 +- sycl/include/sycl/ext/oneapi/group_mask.hpp | 118 ++++++++++++++++++ sycl/source/CMakeLists.txt | 1 + sycl/source/ext/group_mask.cpp | 34 +++++ sycl/test/check_device_code/group_mask.cpp | 8 ++ 9 files changed, 184 insertions(+), 17 deletions(-) create mode 100644 sycl/include/sycl/ext/oneapi/group_mask.hpp create mode 100644 sycl/source/ext/group_mask.cpp create mode 100644 sycl/test/check_device_code/group_mask.cpp diff --git a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc b/sycl/doc/extensions/GroupMask/GroupMask.asciidoc index d95fd6c5b12be..8be3de09bb98b 100755 --- a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc +++ b/sycl/doc/extensions/GroupMask/GroupMask.asciidoc @@ -81,7 +81,7 @@ must be encountered by all work-items in the group in converged control flow. |=== |Function|Description -|`template Group::mask_type group_ballot(Group g, bool predicate = true) const` +|`template Group::mask_type group_ballot(Group g, bool predicate = true)` |Return a `group_mask` representing the set of work-items in group _g_ for which _predicate_ is `true`. |=== @@ -137,14 +137,14 @@ work-item with the id `max_local_range()-1`. |Return the highest `id` with a corresponding bit set in the mask. If no bits are set, the return value is equal to `size()`. -|`template > void insert_bits(T bits, id<1> pos = 0)` +|`template > void insert_bits(const T& bits, id<1> pos = 0)` |Insert `CHAR_BIT * sizeof(T)` bits into the mask, starting from _pos_. `T` must be an integral type or a SYCL `marray` of integral types. _pos_ must be a multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+] `CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+] `CHAR_BIT * sizeof(T)`) bits are ignored. -|`template > T extract_bits(id<1> pos = 0) const` +|`template > T extract_bits(id<1> pos = 0) const` |Return `CHAR_BIT * sizeof(T)` bits from the mask, starting from _pos_. `T` must be an integral type or a SYCL `marray` of integral types. _pos_ must be a multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+] @@ -259,6 +259,7 @@ struct group_mask { }; static constexpr size_t max_bits = /* implementation-defined */; + static constexpr size_t marray_size = max_bits/sizeof(uint32_t)/8; bool operator[](id<1> id) const; reference operator[](id<1> id); @@ -271,10 +272,10 @@ struct group_mask { id<1> find_low() const; id<1> find_high() const; - template > - void insert_bits(T bits, id<1> pos = 0); + template > + void insert_bits(const T& bits, id<1> pos = 0); - template > + template > T extract_bits(id<1> pos = 0); void set(); @@ -286,12 +287,12 @@ struct group_mask { void flip(); void flip(id<1> id); - bool operator==(group_mask rhs) const; - bool operator!=(group_mask rhs) const; + bool operator==(const group_mask& rhs) const; + bool operator!=(const group_mask& rhs) const; - group_mask operator &=(group_mask rhs); - group_mask operator |=(group_mask rhs); - group_mask operator ^=(group_mask rhs); + group_mask operator &=(const group_mask& rhs); + group_mask operator |=(const group_mask& rhs); + group_mask operator ^=(const group_mask& rhs); group_mask operator <<=(size_t); group_mask operator >>=(size_t rhs); diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 18ef03cc70607..7e6a2aad1f5fb 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -597,6 +597,9 @@ __spirv_ocl_prefetch(const __attribute__((opencl_global)) char *Ptr, extern SYCL_EXTERNAL uint16_t __spirv_ConvertFToBF16INTEL(float) noexcept; extern SYCL_EXTERNAL float __spirv_ConvertBF16ToFINTEL(uint16_t) noexcept; +__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT __ocl_vec_t +__spirv_GroupNonUniformBallot(uint32_t Execution, bool Predicate) noexcept; + #else // if !__SYCL_DEVICE_ONLY__ template diff --git a/sycl/include/CL/sycl.hpp b/sycl/include/CL/sycl.hpp index 75f0cedd57c44..d307b9053aca2 100644 --- a/sycl/include/CL/sycl.hpp +++ b/sycl/include/CL/sycl.hpp @@ -56,6 +56,7 @@ #include #include #include +#include #include #include #include diff --git a/sycl/include/CL/sycl/feature_test.hpp b/sycl/include/CL/sycl/feature_test.hpp index 4625cfa06fed7..1f6ec558f13cd 100644 --- a/sycl/include/CL/sycl/feature_test.hpp +++ b/sycl/include/CL/sycl/feature_test.hpp @@ -14,6 +14,7 @@ namespace sycl { // TODO: Move these feature-test macros to compiler driver. #define SYCL_EXT_INTEL_DEVICE_INFO 2 +#define SYCL_EXT_ONEAPI_GROUP_MASK 1 #define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1 // As for SYCL_EXT_ONEAPI_MATRIX: // 1- provides AOT initial implementation for AMX for the experimental matrix diff --git a/sycl/include/CL/sycl/marray.hpp b/sycl/include/CL/sycl/marray.hpp index 5b758b80683d0..0267f0a85ff8a 100644 --- a/sycl/include/CL/sycl/marray.hpp +++ b/sycl/include/CL/sycl/marray.hpp @@ -149,9 +149,9 @@ template class marray { } #define __SYCL_BINOP_INTEGRAL(BINOP, OPASSIGN) \ - template \ - friend typename std::enable_if::value, marray> \ - operator BINOP(const marray &Lhs, const marray &Rhs) { \ + template ::value, marray>> \ + friend marray operator BINOP(const marray &Lhs, const marray &Rhs) { \ marray Ret; \ for (size_t I = 0; I < NumElements; ++I) { \ Ret[I] = Lhs[I] BINOP Rhs[I]; \ @@ -166,9 +166,9 @@ template class marray { operator BINOP(const marray &Lhs, const T &Rhs) { \ return Lhs BINOP marray(static_cast(Rhs)); \ } \ - template \ - friend typename std::enable_if::value, marray> \ - &operator OPASSIGN(marray &Lhs, const marray &Rhs) { \ + template ::value, marray>> \ + friend marray &operator OPASSIGN(marray &Lhs, const marray &Rhs) { \ Lhs = Lhs BINOP Rhs; \ return Lhs; \ } \ diff --git a/sycl/include/sycl/ext/oneapi/group_mask.hpp b/sycl/include/sycl/ext/oneapi/group_mask.hpp new file mode 100644 index 0000000000000..a87ecb766fb1b --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/group_mask.hpp @@ -0,0 +1,118 @@ +//==----------------- group_mask.hpp --- SYCL group mask -------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#pragma once + +#include +#include +#include +#include +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace ext { +namespace oneapi { + +struct group_mask { + + // enable reference to individual bit + struct reference { + reference &operator=(bool x); + reference &operator=(const reference &x); + bool operator~() const; + operator bool() const; + reference &flip(); + }; + + static constexpr size_t max_bits = 128 /* implementation-defined */; + /* Bitmask is packed in marray of uint32_t elements. This value represents + * legth of marray. */ + static constexpr size_t marray_size = max_bits / sizeof(uint32_t) / CHAR_BIT; + + bool operator[](id<1> id) const; + reference operator[](id<1> id); + bool test(id<1> id) const; + bool all() const; + bool any() const; + bool none() const; + uint32_t count() const; + uint32_t size() const; + id<1> find_low() const; + id<1> find_high() const; + + template > + void insert_bits(const T &bits, id<1> pos = 0); + + template > + T extract_bits(id<1> pos = 0); + + void set(); + void set(id<1> id, bool value = true); + void reset(); + void reset(id<1> id); + void reset_low(); + void reset_high(); + void flip(); + void flip(id<1> id); + + bool operator==(const group_mask &rhs) const { return Bits == rhs.Bits; } + bool operator!=(const group_mask &rhs) const { return Bits != rhs.Bits; } + + group_mask &operator&=(const group_mask &rhs) { + Bits &= rhs.Bits; + return *this; + } + group_mask &operator|=(const group_mask &rhs) { + Bits |= rhs.Bits; + return *this; + } + + group_mask &operator^=(const group_mask &rhs) { + Bits ^= rhs.Bits; + return *this; + } + + group_mask &operator<<=(size_t); + group_mask &operator>>=(size_t rhs); + + group_mask operator~() const { + auto Tmp = *this; + Tmp.flip(); + return Tmp; + } + group_mask &operator<<(size_t) const; + group_mask &operator>>(size_t) const; + group_mask(const group_mask &rhs) : Bits(rhs.Bits) {} + template + friend group_mask group_ballot(Group g, bool predicate); +protected: + group_mask(const marray &rhs) : Bits(rhs) {} + marray Bits; +}; + +group_mask operator&(const group_mask &lhs, const group_mask &rhs); +group_mask operator|(const group_mask &lhs, const group_mask &rhs); +group_mask operator^(const group_mask &lhs, const group_mask &rhs); + +template +group_mask group_ballot(Group g, bool predicate) { + (void)g; +#ifdef __SYCL_DEVICE_ONLY__ + auto res = __spirv_GroupNonUniformBallot( + detail::spirv::group_scope::value, predicate); + return marray{res[3], res[2], res[1], res[0]}; +#else + (void)predicate; + throw exception{errc::feature_not_supported, + "Group mask is not supported on host device"}; +#endif +} +} // namespace oneapi +} // namespace ext +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/source/CMakeLists.txt b/sycl/source/CMakeLists.txt index fd6e120107c92..7a0a8e0d7b174 100644 --- a/sycl/source/CMakeLists.txt +++ b/sycl/source/CMakeLists.txt @@ -149,6 +149,7 @@ set(SYCL_SOURCES "detail/sycl_mem_obj_t.cpp" "detail/usm/usm_impl.cpp" "detail/util.cpp" + "ext/group_mask.cpp" "accessor.cpp" "context.cpp" "device.cpp" diff --git a/sycl/source/ext/group_mask.cpp b/sycl/source/ext/group_mask.cpp new file mode 100644 index 0000000000000..96d06e8afc7dc --- /dev/null +++ b/sycl/source/ext/group_mask.cpp @@ -0,0 +1,34 @@ +//==------------------- group_mask.cpp -------------------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace ext { +namespace oneapi { +group_mask operator&(const group_mask &lhs, const group_mask &rhs) { + auto Res = lhs; + Res &= rhs; + return Res; +} +group_mask operator|(const group_mask &lhs, const group_mask &rhs) { + auto Res = lhs; + Res |= rhs; + return Res; +} + +group_mask operator^(const group_mask &lhs, const group_mask &rhs) { + auto Res = lhs; + Res ^= rhs; + return Res; +} +} // namespace oneapi +} // namespace ext +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/test/check_device_code/group_mask.cpp b/sycl/test/check_device_code/group_mask.cpp new file mode 100644 index 0000000000000..06afa8fa55d84 --- /dev/null +++ b/sycl/test/check_device_code/group_mask.cpp @@ -0,0 +1,8 @@ +// RUN: %clangxx -I %sycl_include -S -emit-llvm -fsycl-device-only %s -o - -Xclang -disable-llvm-passes | FileCheck %s + +#include + +using namespace sycl; + +SYCL_EXTERNAL void test_group_mask(group<> g) { ext::oneapi::group_ballot(g, true); } +// CHECK: %{{.*}} = call spir_func <4 x i32> @_Z[[#]]__spirv_GroupNonUniformBallotjb(i32 {{.*}}, i1{{.*}}) From 8b20c3ad181bd97874e111aafdca354a9d72dc17 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Thu, 2 Sep 2021 19:45:51 +0300 Subject: [PATCH 02/15] Add referencing individual build --- sycl/include/sycl/ext/oneapi/group_mask.hpp | 70 +++++++++++++++------ 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/group_mask.hpp b/sycl/include/sycl/ext/oneapi/group_mask.hpp index a87ecb766fb1b..b8b635a2bbc9b 100644 --- a/sycl/include/sycl/ext/oneapi/group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/group_mask.hpp @@ -19,23 +19,50 @@ namespace ext { namespace oneapi { struct group_mask { + using WordType = uint32_t; + static constexpr size_t max_bits = 128 /* implementation-defined */; + static constexpr size_t word_size = sizeof(WordType) * CHAR_BIT; + /* Bitmask is packed in marray of uint32_t elements. This value represents + * legth of marray. */ + static constexpr size_t marray_size = max_bits / word_size; // enable reference to individual bit struct reference { - reference &operator=(bool x); - reference &operator=(const reference &x); - bool operator~() const; - operator bool() const; - reference &flip(); + reference &operator=(bool x) { + Ref &= x; + return *this; + } + reference &operator=(const reference &x) { + Ref &= (bool)x; + return *this; + } + bool operator~() const { return !(Ref & RefBit); } + operator bool() const { return Ref & RefBit; } + reference &flip() { + if ((bool)*this) { + Ref &= ~RefBit; + } else { + Ref |= RefBit; + } + return *this; + } + + reference(group_mask &gmask, size_t pos) + : Ref(gmask.Bits[pos / word_size]) { + size_t WordPos = pos; + while (WordPos -= word_size && WordPos) + WordPos = pos; + RefBit = 1 << WordPos; + } + private: + WordType &Ref; + WordType RefBit; }; - static constexpr size_t max_bits = 128 /* implementation-defined */; - /* Bitmask is packed in marray of uint32_t elements. This value represents - * legth of marray. */ - static constexpr size_t marray_size = max_bits / sizeof(uint32_t) / CHAR_BIT; - bool operator[](id<1> id) const; - reference operator[](id<1> id); + reference operator[](id<1> id) { + return {*this, id.get(0)}; + } bool test(id<1> id) const; bool all() const; bool any() const; @@ -45,11 +72,15 @@ struct group_mask { id<1> find_low() const; id<1> find_high() const; - template > + template > void insert_bits(const T &bits, id<1> pos = 0); - template > - T extract_bits(id<1> pos = 0); + template > + T extract_bits(id<1> pos = 0) { + T Res = Bits; + Res <<= pos; + return Res; + } void set(); void set(id<1> id, bool value = true); @@ -90,22 +121,23 @@ struct group_mask { group_mask(const group_mask &rhs) : Bits(rhs.Bits) {} template friend group_mask group_ballot(Group g, bool predicate); + protected: - group_mask(const marray &rhs) : Bits(rhs) {} - marray Bits; + group_mask(const marray &rhs) : Bits(rhs) {} + marray Bits; }; group_mask operator&(const group_mask &lhs, const group_mask &rhs); group_mask operator|(const group_mask &lhs, const group_mask &rhs); group_mask operator^(const group_mask &lhs, const group_mask &rhs); -template -group_mask group_ballot(Group g, bool predicate) { +template group_mask group_ballot(Group g, bool predicate) { (void)g; #ifdef __SYCL_DEVICE_ONLY__ auto res = __spirv_GroupNonUniformBallot( detail::spirv::group_scope::value, predicate); - return marray{res[3], res[2], res[1], res[0]}; + return marray{res[3], res[2], + res[1], res[0]}; #else (void)predicate; throw exception{errc::feature_not_supported, From 95e56e9ee224dd7e0855bdd70c0e257faf2c7036 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Fri, 3 Sep 2021 10:21:33 +0300 Subject: [PATCH 03/15] Initial implementation of all methods --- .../extensions/GroupMask/GroupMask.asciidoc | 16 +-- sycl/include/sycl/ext/oneapi/group_mask.hpp | 120 ++++++++++++++---- 2 files changed, 100 insertions(+), 36 deletions(-) diff --git a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc b/sycl/doc/extensions/GroupMask/GroupMask.asciidoc index 8be3de09bb98b..dd1ee0cdcf256 100755 --- a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc +++ b/sycl/doc/extensions/GroupMask/GroupMask.asciidoc @@ -259,7 +259,7 @@ struct group_mask { }; static constexpr size_t max_bits = /* implementation-defined */; - static constexpr size_t marray_size = max_bits/sizeof(uint32_t)/8; + static constexpr size_t marray_size = max_bits/sizeof(uint32_t)/CHAR_BIT; bool operator[](id<1> id) const; reference operator[](id<1> id); @@ -290,15 +290,15 @@ struct group_mask { bool operator==(const group_mask& rhs) const; bool operator!=(const group_mask& rhs) const; - group_mask operator &=(const group_mask& rhs); - group_mask operator |=(const group_mask& rhs); - group_mask operator ^=(const group_mask& rhs); - group_mask operator <<=(size_t); - group_mask operator >>=(size_t rhs); + group_mask &operator &=(const group_mask& rhs); + group_mask &operator |=(const group_mask& rhs); + group_mask &operator ^=(const group_mask& rhs); + group_mask &operator <<=(size_t n); + group_mask &operator >>=(size_t n); group_mask operator ~() const; - group_mask operator <<(size_t) const; - group_mask operator >>(size_t) const; + group_mask operator <<(size_t n) const; + group_mask operator >>(size_t n) const; }; diff --git a/sycl/include/sycl/ext/oneapi/group_mask.hpp b/sycl/include/sycl/ext/oneapi/group_mask.hpp index b8b635a2bbc9b..58f07b3386f22 100644 --- a/sycl/include/sycl/ext/oneapi/group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/group_mask.hpp @@ -25,7 +25,10 @@ struct group_mask { /* Bitmask is packed in marray of uint32_t elements. This value represents * legth of marray. */ static constexpr size_t marray_size = max_bits / word_size; - + /* The bits are stored in the memory in the following way: + marray id | 0 | 1 | 2 | 3 | + bit id |31 .. 0|63 .. 32|95 .. 64|127 .. 96| + */ // enable reference to individual bit struct reference { reference &operator=(bool x) { @@ -54,42 +57,68 @@ struct group_mask { WordPos = pos; RefBit = 1 << WordPos; } + private: + // Reference to the word containing the bit WordType &Ref; + // Bit mask where only referenced bit is set WordType RefBit; }; - bool operator[](id<1> id) const; - reference operator[](id<1> id) { - return {*this, id.get(0)}; + bool operator[](id<1> id) const { return operator[](id); } + reference operator[](id<1> id) { return {*this, id.get(0)}; } + bool test(id<1> id) const { return operator[](id); } + bool all() const { return !(~(Bits[0] & Bits[1] & Bits[2] & Bits[3])); } + bool any() const { return Bits[0] | Bits[1] | Bits[2] | Bits[3]; } + bool none() const { return !any(); } + uint32_t count() const { + unsigned int count = 0; + for (auto word : Bits) { + while (word) { + word &= (word - 1); + count++; + } + } + return count; + } + uint32_t size() const { return max_bits; } + id<1> find_low() const { + size_t i = 0; + while (i < size() && !operator[](i)) + i++; + return {i}; + } + id<1> find_high() const { + size_t i = size() - 1; + while (i > 0 && !operator[](i)) + i--; + return {operator[](i) ? i : size()}; } - bool test(id<1> id) const; - bool all() const; - bool any() const; - bool none() const; - uint32_t count() const; - uint32_t size() const; - id<1> find_low() const; - id<1> find_high() const; template > - void insert_bits(const T &bits, id<1> pos = 0); + void insert_bits(const T &bits, id<1> pos = 0) { + operator>>=(pos.get(0)); + operator<<=(pos.get(0)); + group_mask tmp(bits); + tmp>>=pos.get(0); + Bits |= tmp.Bits; + } template > T extract_bits(id<1> pos = 0) { - T Res = Bits; - Res <<= pos; - return Res; + group_mask Tmp = *this; + Tmp <<= pos.get(0); + return Tmp.Bits; } - void set(); - void set(id<1> id, bool value = true); - void reset(); - void reset(id<1> id); - void reset_low(); - void reset_high(); - void flip(); - void flip(id<1> id); + void set() { Bits = !(WordType{0}); } + void set(id<1> id, bool value = true) { operator[](id) = value; } + void reset() { Bits = WordType{0}; } + void reset(id<1> id) { operator[](id) = 0; } + void reset_low() { reset(find_low()); } + void reset_high() { reset(find_high()); } + void flip() { Bits = ~Bits; } + void flip(id<1> id) { operator[](id).flip(); } bool operator==(const group_mask &rhs) const { return Bits == rhs.Bits; } bool operator!=(const group_mask &rhs) const { return Bits != rhs.Bits; } @@ -108,16 +137,51 @@ struct group_mask { return *this; } - group_mask &operator<<=(size_t); - group_mask &operator>>=(size_t rhs); + group_mask &operator<<=(size_t pos) { + marray Res{0}; + size_t word_shift = pos / word_size; + size_t bit_shift = pos % word_size; + WordType extra_bits = 0; + for (size_t i = 0; i < marray_size; i++) { + extra_bits = Bits[i] >> (word_size - bit_shift); + Bits[i] <<= bit_shift; + Res[i + word_shift] = Bits[i] + extra_bits; + } + Bits = Res; + return *this; + } + + group_mask &operator>>=(size_t pos) { + marray Res{0}; + size_t word_shift = pos / word_size; + size_t bit_shift = pos % word_size; + WordType extra_bits = 0; + for (int i = marray_size - 1; i >= 0; i--) { + extra_bits = Bits[i] << (word_size - bit_shift); + Bits[i] >>= bit_shift; + Res[i - word_shift] = Bits[i] + extra_bits; + } + Bits = Res; + + return *this; + } group_mask operator~() const { auto Tmp = *this; Tmp.flip(); return Tmp; } - group_mask &operator<<(size_t) const; - group_mask &operator>>(size_t) const; + group_mask operator<<(size_t pos) const { + auto Tmp = *this; + Tmp <<= pos; + return Tmp; + } + group_mask operator>>(size_t pos) const { + auto Tmp = *this; + Tmp >>= pos; + return Tmp; + } + group_mask(const group_mask &rhs) : Bits(rhs.Bits) {} template friend group_mask group_ballot(Group g, bool predicate); From a26748fe6c2dab31d075bd0d8cb0273ccb5ea0cc Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Fri, 3 Sep 2021 14:29:44 +0300 Subject: [PATCH 04/15] Remove dependency on lib --- sycl/include/sycl/ext/oneapi/group_mask.hpp | 24 +++++++++++---- sycl/source/CMakeLists.txt | 1 - sycl/source/ext/group_mask.cpp | 34 --------------------- 3 files changed, 18 insertions(+), 41 deletions(-) delete mode 100644 sycl/source/ext/group_mask.cpp diff --git a/sycl/include/sycl/ext/oneapi/group_mask.hpp b/sycl/include/sycl/ext/oneapi/group_mask.hpp index 58f07b3386f22..633c35490582c 100644 --- a/sycl/include/sycl/ext/oneapi/group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/group_mask.hpp @@ -100,7 +100,7 @@ struct group_mask { operator>>=(pos.get(0)); operator<<=(pos.get(0)); group_mask tmp(bits); - tmp>>=pos.get(0); + tmp >>= pos.get(0); Bits |= tmp.Bits; } @@ -189,12 +189,24 @@ struct group_mask { protected: group_mask(const marray &rhs) : Bits(rhs) {} marray Bits; -}; - -group_mask operator&(const group_mask &lhs, const group_mask &rhs); -group_mask operator|(const group_mask &lhs, const group_mask &rhs); -group_mask operator^(const group_mask &lhs, const group_mask &rhs); +public: + group_mask operator&(const group_mask &rhs) { + auto Res = *this; + Res &= rhs; + return Res; + } + group_mask operator|(const group_mask &rhs) { + auto Res = *this; + Res |= rhs; + return Res; + } + group_mask operator^(const group_mask &rhs) { + auto Res = *this; + Res ^= rhs; + return Res; + } +}; template group_mask group_ballot(Group g, bool predicate) { (void)g; #ifdef __SYCL_DEVICE_ONLY__ diff --git a/sycl/source/CMakeLists.txt b/sycl/source/CMakeLists.txt index 7a0a8e0d7b174..fd6e120107c92 100644 --- a/sycl/source/CMakeLists.txt +++ b/sycl/source/CMakeLists.txt @@ -149,7 +149,6 @@ set(SYCL_SOURCES "detail/sycl_mem_obj_t.cpp" "detail/usm/usm_impl.cpp" "detail/util.cpp" - "ext/group_mask.cpp" "accessor.cpp" "context.cpp" "device.cpp" diff --git a/sycl/source/ext/group_mask.cpp b/sycl/source/ext/group_mask.cpp deleted file mode 100644 index 96d06e8afc7dc..0000000000000 --- a/sycl/source/ext/group_mask.cpp +++ /dev/null @@ -1,34 +0,0 @@ -//==------------------- group_mask.cpp -------------------------------------==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include - -__SYCL_INLINE_NAMESPACE(cl) { -namespace sycl { -namespace ext { -namespace oneapi { -group_mask operator&(const group_mask &lhs, const group_mask &rhs) { - auto Res = lhs; - Res &= rhs; - return Res; -} -group_mask operator|(const group_mask &lhs, const group_mask &rhs) { - auto Res = lhs; - Res |= rhs; - return Res; -} - -group_mask operator^(const group_mask &lhs, const group_mask &rhs) { - auto Res = lhs; - Res ^= rhs; - return Res; -} -} // namespace oneapi -} // namespace ext -} // namespace sycl -} // __SYCL_INLINE_NAMESPACE(cl) From 3c578c65be22d2be0816424d154ac4a14587f903 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Mon, 6 Sep 2021 07:27:03 +0300 Subject: [PATCH 05/15] Bugfix --- .../extensions/GroupMask/GroupMask.asciidoc | 13 +-- sycl/include/sycl/ext/oneapi/group_mask.hpp | 98 ++++++++++--------- sycl/test/basic_tests/group_mask.cpp | 82 ++++++++++++++++ 3 files changed, 141 insertions(+), 52 deletions(-) create mode 100644 sycl/test/basic_tests/group_mask.cpp diff --git a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc b/sycl/doc/extensions/GroupMask/GroupMask.asciidoc index dd1ee0cdcf256..4cf1a5691cdab 100755 --- a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc +++ b/sycl/doc/extensions/GroupMask/GroupMask.asciidoc @@ -139,15 +139,13 @@ work-item with the id `max_local_range()-1`. |`template > void insert_bits(const T& bits, id<1> pos = 0)` |Insert `CHAR_BIT * sizeof(T)` bits into the mask, starting from _pos_. `T` - must be an integral type or a SYCL `marray` of integral types. _pos_ must be a - multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+] + must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+] `CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+] `CHAR_BIT * sizeof(T)`) bits are ignored. |`template > T extract_bits(id<1> pos = 0) const` |Return `CHAR_BIT * sizeof(T)` bits from the mask, starting from _pos_. `T` - must be an integral type or a SYCL `marray` of integral types. _pos_ must be a - multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+] + must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+] `CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+] `CHAR_BIT * sizeof(T)`) bits of the return value are zero. @@ -300,12 +298,11 @@ struct group_mask { group_mask operator <<(size_t n) const; group_mask operator >>(size_t n) const; + group_mask operator &(const group_mask& rhs) const; + group_mask operator |(const group_mask& rhs) const; + group_mask operator ^(const group_mask& rhs) const; }; -group_mask operator &(const group_mask& lhs, const group_mask& rhs); -group_mask operator |(const group_mask& lhs, const group_mask& rhs); -group_mask operator ^(const group_mask& lhs, const group_mask& rhs); - } // namespace oneapi } // namespace ext } // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/group_mask.hpp b/sycl/include/sycl/ext/oneapi/group_mask.hpp index 633c35490582c..2d9b820532cf7 100644 --- a/sycl/include/sycl/ext/oneapi/group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/group_mask.hpp @@ -27,35 +27,32 @@ struct group_mask { static constexpr size_t marray_size = max_bits / word_size; /* The bits are stored in the memory in the following way: marray id | 0 | 1 | 2 | 3 | - bit id |31 .. 0|63 .. 32|95 .. 64|127 .. 96| + bit id |127 .. 96|95 .. 64|63 .. 32|31 .. 0| */ // enable reference to individual bit struct reference { reference &operator=(bool x) { - Ref &= x; + if (x) { + Ref |= RefBit; + } else { + Ref &= ~RefBit; + } return *this; } reference &operator=(const reference &x) { - Ref &= (bool)x; + operator=((bool)x); return *this; } bool operator~() const { return !(Ref & RefBit); } operator bool() const { return Ref & RefBit; } reference &flip() { - if ((bool)*this) { - Ref &= ~RefBit; - } else { - Ref |= RefBit; - } + operator=(!(bool)*this); return *this; } reference(group_mask &gmask, size_t pos) - : Ref(gmask.Bits[pos / word_size]) { - size_t WordPos = pos; - while (WordPos -= word_size && WordPos) - WordPos = pos; - RefBit = 1 << WordPos; + : Ref(gmask.Bits[marray_size - (pos / word_size) - 1]) { + RefBit = 1 << pos % word_size; } private: @@ -65,7 +62,10 @@ struct group_mask { WordType RefBit; }; - bool operator[](id<1> id) const { return operator[](id); } + bool operator[](id<1> id) const { + return Bits[marray_size - id.get(0) / word_size - 1] & + (1 << (id.get(0) % word_size)); + } reference operator[](id<1> id) { return {*this, id.get(0)}; } bool test(id<1> id) const { return operator[](id); } bool all() const { return !(~(Bits[0] & Bits[1] & Bits[2] & Bits[3])); } @@ -97,10 +97,14 @@ struct group_mask { template > void insert_bits(const T &bits, id<1> pos = 0) { - operator>>=(pos.get(0)); - operator<<=(pos.get(0)); group_mask tmp(bits); - tmp >>= pos.get(0); + if (pos.get(0) > 0) { + operator<<=(max_bits - pos.get(0)); + operator>>=(max_bits - pos.get(0)); + tmp <<= pos.get(0); + } else { + reset(); + } Bits |= tmp.Bits; } @@ -111,7 +115,7 @@ struct group_mask { return Tmp.Bits; } - void set() { Bits = !(WordType{0}); } + void set() { Bits = ~(WordType{0}); } void set(id<1> id, bool value = true) { operator[](id) = value; } void reset() { Bits = WordType{0}; } void reset(id<1> id) { operator[](id) = 0; } @@ -120,8 +124,13 @@ struct group_mask { void flip() { Bits = ~Bits; } void flip(id<1> id) { operator[](id).flip(); } - bool operator==(const group_mask &rhs) const { return Bits == rhs.Bits; } - bool operator!=(const group_mask &rhs) const { return Bits != rhs.Bits; } + bool operator==(const group_mask &rhs) const { + bool Res = true; + for (size_t i = 0; i < marray_size; i++) + Res &= Bits[i] == rhs.Bits[i]; + return Res; + } + bool operator!=(const group_mask &rhs) const { return !(*this == rhs); } group_mask &operator&=(const group_mask &rhs) { Bits &= rhs.Bits; @@ -138,31 +147,32 @@ struct group_mask { } group_mask &operator<<=(size_t pos) { - marray Res{0}; - size_t word_shift = pos / word_size; - size_t bit_shift = pos % word_size; - WordType extra_bits = 0; - for (size_t i = 0; i < marray_size; i++) { - extra_bits = Bits[i] >> (word_size - bit_shift); - Bits[i] <<= bit_shift; - Res[i + word_shift] = Bits[i] + extra_bits; + if (pos > 0) { + marray Res{0}; + size_t word_shift = pos / word_size; + size_t bit_shift = pos % word_size; + WordType extra_bits = 0; + for (int i = marray_size - 1; i >= 0; i--) { + Res[i - word_shift] = (Bits[i] << bit_shift) + extra_bits; + extra_bits = Bits[i] >> (word_size - bit_shift); + } + Bits = Res; } - Bits = Res; return *this; } group_mask &operator>>=(size_t pos) { - marray Res{0}; - size_t word_shift = pos / word_size; - size_t bit_shift = pos % word_size; - WordType extra_bits = 0; - for (int i = marray_size - 1; i >= 0; i--) { - extra_bits = Bits[i] << (word_size - bit_shift); - Bits[i] >>= bit_shift; - Res[i - word_shift] = Bits[i] + extra_bits; + if (pos > 0) { + marray Res{0}; + size_t word_shift = pos / word_size; + size_t bit_shift = pos % word_size; + WordType extra_bits = 0; + for (size_t i = 0; i < marray_size; i++) { + Res[i + word_shift] = (Bits[i] >> bit_shift) + extra_bits; + extra_bits = Bits[i] << (word_size - bit_shift); + } + Bits = Res; } - Bits = Res; - return *this; } @@ -186,11 +196,8 @@ struct group_mask { template friend group_mask group_ballot(Group g, bool predicate); -protected: group_mask(const marray &rhs) : Bits(rhs) {} - marray Bits; -public: group_mask operator&(const group_mask &rhs) { auto Res = *this; Res &= rhs; @@ -206,14 +213,17 @@ struct group_mask { Res ^= rhs; return Res; } + +private: + marray Bits; }; template group_mask group_ballot(Group g, bool predicate) { (void)g; #ifdef __SYCL_DEVICE_ONLY__ auto res = __spirv_GroupNonUniformBallot( detail::spirv::group_scope::value, predicate); - return marray{res[3], res[2], - res[1], res[0]}; + return marray{res[0], res[1], + res[2], res[3]}; #else (void)predicate; throw exception{errc::feature_not_supported, diff --git a/sycl/test/basic_tests/group_mask.cpp b/sycl/test/basic_tests/group_mask.cpp new file mode 100644 index 0000000000000..ca4184301e85d --- /dev/null +++ b/sycl/test/basic_tests/group_mask.cpp @@ -0,0 +1,82 @@ +// RUN: %clangxx -g -O0 -fsycl -fsycl-targets=%sycl_triple %s -o %t.out +// RUN: %t.out + +//==-------- group_mask.cpp - SYCL group_mask test -------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include +#include +#include + +int main() { + sycl::ext::oneapi::group_mask g{sycl::marray{0}}; + assert(g.none() && !g.any() && !g.all()); + assert(g[10] == false); // reference::operator[](id) const; + g[10] = true; // reference::operator=(bool); + assert(g[10] == true); + g[11] = g[10]; // reference::operator=(reference) reference::operator[](id); + assert(g[10].flip() == false); // reference::flip() + assert(~g[10] == true); // refernce::operator~() + assert(g[10] == false); + assert(g[11] == true); + assert(g.test(10) == false && g.test(11) == true); + g.set(101, 1); + g.set(11, 0); + g.set(53, 1); + assert(!g.none() && g.any() && !g.all()); + + assert(g.count() == 2); + assert(g.find_low() == 53); + assert(g.find_high() == 101); + assert(g.size() == 128); + + g.reset(); + assert(g.none() && !g.any() && !g.all()); + assert(g.find_low() == g.size() && g.find_high() == g.size()); + g.set(); + assert(!g.none() && g.any() && g.all()); + assert(g.find_low() == 0 && g.find_high() == 127); + g.flip(); + assert(g.none() && !g.any() && !g.all()); + + g.flip(13); + g.flip(43); + g.flip(79); + auto b = g; + assert(b == g && !(b != g)); + g.flip(101); + assert(g.find_high() == 101); + assert(b.find_high() == 79); + assert(b != g && !(b == g)); + b.flip(101); + assert(b == g && !(b != g)); + b = g >> 1; + assert(b[12] && b[42] && b[78] && b[100]); + b <<= 1; + assert(b == g); + g ^= ~b; + assert(!g.none() && g.any() && g.all()); + assert((g | ~g).all()); + assert((g & ~g).none()); + assert((g ^ ~g).all()); + b.reset_low(); + b.reset_high(); + assert(!b[13] && b[43] && b[79] && !b[101]); + b.insert_bits({1, 2, 4, 8}); + assert(b[96] && b[65] && b[34] && b[3]); + g = b; + g <<= 33; + assert(!g[96] && !g[65] && !g[34] && !g[3] && g[98] && g[67] && g[36]); + b.insert_bits({1, 1, 1, 1}, 15); + assert(b[111] && !b[96] && b[79] && !b[65] && b[47] && !b[34] && b[15] && + b[3]); + b >>= 79; + assert(b[32] && b[0]); + b.flip(32); + b.flip(0); + assert(b.none()); +} From 24e5a95ab6c9bf756fa7d992babc8d576cb34758 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Mon, 6 Sep 2021 17:09:06 +0300 Subject: [PATCH 06/15] Fix internal layout --- sycl/include/sycl/ext/oneapi/group_mask.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/group_mask.hpp b/sycl/include/sycl/ext/oneapi/group_mask.hpp index 2d9b820532cf7..deb3fa61f69eb 100644 --- a/sycl/include/sycl/ext/oneapi/group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/group_mask.hpp @@ -222,8 +222,8 @@ template group_mask group_ballot(Group g, bool predicate) { #ifdef __SYCL_DEVICE_ONLY__ auto res = __spirv_GroupNonUniformBallot( detail::spirv::group_scope::value, predicate); - return marray{res[0], res[1], - res[2], res[3]}; + return marray{res[3], res[2], + res[1], res[0]}; #else (void)predicate; throw exception{errc::feature_not_supported, From db1f5b822ee05308eb6edbf687ef903eeda13d12 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Tue, 7 Sep 2021 13:04:45 +0300 Subject: [PATCH 07/15] Update tests Move Group Mask test to extension folder Add test for extension macro --- .../{basic_tests => extensions}/group_mask.cpp | 0 sycl/test/extensions/macro.cpp | 15 +++++++++++++++ 2 files changed, 15 insertions(+) rename sycl/test/{basic_tests => extensions}/group_mask.cpp (100%) create mode 100644 sycl/test/extensions/macro.cpp diff --git a/sycl/test/basic_tests/group_mask.cpp b/sycl/test/extensions/group_mask.cpp similarity index 100% rename from sycl/test/basic_tests/group_mask.cpp rename to sycl/test/extensions/group_mask.cpp diff --git a/sycl/test/extensions/macro.cpp b/sycl/test/extensions/macro.cpp new file mode 100644 index 0000000000000..aaa70eb21ce36 --- /dev/null +++ b/sycl/test/extensions/macro.cpp @@ -0,0 +1,15 @@ +// This test checks presence of macros for available extensions. +// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %t.out + +#include +#include +int main() { +#if SYCL_EXT_ONEAPI_GROUP_MASK == 1 + std::cout << "SYCL_EXT_ONEAPI_GROUP_MASK=1" << std::endl; +#else + std::cerr << "SYCL_EXT_ONEAPI_GROUP_MASK!=1" << std::endl; + exit(1); +#endif + exit(0); +} From 4442ea443856db99c13e73bafc18e44ce99cb5c2 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Wed, 8 Sep 2021 19:04:37 +0300 Subject: [PATCH 08/15] Apply review comments --- .../extensions/GroupMask/GroupMask.asciidoc | 59 ++++++++----------- sycl/include/sycl/ext/oneapi/group_mask.hpp | 40 ++++++------- .../group_mask.cpp | 0 3 files changed, 46 insertions(+), 53 deletions(-) rename sycl/test/{basic_tests => extensions}/group_mask.cpp (100%) diff --git a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc b/sycl/doc/extensions/GroupMask/GroupMask.asciidoc index 4cf1a5691cdab..7cce193a11f17 100755 --- a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc +++ b/sycl/doc/extensions/GroupMask/GroupMask.asciidoc @@ -137,7 +137,7 @@ work-item with the id `max_local_range()-1`. |Return the highest `id` with a corresponding bit set in the mask. If no bits are set, the return value is equal to `size()`. -|`template > void insert_bits(const T& bits, id<1> pos = 0)` +|`template > void insert_bits(const T &bits, id<1> pos = 0)` |Insert `CHAR_BIT * sizeof(T)` bits into the mask, starting from _pos_. `T` must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+] `CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+] @@ -176,32 +176,32 @@ work-item with the id `max_local_range()-1`. |`void flip(id<1> id)` |Toggle the value of the bit corresponding to the specified _id_. -|`bool operator==(group_mask rhs) const` +|`bool operator==(const group_mask &rhs) const` |Return true if each bit in this mask is equal to the corresponding bit in `rhs`. -|`bool operator!=(group_mask rhs) const` +|`bool operator!=(const group_mask &rhs) const` |Return true if any bit in this mask is not equal to the corresponding bit in `rhs`. -|`group_mask operator &=(group_mask rhs)` +|`group_mask &operator &=(const group_mask &rhs)` |Set the bits of this mask to the result of performing a bitwise AND with this mask and `rhs`. -|`group_mask operator \|=(group_mask rhs)` +|`group_mask &operator \|=(const group_mask &rhs)` |Set the bits of this mask to the result of performing a bitwise OR with this mask and `rhs`. -|`group_mask operator ^=(group_mask rhs)` +|`group_mask &operator ^=(const group_mask &rhs)` |Set the bits of this mask to the result of performing a bitwise XOR with this mask and `rhs`. -|`group_mask operator pass:[<<=](size_t shift)` +|`group_mask &operator pass:[<<=](size_t shift)` |Set the bits of this mask to the result of shifting its bits _shift_ positions to the left using a logical shift. Bits that are shifted out to the left are discarded, and zeroes are shifted in from the right. -|`group_mask operator >>=(size_t shift)` +|`group_mask &operator >>=(size_t shift)` |Set the bits of this mask to the result of shifting its bits _shift_ positions to the right using a logical shift. Bits that are shifted out to the right are discarded, and zeroes are shifted in from the left. @@ -209,31 +209,24 @@ work-item with the id `max_local_range()-1`. |`group_mask operator ~() const` |Return a mask representing the result of flipping all the bits in this mask. -|`group_mask operator <<(size_t shift)` +|`group_mask operator <<(size_t shift) const` |Return a mask representing the result of shifting its bits _shift_ positions to the left using a logical shift. Bits that are shifted out to the left are discarded, and zeroes are shifted in from the right. -|`group_mask operator >>(size_t shift)` +|`group_mask operator >>(size_t shift) const` |Return a mask representing the result of shifting its bits _shift_ positions to the right using a logical shift. Bits that are shifted out to the right are discarded, and zeroes are shifted in from the left. -|=== -|=== -|Function|Description +|`group_mask operator &(const group_mask &rhs) const` +|Return a mask representing the result of performing a bitwise AND of two masks. -|`group_mask operator &(const group_mask& lhs, const group_mask& rhs)` -|Return a mask representing the result of performing a bitwise AND of `lhs` and - `rhs`. +|`group_mask operator \|(const group_mask &rhs) const` +|Return a mask representing the result of performing a bitwise OR of two masks. -|`group_mask operator \|(const group_mask& lhs, const group_mask& rhs)` -|Return a mask representing the result of performing a bitwise OR of `lhs` and - `rhs`. - -|`group_mask operator ^(const group_mask& lhs, const group_mask& rhs)` -|Return a mask representing the result of performing a bitwise XOR of `lhs` and - `rhs`. +|`group_mask operator ^(const group_mask &rhs) const` +|Return a mask representing the result of performing a bitwise XOR of two masks. |=== @@ -257,7 +250,7 @@ struct group_mask { }; static constexpr size_t max_bits = /* implementation-defined */; - static constexpr size_t marray_size = max_bits/sizeof(uint32_t)/CHAR_BIT; + static constexpr size_t marray_size = /* implementation defined */; bool operator[](id<1> id) const; reference operator[](id<1> id); @@ -271,7 +264,7 @@ struct group_mask { id<1> find_high() const; template > - void insert_bits(const T& bits, id<1> pos = 0); + void insert_bits(const T &bits, id<1> pos = 0); template > T extract_bits(id<1> pos = 0); @@ -285,12 +278,12 @@ struct group_mask { void flip(); void flip(id<1> id); - bool operator==(const group_mask& rhs) const; - bool operator!=(const group_mask& rhs) const; + bool operator==(const group_mask &rhs) const; + bool operator!=(const group_mask &rhs) const; - group_mask &operator &=(const group_mask& rhs); - group_mask &operator |=(const group_mask& rhs); - group_mask &operator ^=(const group_mask& rhs); + group_mask &operator &=(const group_mask &rhs); + group_mask &operator |=(const group_mask &rhs); + group_mask &operator ^=(const group_mask &rhs); group_mask &operator <<=(size_t n); group_mask &operator >>=(size_t n); @@ -298,9 +291,9 @@ struct group_mask { group_mask operator <<(size_t n) const; group_mask operator >>(size_t n) const; - group_mask operator &(const group_mask& rhs) const; - group_mask operator |(const group_mask& rhs) const; - group_mask operator ^(const group_mask& rhs) const; + group_mask operator &(const group_mask &rhs) const; + group_mask operator |(const group_mask &rhs) const; + group_mask operator ^(const group_mask &rhs) const; }; } // namespace oneapi diff --git a/sycl/include/sycl/ext/oneapi/group_mask.hpp b/sycl/include/sycl/ext/oneapi/group_mask.hpp index deb3fa61f69eb..dd904bac39b65 100644 --- a/sycl/include/sycl/ext/oneapi/group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/group_mask.hpp @@ -19,16 +19,16 @@ namespace ext { namespace oneapi { struct group_mask { - using WordType = uint32_t; static constexpr size_t max_bits = 128 /* implementation-defined */; - static constexpr size_t word_size = sizeof(WordType) * CHAR_BIT; + static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT; /* Bitmask is packed in marray of uint32_t elements. This value represents - * legth of marray. */ - static constexpr size_t marray_size = max_bits / word_size; + * legth of marray. Round up in case when it is not evenly divisible. */ + static constexpr size_t marray_size = (max_bits + word_size - 1) / word_size; /* The bits are stored in the memory in the following way: marray id | 0 | 1 | 2 | 3 | bit id |127 .. 96|95 .. 64|63 .. 32|31 .. 0| */ + // enable reference to individual bit struct reference { reference &operator=(bool x) { @@ -57,9 +57,9 @@ struct group_mask { private: // Reference to the word containing the bit - WordType &Ref; + uint32_t &Ref; // Bit mask where only referenced bit is set - WordType RefBit; + uint32_t RefBit; }; bool operator[](id<1> id) const { @@ -95,7 +95,7 @@ struct group_mask { return {operator[](i) ? i : size()}; } - template > + template > void insert_bits(const T &bits, id<1> pos = 0) { group_mask tmp(bits); if (pos.get(0) > 0) { @@ -108,16 +108,16 @@ struct group_mask { Bits |= tmp.Bits; } - template > + template > T extract_bits(id<1> pos = 0) { group_mask Tmp = *this; Tmp <<= pos.get(0); return Tmp.Bits; } - void set() { Bits = ~(WordType{0}); } + void set() { Bits = ~(uint32_t{0}); } void set(id<1> id, bool value = true) { operator[](id) = value; } - void reset() { Bits = WordType{0}; } + void reset() { Bits = uint32_t{0}; } void reset(id<1> id) { operator[](id) = 0; } void reset_low() { reset(find_low()); } void reset_high() { reset(find_high()); } @@ -148,10 +148,10 @@ struct group_mask { group_mask &operator<<=(size_t pos) { if (pos > 0) { - marray Res{0}; + marray Res{0}; size_t word_shift = pos / word_size; size_t bit_shift = pos % word_size; - WordType extra_bits = 0; + uint32_t extra_bits = 0; for (int i = marray_size - 1; i >= 0; i--) { Res[i - word_shift] = (Bits[i] << bit_shift) + extra_bits; extra_bits = Bits[i] >> (word_size - bit_shift); @@ -163,10 +163,10 @@ struct group_mask { group_mask &operator>>=(size_t pos) { if (pos > 0) { - marray Res{0}; + marray Res{0}; size_t word_shift = pos / word_size; size_t bit_shift = pos % word_size; - WordType extra_bits = 0; + uint32_t extra_bits = 0; for (size_t i = 0; i < marray_size; i++) { Res[i + word_shift] = (Bits[i] >> bit_shift) + extra_bits; extra_bits = Bits[i] << (word_size - bit_shift); @@ -196,33 +196,33 @@ struct group_mask { template friend group_mask group_ballot(Group g, bool predicate); - group_mask(const marray &rhs) : Bits(rhs) {} + group_mask(const marray &rhs) : Bits(rhs) {} - group_mask operator&(const group_mask &rhs) { + group_mask operator&(const group_mask &rhs) const { auto Res = *this; Res &= rhs; return Res; } - group_mask operator|(const group_mask &rhs) { + group_mask operator|(const group_mask &rhs) const { auto Res = *this; Res |= rhs; return Res; } - group_mask operator^(const group_mask &rhs) { + group_mask operator^(const group_mask &rhs) const { auto Res = *this; Res ^= rhs; return Res; } private: - marray Bits; + marray Bits; }; template group_mask group_ballot(Group g, bool predicate) { (void)g; #ifdef __SYCL_DEVICE_ONLY__ auto res = __spirv_GroupNonUniformBallot( detail::spirv::group_scope::value, predicate); - return marray{res[3], res[2], + return marray{res[3], res[2], res[1], res[0]}; #else (void)predicate; diff --git a/sycl/test/basic_tests/group_mask.cpp b/sycl/test/extensions/group_mask.cpp similarity index 100% rename from sycl/test/basic_tests/group_mask.cpp rename to sycl/test/extensions/group_mask.cpp From baaeb39b0f1ea5121b31014361245970154402b5 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Wed, 8 Sep 2021 20:51:36 +0300 Subject: [PATCH 09/15] Fix clang-format --- sycl/include/sycl/ext/oneapi/group_mask.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/group_mask.hpp b/sycl/include/sycl/ext/oneapi/group_mask.hpp index dd904bac39b65..d18286f4240af 100644 --- a/sycl/include/sycl/ext/oneapi/group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/group_mask.hpp @@ -222,8 +222,8 @@ template group_mask group_ballot(Group g, bool predicate) { #ifdef __SYCL_DEVICE_ONLY__ auto res = __spirv_GroupNonUniformBallot( detail::spirv::group_scope::value, predicate); - return marray{res[3], res[2], - res[1], res[0]}; + return marray{res[3], res[2], res[1], + res[0]}; #else (void)predicate; throw exception{errc::feature_not_supported, From 585fa1b26385cf961008d354a660673e80316154 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Mon, 13 Sep 2021 11:18:24 +0300 Subject: [PATCH 10/15] Apply review comments --- .../extensions/GroupMask/GroupMask.asciidoc | 30 +++++++++----- sycl/include/CL/sycl/detail/helpers.hpp | 6 +++ sycl/include/sycl/ext/oneapi/group_mask.hpp | 41 ++++++++++++------- sycl/test/extensions/group_mask.cpp | 14 +++++-- 4 files changed, 64 insertions(+), 27 deletions(-) diff --git a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc b/sycl/doc/extensions/GroupMask/GroupMask.asciidoc index 7cce193a11f17..3cf970bd6afd6 100755 --- a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc +++ b/sycl/doc/extensions/GroupMask/GroupMask.asciidoc @@ -82,7 +82,7 @@ must be encountered by all work-items in the group in converged control flow. |Function|Description |`template Group::mask_type group_ballot(Group g, bool predicate = true)` -|Return a `group_mask` representing the set of work-items in group _g_ for which _predicate_ is `true`. +|Return a `group_mask` with one bit for each work-item in group _g_. A bit is set in this mask if and only if the corresponding work-item's _predicate_ is `true`. |=== === Group Masks @@ -219,14 +219,22 @@ work-item with the id `max_local_range()-1`. to the right using a logical shift. Bits that are shifted out to the right are discarded, and zeroes are shifted in from the left. -|`group_mask operator &(const group_mask &rhs) const` -|Return a mask representing the result of performing a bitwise AND of two masks. +|=== -|`group_mask operator \|(const group_mask &rhs) const` -|Return a mask representing the result of performing a bitwise OR of two masks. +|=== +|Function|Description -|`group_mask operator ^(const group_mask &rhs) const` -|Return a mask representing the result of performing a bitwise XOR of two masks. +|`group_mask operator &(const group_mask& lhs, const group_mask& rhs)` +|Return a mask representing the result of performing a bitwise AND of `lhs` and + `rhs`. + +|`group_mask operator \|(const group_mask& lhs, const group_mask& rhs)` +|Return a mask representing the result of performing a bitwise OR of `lhs` and + `rhs`. + +|`group_mask operator ^(const group_mask& lhs, const group_mask& rhs)` +|Return a mask representing the result of performing a bitwise XOR of `lhs` and + `rhs`. |=== @@ -291,11 +299,12 @@ struct group_mask { group_mask operator <<(size_t n) const; group_mask operator >>(size_t n) const; - group_mask operator &(const group_mask &rhs) const; - group_mask operator |(const group_mask &rhs) const; - group_mask operator ^(const group_mask &rhs) const; }; +group_mask operator &(const group_mask& lhs, const group_mask& rhs); +group_mask operator |(const group_mask& lhs, const group_mask& rhs); +group_mask operator ^(const group_mask& lhs, const group_mask& rhs); + } // namespace oneapi } // namespace ext } // namespace sycl @@ -319,6 +328,7 @@ None. |======================================== |Rev|Date|Author|Changes |1|2021-08-11|John Pennycook|*Initial public working draft* +|2|2021-09-13|Vladimir Lazarev|*Update during implementation* |======================================== //************************************************************************ diff --git a/sycl/include/CL/sycl/detail/helpers.hpp b/sycl/include/CL/sycl/detail/helpers.hpp index 118271a35bab5..28fa272f55e50 100644 --- a/sycl/include/CL/sycl/detail/helpers.hpp +++ b/sycl/include/CL/sycl/detail/helpers.hpp @@ -31,6 +31,7 @@ template class range; template class id; template class nd_item; template class h_item; +template class marray; enum class memory_order; namespace detail { @@ -82,6 +83,11 @@ class Builder { return group(Global, Local, Global / Local, Index); } + template + static ResType createGroupMask(marray Bits) { + return ResType(Bits); + } + template static detail::enable_if_t> createItem(const range &Extent, const id &Index, diff --git a/sycl/include/sycl/ext/oneapi/group_mask.hpp b/sycl/include/sycl/ext/oneapi/group_mask.hpp index d18286f4240af..8a1d8f4a0addf 100644 --- a/sycl/include/sycl/ext/oneapi/group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/group_mask.hpp @@ -9,16 +9,22 @@ #include #include +#include #include #include #include __SYCL_INLINE_NAMESPACE(cl) { namespace sycl { +namespace detail { +class Builder; +} // namespace detail + namespace ext { namespace oneapi { struct group_mask { + friend class detail::Builder; static constexpr size_t max_bits = 128 /* implementation-defined */; static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT; /* Bitmask is packed in marray of uint32_t elements. This value represents @@ -95,12 +101,16 @@ struct group_mask { return {operator[](i) ? i : size()}; } - template > - void insert_bits(const T &bits, id<1> pos = 0) { + template ::value>> + void insert_bits(T bits, id<1> pos = 0) {} + + template + void insert_bits(const marray &bits, id<1> pos = 0) { group_mask tmp(bits); if (pos.get(0) > 0) { - operator<<=(max_bits - pos.get(0)); - operator>>=(max_bits - pos.get(0)); + operator<<=(size() - pos.get(0)); + operator>>=(size() - pos.get(0)); tmp <<= pos.get(0); } else { reset(); @@ -193,28 +203,30 @@ struct group_mask { } group_mask(const group_mask &rhs) : Bits(rhs.Bits) {} + template friend group_mask group_ballot(Group g, bool predicate); - group_mask(const marray &rhs) : Bits(rhs) {} - - group_mask operator&(const group_mask &rhs) const { - auto Res = *this; + friend group_mask operator&(const group_mask &lhs, const group_mask &rhs) { + auto Res = lhs; Res &= rhs; return Res; } - group_mask operator|(const group_mask &rhs) const { - auto Res = *this; + + friend group_mask operator|(const group_mask &lhs, const group_mask &rhs) { + auto Res = lhs; Res |= rhs; return Res; } - group_mask operator^(const group_mask &rhs) const { - auto Res = *this; + + friend group_mask operator^(const group_mask &lhs, const group_mask &rhs) { + auto Res = lhs; Res ^= rhs; return Res; } private: + group_mask(const marray &rhs) : Bits(rhs) {} marray Bits; }; template group_mask group_ballot(Group g, bool predicate) { @@ -222,8 +234,9 @@ template group_mask group_ballot(Group g, bool predicate) { #ifdef __SYCL_DEVICE_ONLY__ auto res = __spirv_GroupNonUniformBallot( detail::spirv::group_scope::value, predicate); - return marray{res[3], res[2], res[1], - res[0]}; + return detail::Builder::createGroupMask( + marray{res[3], res[2], res[1], + res[0]}); #else (void)predicate; throw exception{errc::feature_not_supported, diff --git a/sycl/test/extensions/group_mask.cpp b/sycl/test/extensions/group_mask.cpp index ca4184301e85d..67c54c5b209fd 100644 --- a/sycl/test/extensions/group_mask.cpp +++ b/sycl/test/extensions/group_mask.cpp @@ -13,7 +13,9 @@ #include int main() { - sycl::ext::oneapi::group_mask g{sycl::marray{0}}; + auto g = + sycl::detail::Builder::createGroupMask( + sycl::marray{0}); assert(g.none() && !g.any() && !g.all()); assert(g[10] == false); // reference::operator[](id) const; g[10] = true; // reference::operator=(bool); @@ -66,17 +68,23 @@ int main() { b.reset_low(); b.reset_high(); assert(!b[13] && b[43] && b[79] && !b[101]); - b.insert_bits({1, 2, 4, 8}); + b.insert_bits(sycl::marray{1, 2, 4, 8}); assert(b[96] && b[65] && b[34] && b[3]); g = b; g <<= 33; assert(!g[96] && !g[65] && !g[34] && !g[3] && g[98] && g[67] && g[36]); - b.insert_bits({1, 1, 1, 1}, 15); + b.insert_bits(sycl::marray{1, 1, 1, 1}, 15); assert(b[111] && !b[96] && b[79] && !b[65] && b[47] && !b[34] && b[15] && b[3]); + + auto r = b.extract_bits>(); + for(size_t i=0; i>= 79; assert(b[32] && b[0]); b.flip(32); b.flip(0); assert(b.none()); + b.insert_bits((int)1); } From d855e4843cb4ee61ae7f9c6dfbecf83481aa6738 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Tue, 14 Sep 2021 09:50:55 +0300 Subject: [PATCH 11/15] Rename feature to sub-group mask --- .../SubGroupMask.asciidoc} | 78 +++--- sycl/include/CL/sycl.hpp | 2 +- sycl/include/CL/sycl/detail/helpers.hpp | 3 +- sycl/include/CL/sycl/feature_test.hpp | 2 +- sycl/include/sycl/ext/oneapi/group_mask.hpp | 249 ----------------- .../sycl/ext/oneapi/sub_group_mask.hpp | 255 ++++++++++++++++++ .../{group_mask.cpp => sub_group_mask.cpp} | 4 +- sycl/test/extensions/macro.cpp | 6 +- .../{group_mask.cpp => sub_group_mask.cpp} | 74 ++--- 9 files changed, 340 insertions(+), 333 deletions(-) rename sycl/doc/extensions/{GroupMask/GroupMask.asciidoc => SubGroupMask/SubGroupMask.asciidoc} (77%) delete mode 100644 sycl/include/sycl/ext/oneapi/group_mask.hpp create mode 100644 sycl/include/sycl/ext/oneapi/sub_group_mask.hpp rename sycl/test/check_device_code/{group_mask.cpp => sub_group_mask.cpp} (74%) rename sycl/test/extensions/{group_mask.cpp => sub_group_mask.cpp} (55%) diff --git a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc b/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc similarity index 77% rename from sycl/doc/extensions/GroupMask/GroupMask.asciidoc rename to sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc index 3cf970bd6afd6..f28857c2e8d53 100755 --- a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc +++ b/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc @@ -1,4 +1,4 @@ -= SYCL_EXT_ONEAPI_GROUP_MASK += SYCL_EXT_ONEAPI_SUB_GROUP_MASK :source-highlighter: coderay :coderay-linenums-mode: table @@ -21,7 +21,7 @@ IMPORTANT: This specification is a draft. NOTE: Khronos(R) is a registered trademark and SYCL(TM) and SPIR(TM) are trademarks of The Khronos Group Inc. OpenCL(TM) is a trademark of Apple Inc. used by permission by Khronos. -This document describes an extension which adds a `group_mask` type. Such a mask can be used to efficiently represent subsets of work-items in a group for which a given Boolean condition holds. Group mask functionality is currently limited to groups that are instances of the `sub_group` class. +This document describes an extension which adds a `sub_group_mask` type. Such a mask can be used to efficiently represent subsets of work-items in a sub-group for which a given Boolean condition holds. Sub-group mask functionality is currently limited to groups that are instances of the `sub_group` class. == Notice @@ -74,25 +74,25 @@ the specification. === Ballot -The `group_ballot` algorithm converts a Boolean condition from each work-item -in the group into a group mask. Like other group algorithms, `group_ballot` +The `sub_group_ballot` algorithm converts a Boolean condition from each work-item +in the group into a group mask. Like other group algorithms, `sub_group_ballot` must be encountered by all work-items in the group in converged control flow. |=== |Function|Description -|`template Group::mask_type group_ballot(Group g, bool predicate = true)` -|Return a `group_mask` with one bit for each work-item in group _g_. A bit is set in this mask if and only if the corresponding work-item's _predicate_ is `true`. +|`template Group::mask_type sub_group_ballot(Group g, bool predicate = true)` +|Return a `sub_group_mask` with one bit for each work-item in group _g_. A bit is set in this mask if and only if the corresponding work-item's _predicate_ is `true`. |=== === Group Masks The group mask type is an opaque type, permitting implementations to use any mask representation that has the same size and alignment across host and -device. The maximum number of bits that can be stored in a `group_mask` is -exposed as a static member variable, `group_mask::max_bits`. +device. The maximum number of bits that can be stored in a `sub_group_mask` is +exposed as a static member variable, `sub_group_mask::max_bits`. -Functions declared in the `group_mask` class can be called independently by +Functions declared in the `sub_group_mask` class can be called independently by different work-items in the same group. An instance of a group class (e.g. `group` or `sub_group`) is not required to manipulate a group mask. @@ -107,7 +107,7 @@ work-item with the id `max_local_range()-1`. |Return `true` if the bit corresponding to the specified _id_ is set in the mask. -|`group_mask::reference operator[](id<1> id)` +|`sub_group_mask::reference operator[](id<1> id)` |Return a reference to the bit corresponding to the specified _id_ in the mask. |`bool test(id<1> id) const` @@ -143,7 +143,7 @@ work-item with the id `max_local_range()-1`. `CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+] `CHAR_BIT * sizeof(T)`) bits are ignored. -|`template > T extract_bits(id<1> pos = 0) const` +|`template > void extract_bits(T &out, id<1> pos = 0) const` |Return `CHAR_BIT * sizeof(T)` bits from the mask, starting from _pos_. `T` must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+] `CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+] @@ -176,45 +176,45 @@ work-item with the id `max_local_range()-1`. |`void flip(id<1> id)` |Toggle the value of the bit corresponding to the specified _id_. -|`bool operator==(const group_mask &rhs) const` +|`bool operator==(const sub_group_mask &rhs) const` |Return true if each bit in this mask is equal to the corresponding bit in `rhs`. -|`bool operator!=(const group_mask &rhs) const` +|`bool operator!=(const sub_group_mask &rhs) const` |Return true if any bit in this mask is not equal to the corresponding bit in `rhs`. -|`group_mask &operator &=(const group_mask &rhs)` +|`sub_group_mask &operator &=(const sub_group_mask &rhs)` |Set the bits of this mask to the result of performing a bitwise AND with this mask and `rhs`. -|`group_mask &operator \|=(const group_mask &rhs)` +|`sub_group_mask &operator \|=(const sub_group_mask &rhs)` |Set the bits of this mask to the result of performing a bitwise OR with this mask and `rhs`. -|`group_mask &operator ^=(const group_mask &rhs)` +|`sub_group_mask &operator ^=(const sub_group_mask &rhs)` |Set the bits of this mask to the result of performing a bitwise XOR with this mask and `rhs`. -|`group_mask &operator pass:[<<=](size_t shift)` +|`sub_group_mask &operator pass:[<<=](size_t shift)` |Set the bits of this mask to the result of shifting its bits _shift_ positions to the left using a logical shift. Bits that are shifted out to the left are discarded, and zeroes are shifted in from the right. -|`group_mask &operator >>=(size_t shift)` +|`sub_group_mask &operator >>=(size_t shift)` |Set the bits of this mask to the result of shifting its bits _shift_ positions to the right using a logical shift. Bits that are shifted out to the right are discarded, and zeroes are shifted in from the left. -|`group_mask operator ~() const` +|`sub_group_mask operator ~() const` |Return a mask representing the result of flipping all the bits in this mask. -|`group_mask operator <<(size_t shift) const` +|`sub_group_mask operator <<(size_t shift) const` |Return a mask representing the result of shifting its bits _shift_ positions to the left using a logical shift. Bits that are shifted out to the left are discarded, and zeroes are shifted in from the right. -|`group_mask operator >>(size_t shift) const` +|`sub_group_mask operator >>(size_t shift) const` |Return a mask representing the result of shifting its bits _shift_ positions to the right using a logical shift. Bits that are shifted out to the right are discarded, and zeroes are shifted in from the left. @@ -224,15 +224,15 @@ work-item with the id `max_local_range()-1`. |=== |Function|Description -|`group_mask operator &(const group_mask& lhs, const group_mask& rhs)` +|`sub_group_mask operator &(const sub_group_mask& lhs, const sub_group_mask& rhs)` |Return a mask representing the result of performing a bitwise AND of `lhs` and `rhs`. -|`group_mask operator \|(const group_mask& lhs, const group_mask& rhs)` +|`sub_group_mask operator \|(const sub_group_mask& lhs, const sub_group_mask& rhs)` |Return a mask representing the result of performing a bitwise OR of `lhs` and `rhs`. -|`group_mask operator ^(const group_mask& lhs, const group_mask& rhs)` +|`sub_group_mask operator ^(const sub_group_mask& lhs, const sub_group_mask& rhs)` |Return a mask representing the result of performing a bitwise XOR of `lhs` and `rhs`. @@ -246,7 +246,7 @@ namespace sycl { namespace ext { namespace oneapi { -struct group_mask { +struct sub_group_mask { // enable reference to individual bit struct reference { @@ -275,7 +275,7 @@ struct group_mask { void insert_bits(const T &bits, id<1> pos = 0); template > - T extract_bits(id<1> pos = 0); + void extract_bits(T &out, id<1> pos = 0); void set(); void set(id<1> id, bool value = true); @@ -286,24 +286,24 @@ struct group_mask { void flip(); void flip(id<1> id); - bool operator==(const group_mask &rhs) const; - bool operator!=(const group_mask &rhs) const; + bool operator==(const sub_group_mask &rhs) const; + bool operator!=(const sub_group_mask &rhs) const; - group_mask &operator &=(const group_mask &rhs); - group_mask &operator |=(const group_mask &rhs); - group_mask &operator ^=(const group_mask &rhs); - group_mask &operator <<=(size_t n); - group_mask &operator >>=(size_t n); + sub_group_mask &operator &=(const sub_group_mask &rhs); + sub_group_mask &operator |=(const sub_group_mask &rhs); + sub_group_mask &operator ^=(const sub_group_mask &rhs); + sub_group_mask &operator <<=(size_t n); + sub_group_mask &operator >>=(size_t n); - group_mask operator ~() const; - group_mask operator <<(size_t n) const; - group_mask operator >>(size_t n) const; + sub_group_mask operator ~() const; + sub_group_mask operator <<(size_t n) const; + sub_group_mask operator >>(size_t n) const; }; -group_mask operator &(const group_mask& lhs, const group_mask& rhs); -group_mask operator |(const group_mask& lhs, const group_mask& rhs); -group_mask operator ^(const group_mask& lhs, const group_mask& rhs); +sub_group_mask operator &(const sub_group_mask& lhs, const sub_group_mask& rhs); +sub_group_mask operator |(const sub_group_mask& lhs, const sub_group_mask& rhs); +sub_group_mask operator ^(const sub_group_mask& lhs, const sub_group_mask& rhs); } // namespace oneapi } // namespace ext diff --git a/sycl/include/CL/sycl.hpp b/sycl/include/CL/sycl.hpp index d307b9053aca2..46c080ee880e4 100644 --- a/sycl/include/CL/sycl.hpp +++ b/sycl/include/CL/sycl.hpp @@ -56,7 +56,7 @@ #include #include #include -#include #include #include #include +#include diff --git a/sycl/include/CL/sycl/detail/helpers.hpp b/sycl/include/CL/sycl/detail/helpers.hpp index 28fa272f55e50..a70fc8410f2d5 100644 --- a/sycl/include/CL/sycl/detail/helpers.hpp +++ b/sycl/include/CL/sycl/detail/helpers.hpp @@ -83,8 +83,7 @@ class Builder { return group(Global, Local, Global / Local, Index); } - template - static ResType createGroupMask(marray Bits) { + template static ResType createSubGroupMask(uint32_t Bits) { return ResType(Bits); } diff --git a/sycl/include/CL/sycl/feature_test.hpp b/sycl/include/CL/sycl/feature_test.hpp index 1f6ec558f13cd..e3cacc48f0982 100644 --- a/sycl/include/CL/sycl/feature_test.hpp +++ b/sycl/include/CL/sycl/feature_test.hpp @@ -14,7 +14,7 @@ namespace sycl { // TODO: Move these feature-test macros to compiler driver. #define SYCL_EXT_INTEL_DEVICE_INFO 2 -#define SYCL_EXT_ONEAPI_GROUP_MASK 1 +#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 1 #define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1 // As for SYCL_EXT_ONEAPI_MATRIX: // 1- provides AOT initial implementation for AMX for the experimental matrix diff --git a/sycl/include/sycl/ext/oneapi/group_mask.hpp b/sycl/include/sycl/ext/oneapi/group_mask.hpp deleted file mode 100644 index 8a1d8f4a0addf..0000000000000 --- a/sycl/include/sycl/ext/oneapi/group_mask.hpp +++ /dev/null @@ -1,249 +0,0 @@ -//==----------------- group_mask.hpp --- SYCL group mask -------------------==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#pragma once - -#include -#include -#include -#include -#include -#include - -__SYCL_INLINE_NAMESPACE(cl) { -namespace sycl { -namespace detail { -class Builder; -} // namespace detail - -namespace ext { -namespace oneapi { - -struct group_mask { - friend class detail::Builder; - static constexpr size_t max_bits = 128 /* implementation-defined */; - static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT; - /* Bitmask is packed in marray of uint32_t elements. This value represents - * legth of marray. Round up in case when it is not evenly divisible. */ - static constexpr size_t marray_size = (max_bits + word_size - 1) / word_size; - /* The bits are stored in the memory in the following way: - marray id | 0 | 1 | 2 | 3 | - bit id |127 .. 96|95 .. 64|63 .. 32|31 .. 0| - */ - - // enable reference to individual bit - struct reference { - reference &operator=(bool x) { - if (x) { - Ref |= RefBit; - } else { - Ref &= ~RefBit; - } - return *this; - } - reference &operator=(const reference &x) { - operator=((bool)x); - return *this; - } - bool operator~() const { return !(Ref & RefBit); } - operator bool() const { return Ref & RefBit; } - reference &flip() { - operator=(!(bool)*this); - return *this; - } - - reference(group_mask &gmask, size_t pos) - : Ref(gmask.Bits[marray_size - (pos / word_size) - 1]) { - RefBit = 1 << pos % word_size; - } - - private: - // Reference to the word containing the bit - uint32_t &Ref; - // Bit mask where only referenced bit is set - uint32_t RefBit; - }; - - bool operator[](id<1> id) const { - return Bits[marray_size - id.get(0) / word_size - 1] & - (1 << (id.get(0) % word_size)); - } - reference operator[](id<1> id) { return {*this, id.get(0)}; } - bool test(id<1> id) const { return operator[](id); } - bool all() const { return !(~(Bits[0] & Bits[1] & Bits[2] & Bits[3])); } - bool any() const { return Bits[0] | Bits[1] | Bits[2] | Bits[3]; } - bool none() const { return !any(); } - uint32_t count() const { - unsigned int count = 0; - for (auto word : Bits) { - while (word) { - word &= (word - 1); - count++; - } - } - return count; - } - uint32_t size() const { return max_bits; } - id<1> find_low() const { - size_t i = 0; - while (i < size() && !operator[](i)) - i++; - return {i}; - } - id<1> find_high() const { - size_t i = size() - 1; - while (i > 0 && !operator[](i)) - i--; - return {operator[](i) ? i : size()}; - } - - template ::value>> - void insert_bits(T bits, id<1> pos = 0) {} - - template - void insert_bits(const marray &bits, id<1> pos = 0) { - group_mask tmp(bits); - if (pos.get(0) > 0) { - operator<<=(size() - pos.get(0)); - operator>>=(size() - pos.get(0)); - tmp <<= pos.get(0); - } else { - reset(); - } - Bits |= tmp.Bits; - } - - template > - T extract_bits(id<1> pos = 0) { - group_mask Tmp = *this; - Tmp <<= pos.get(0); - return Tmp.Bits; - } - - void set() { Bits = ~(uint32_t{0}); } - void set(id<1> id, bool value = true) { operator[](id) = value; } - void reset() { Bits = uint32_t{0}; } - void reset(id<1> id) { operator[](id) = 0; } - void reset_low() { reset(find_low()); } - void reset_high() { reset(find_high()); } - void flip() { Bits = ~Bits; } - void flip(id<1> id) { operator[](id).flip(); } - - bool operator==(const group_mask &rhs) const { - bool Res = true; - for (size_t i = 0; i < marray_size; i++) - Res &= Bits[i] == rhs.Bits[i]; - return Res; - } - bool operator!=(const group_mask &rhs) const { return !(*this == rhs); } - - group_mask &operator&=(const group_mask &rhs) { - Bits &= rhs.Bits; - return *this; - } - group_mask &operator|=(const group_mask &rhs) { - Bits |= rhs.Bits; - return *this; - } - - group_mask &operator^=(const group_mask &rhs) { - Bits ^= rhs.Bits; - return *this; - } - - group_mask &operator<<=(size_t pos) { - if (pos > 0) { - marray Res{0}; - size_t word_shift = pos / word_size; - size_t bit_shift = pos % word_size; - uint32_t extra_bits = 0; - for (int i = marray_size - 1; i >= 0; i--) { - Res[i - word_shift] = (Bits[i] << bit_shift) + extra_bits; - extra_bits = Bits[i] >> (word_size - bit_shift); - } - Bits = Res; - } - return *this; - } - - group_mask &operator>>=(size_t pos) { - if (pos > 0) { - marray Res{0}; - size_t word_shift = pos / word_size; - size_t bit_shift = pos % word_size; - uint32_t extra_bits = 0; - for (size_t i = 0; i < marray_size; i++) { - Res[i + word_shift] = (Bits[i] >> bit_shift) + extra_bits; - extra_bits = Bits[i] << (word_size - bit_shift); - } - Bits = Res; - } - return *this; - } - - group_mask operator~() const { - auto Tmp = *this; - Tmp.flip(); - return Tmp; - } - group_mask operator<<(size_t pos) const { - auto Tmp = *this; - Tmp <<= pos; - return Tmp; - } - group_mask operator>>(size_t pos) const { - auto Tmp = *this; - Tmp >>= pos; - return Tmp; - } - - group_mask(const group_mask &rhs) : Bits(rhs.Bits) {} - - template - friend group_mask group_ballot(Group g, bool predicate); - - friend group_mask operator&(const group_mask &lhs, const group_mask &rhs) { - auto Res = lhs; - Res &= rhs; - return Res; - } - - friend group_mask operator|(const group_mask &lhs, const group_mask &rhs) { - auto Res = lhs; - Res |= rhs; - return Res; - } - - friend group_mask operator^(const group_mask &lhs, const group_mask &rhs) { - auto Res = lhs; - Res ^= rhs; - return Res; - } - -private: - group_mask(const marray &rhs) : Bits(rhs) {} - marray Bits; -}; -template group_mask group_ballot(Group g, bool predicate) { - (void)g; -#ifdef __SYCL_DEVICE_ONLY__ - auto res = __spirv_GroupNonUniformBallot( - detail::spirv::group_scope::value, predicate); - return detail::Builder::createGroupMask( - marray{res[3], res[2], res[1], - res[0]}); -#else - (void)predicate; - throw exception{errc::feature_not_supported, - "Group mask is not supported on host device"}; -#endif -} -} // namespace oneapi -} // namespace ext -} // namespace sycl -} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp new file mode 100644 index 0000000000000..8e5c73179f8c3 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp @@ -0,0 +1,255 @@ +//==------------ sub_group_mask.hpp --- SYCL sub-group mask ----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#pragma once + +#include +#include +#include +#include +#include +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace detail { +class Builder; +} // namespace detail + +namespace ext { +namespace oneapi { + +struct sub_group_mask { + friend class detail::Builder; + static constexpr size_t max_bits = 32 /* implementation-defined */; + static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT; + + // enable reference to individual bit + struct reference { + reference &operator=(bool x) { + if (x) { + Ref |= RefBit; + } else { + Ref &= ~RefBit; + } + return *this; + } + reference &operator=(const reference &x) { + operator=((bool)x); + return *this; + } + bool operator~() const { return !(Ref & RefBit); } + operator bool() const { return Ref & RefBit; } + reference &flip() { + operator=(!(bool)*this); + return *this; + } + + reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) { + RefBit = 1 << pos % word_size; + } + + private: + // Reference to the word containing the bit + uint32_t &Ref; + // Bit mask where only referenced bit is set + uint32_t RefBit; + }; + + bool operator[](id<1> id) const { + return Bits & (1 << (id.get(0) % word_size)); + } + reference operator[](id<1> id) { return {*this, id.get(0)}; } + bool test(id<1> id) const { return operator[](id); } + bool all() const { return !~Bits; } + bool any() const { return Bits; } + bool none() const { return !Bits; } + uint32_t count() const { + unsigned int count = 0; + auto word = Bits; + while (word) { + word &= (word - 1); + count++; + } + return count; + } + uint32_t size() const { return max_bits; } + id<1> find_low() const { + size_t i = 0; + while (i < size() && !operator[](i)) + i++; + return {i}; + } + id<1> find_high() const { + size_t i = size() - 1; + while (i > 0 && !operator[](i)) + i--; + return {operator[](i) ? i : size()}; + } + + template ::value>> + void insert_bits(Type bits, id<1> pos = 0) { + size_t insert_size = sizeof(Type) * CHAR_BIT; + uint32_t insert_data = (uint32_t)bits; + insert_data <<= pos.get(0); + uint32_t mask = 0; + if (pos.get(0) + insert_size < size()) + mask |= (0xffffffff << (pos.get(0) + insert_size)); + if (pos.get(0) < size()) + mask |= (0xffffffff >> (size() - pos.get(0))); + Bits &= mask; + Bits += insert_data; + } + + /* The bits are stored in the memory in the following way: + marray id | 0 | 1 | 2 | 3 | + bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24| + */ + template ::value>> + void insert_bits(const marray &bits, id<1> pos = 0) { + size_t cur_pos = pos.get(0); + for (auto elem : bits) { + if (cur_pos < size()) { + this->insert_bits(elem, cur_pos); + cur_pos += sizeof(Type) * CHAR_BIT; + } + } + } + + template ::value>> + void extract_bits(Type &bits, id<1> pos = 0) { + uint32_t Res = Bits; + if (pos.get(0) < size()) { + if (pos.get(0) > 0) { + Res >>= pos.get(0); + } + + if (sizeof(Type) * CHAR_BIT < size()) { + Res &= (0xffffffff >> (size() - (sizeof(Type) * CHAR_BIT))); + } + bits = (Type)Res; + } else { + bits = 0; + } + } + + template ::value>> + void extract_bits(marray &bits, id<1> pos = 0) { + size_t cur_pos = pos.get(0); + for (auto &elem : bits) { + if (cur_pos < size()) { + this->extract_bits(elem, cur_pos); + cur_pos += sizeof(Type) * CHAR_BIT; + } else { + elem = 0; + } + } + } + + void set() { Bits = uint32_t{0xffffffff}; } + void set(id<1> id, bool value = true) { operator[](id) = value; } + void reset() { Bits = uint32_t{0}; } + void reset(id<1> id) { operator[](id) = 0; } + void reset_low() { reset(find_low()); } + void reset_high() { reset(find_high()); } + void flip() { Bits = ~Bits; } + void flip(id<1> id) { operator[](id).flip(); } + + bool operator==(const sub_group_mask &rhs) const { return Bits == rhs.Bits; } + bool operator!=(const sub_group_mask &rhs) const { return !(*this == rhs); } + + sub_group_mask &operator&=(const sub_group_mask &rhs) { + Bits &= rhs.Bits; + return *this; + } + sub_group_mask &operator|=(const sub_group_mask &rhs) { + Bits |= rhs.Bits; + return *this; + } + + sub_group_mask &operator^=(const sub_group_mask &rhs) { + Bits ^= rhs.Bits; + return *this; + } + + sub_group_mask &operator<<=(size_t pos) { + Bits <<= pos; + return *this; + } + + sub_group_mask &operator>>=(size_t pos) { + Bits >>= pos; + return *this; + } + + sub_group_mask operator~() const { + auto Tmp = *this; + Tmp.flip(); + return Tmp; + } + sub_group_mask operator<<(size_t pos) const { + auto Tmp = *this; + Tmp <<= pos; + return Tmp; + } + sub_group_mask operator>>(size_t pos) const { + auto Tmp = *this; + Tmp >>= pos; + return Tmp; + } + + sub_group_mask(const sub_group_mask &rhs) : Bits(rhs.Bits) {} + + template + friend sub_group_mask sub_group_ballot(Group g, bool predicate); + + friend sub_group_mask operator&(const sub_group_mask &lhs, + const sub_group_mask &rhs) { + auto Res = lhs; + Res &= rhs; + return Res; + } + + friend sub_group_mask operator|(const sub_group_mask &lhs, + const sub_group_mask &rhs) { + auto Res = lhs; + Res |= rhs; + return Res; + } + + friend sub_group_mask operator^(const sub_group_mask &lhs, + const sub_group_mask &rhs) { + auto Res = lhs; + Res ^= rhs; + return Res; + } + +private: + sub_group_mask(uint32_t rhs) : Bits(rhs) {} + uint32_t Bits; +}; +template +sub_group_mask sub_group_ballot(Group g, bool predicate) { + (void)g; +#ifdef __SYCL_DEVICE_ONLY__ + auto res = __spirv_GroupNonUniformBallot( + detail::spirv::group_scope::value, predicate); + return detail::Builder::createSubGroupMask(res[0]); +#else + (void)predicate; + throw exception{errc::feature_not_supported, + "Sub-group mask is not supported on host device"}; +#endif +} +} // namespace oneapi +} // namespace ext +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/test/check_device_code/group_mask.cpp b/sycl/test/check_device_code/sub_group_mask.cpp similarity index 74% rename from sycl/test/check_device_code/group_mask.cpp rename to sycl/test/check_device_code/sub_group_mask.cpp index 06afa8fa55d84..5ebca8f738714 100644 --- a/sycl/test/check_device_code/group_mask.cpp +++ b/sycl/test/check_device_code/sub_group_mask.cpp @@ -4,5 +4,7 @@ using namespace sycl; -SYCL_EXTERNAL void test_group_mask(group<> g) { ext::oneapi::group_ballot(g, true); } +SYCL_EXTERNAL void test_group_mask(group<> g) { + ext::oneapi::sub_group_ballot(g, true); +} // CHECK: %{{.*}} = call spir_func <4 x i32> @_Z[[#]]__spirv_GroupNonUniformBallotjb(i32 {{.*}}, i1{{.*}}) diff --git a/sycl/test/extensions/macro.cpp b/sycl/test/extensions/macro.cpp index aaa70eb21ce36..7264ac21e4264 100644 --- a/sycl/test/extensions/macro.cpp +++ b/sycl/test/extensions/macro.cpp @@ -5,10 +5,10 @@ #include #include int main() { -#if SYCL_EXT_ONEAPI_GROUP_MASK == 1 - std::cout << "SYCL_EXT_ONEAPI_GROUP_MASK=1" << std::endl; +#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK == 1 + std::cout << "SYCL_EXT_ONEAPI_SUB_GROUP_MASK=1" << std::endl; #else - std::cerr << "SYCL_EXT_ONEAPI_GROUP_MASK!=1" << std::endl; + std::cerr << "SYCL_EXT_ONEAPI_SUB_GROUP_MASK!=1" << std::endl; exit(1); #endif exit(0); diff --git a/sycl/test/extensions/group_mask.cpp b/sycl/test/extensions/sub_group_mask.cpp similarity index 55% rename from sycl/test/extensions/group_mask.cpp rename to sycl/test/extensions/sub_group_mask.cpp index 67c54c5b209fd..9a76d0a3fe0e4 100644 --- a/sycl/test/extensions/group_mask.cpp +++ b/sycl/test/extensions/sub_group_mask.cpp @@ -1,7 +1,7 @@ // RUN: %clangxx -g -O0 -fsycl -fsycl-targets=%sycl_triple %s -o %t.out // RUN: %t.out -//==-------- group_mask.cpp - SYCL group_mask test -------------------------==// +//==-------- sub_group_mask.cpp - SYCL sub-group mask test -----------------==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -13,9 +13,8 @@ #include int main() { - auto g = - sycl::detail::Builder::createGroupMask( - sycl::marray{0}); + auto g = sycl::detail::Builder::createSubGroupMask< + sycl::ext::oneapi::sub_group_mask>(0); assert(g.none() && !g.any() && !g.all()); assert(g[10] == false); // reference::operator[](id) const; g[10] = true; // reference::operator=(bool); @@ -26,38 +25,38 @@ int main() { assert(g[10] == false); assert(g[11] == true); assert(g.test(10) == false && g.test(11) == true); - g.set(101, 1); + g.set(30, 1); g.set(11, 0); - g.set(53, 1); + g.set(23, 1); assert(!g.none() && g.any() && !g.all()); assert(g.count() == 2); - assert(g.find_low() == 53); - assert(g.find_high() == 101); - assert(g.size() == 128); + assert(g.find_low() == 23); + assert(g.find_high() == 30); + assert(g.size() == 32); g.reset(); assert(g.none() && !g.any() && !g.all()); assert(g.find_low() == g.size() && g.find_high() == g.size()); g.set(); assert(!g.none() && g.any() && g.all()); - assert(g.find_low() == 0 && g.find_high() == 127); + assert(g.find_low() == 0 && g.find_high() == 31); g.flip(); assert(g.none() && !g.any() && !g.all()); g.flip(13); - g.flip(43); - g.flip(79); + g.flip(23); + g.flip(29); auto b = g; assert(b == g && !(b != g)); - g.flip(101); - assert(g.find_high() == 101); - assert(b.find_high() == 79); + g.flip(31); + assert(g.find_high() == 31); + assert(b.find_high() == 29); assert(b != g && !(b == g)); - b.flip(101); + b.flip(31); assert(b == g && !(b != g)); b = g >> 1; - assert(b[12] && b[42] && b[78] && b[100]); + assert(b[12] && b[22] && b[28] && b[30]); b <<= 1; assert(b == g); g ^= ~b; @@ -67,24 +66,25 @@ int main() { assert((g ^ ~g).all()); b.reset_low(); b.reset_high(); - assert(!b[13] && b[43] && b[79] && !b[101]); - b.insert_bits(sycl::marray{1, 2, 4, 8}); - assert(b[96] && b[65] && b[34] && b[3]); - g = b; - g <<= 33; - assert(!g[96] && !g[65] && !g[34] && !g[3] && g[98] && g[67] && g[36]); - b.insert_bits(sycl::marray{1, 1, 1, 1}, 15); - assert(b[111] && !b[96] && b[79] && !b[65] && b[47] && !b[34] && b[15] && - b[3]); - - auto r = b.extract_bits>(); - for(size_t i=0; i>= 79; - assert(b[32] && b[0]); - b.flip(32); - b.flip(0); - assert(b.none()); - b.insert_bits((int)1); + assert(!b[13] && b[23] && b[29] && !b[31]); + b.insert_bits(0x01020408); + assert(b[24] && b[17] && b[10] && b[3]); + b <<= 13; + assert(!b[24] && !b[17] && !b[10] && !b[3] && b[30] && b[23] && b[16]); + b.insert_bits((char)0b01010101, 18); + assert(b[18] && b[20] && b[22] && b[24] && b[30] && !b[23] && b[16]); + b[3] = true; + b.insert_bits(sycl::marray{1, 2, 4, 8, 16, 32, 64, 128}, 5); + assert(!b[18] && !b[20] && !b[22] && !b[24] && !b[30] && !b[16] && b[3] && + b[5] && b[14] && b[23]); + char r; + b.extract_bits(r); + assert(r == 0b00101000); + long r2 = -1; + b.extract_bits(r2, 16); + assert(r2 == 128); + b[31] = true; + sycl::marray r3{-1}; + b.extract_bits(r3, 14); + assert(r3[0] == 1 && r3[1] == 2 && r3[2] == 2 && !r3[3] && !r3[4] && !r3[5]); } From a184e291ef878f0022614cf24513ed67c72985ec Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Wed, 15 Sep 2021 19:31:07 +0300 Subject: [PATCH 12/15] Apply review comments --- .../SubGroupMask/SubGroupMask.asciidoc | 8 +++---- sycl/include/CL/sycl/detail/helpers.hpp | 5 ++-- .../sycl/ext/oneapi/sub_group_mask.hpp | 24 ++++++++++++++----- .../test/check_device_code/sub_group_mask.cpp | 4 ++-- sycl/test/extensions/sub_group_mask.cpp | 2 +- 5 files changed, 28 insertions(+), 15 deletions(-) diff --git a/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc b/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc index f28857c2e8d53..ac2cd8401690c 100755 --- a/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc +++ b/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc @@ -21,7 +21,7 @@ IMPORTANT: This specification is a draft. NOTE: Khronos(R) is a registered trademark and SYCL(TM) and SPIR(TM) are trademarks of The Khronos Group Inc. OpenCL(TM) is a trademark of Apple Inc. used by permission by Khronos. -This document describes an extension which adds a `sub_group_mask` type. Such a mask can be used to efficiently represent subsets of work-items in a sub-group for which a given Boolean condition holds. Sub-group mask functionality is currently limited to groups that are instances of the `sub_group` class. +This document describes an extension which adds a `sub_group_mask` type. Such a mask can be used to efficiently represent subsets of work-items in a sub-group for which a given Boolean condition holds. == Notice @@ -74,14 +74,14 @@ the specification. === Ballot -The `sub_group_ballot` algorithm converts a Boolean condition from each work-item -in the group into a group mask. Like other group algorithms, `sub_group_ballot` +The `group_ballot` algorithm converts a Boolean condition from each work-item +in the group into a group mask. Like other group algorithms, `group_ballot` must be encountered by all work-items in the group in converged control flow. |=== |Function|Description -|`template Group::mask_type sub_group_ballot(Group g, bool predicate = true)` +|`template Group::mask_type group_ballot(Group g, bool predicate = true)` |Return a `sub_group_mask` with one bit for each work-item in group _g_. A bit is set in this mask if and only if the corresponding work-item's _predicate_ is `true`. |=== diff --git a/sycl/include/CL/sycl/detail/helpers.hpp b/sycl/include/CL/sycl/detail/helpers.hpp index a70fc8410f2d5..d09738cc3fca0 100644 --- a/sycl/include/CL/sycl/detail/helpers.hpp +++ b/sycl/include/CL/sycl/detail/helpers.hpp @@ -83,8 +83,9 @@ class Builder { return group(Global, Local, Global / Local, Index); } - template static ResType createSubGroupMask(uint32_t Bits) { - return ResType(Bits); + template static ResType createSubGroupMask(uint32_t Bits, + size_t BitsNum) { + return ResType(Bits, BitsNum); } template diff --git a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp index 8e5c73179f8c3..3ff837beabd16 100644 --- a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp @@ -77,7 +77,7 @@ struct sub_group_mask { } return count; } - uint32_t size() const { return max_bits; } + uint32_t size() const { return bits_num; } id<1> find_low() const { size_t i = 0; while (i < size() && !operator[](i)) @@ -206,10 +206,13 @@ struct sub_group_mask { return Tmp; } - sub_group_mask(const sub_group_mask &rhs) : Bits(rhs.Bits) {} + sub_group_mask(const sub_group_mask &rhs) + : Bits(rhs.Bits), bits_num(rhs.bits_num) {} template - friend sub_group_mask sub_group_ballot(Group g, bool predicate); + friend detail::enable_if_t< + std::is_same, sub_group>::value, sub_group_mask> + group_ballot(Group g, bool predicate); friend sub_group_mask operator&(const sub_group_mask &lhs, const sub_group_mask &rhs) { @@ -233,22 +236,31 @@ struct sub_group_mask { } private: - sub_group_mask(uint32_t rhs) : Bits(rhs) {} + sub_group_mask(uint32_t rhs, size_t bn) : Bits(rhs), bits_num(bn) { + assert(bits_num <= max_bits); + } uint32_t Bits; + // Number of valuable bits + size_t bits_num; }; + template -sub_group_mask sub_group_ballot(Group g, bool predicate) { +detail::enable_if_t, sub_group>::value, + sub_group_mask> +group_ballot(Group g, bool predicate) { (void)g; #ifdef __SYCL_DEVICE_ONLY__ auto res = __spirv_GroupNonUniformBallot( detail::spirv::group_scope::value, predicate); - return detail::Builder::createSubGroupMask(res[0]); + return detail::Builder::createSubGroupMask( + res[0], g.get_max_local_range()[0]); #else (void)predicate; throw exception{errc::feature_not_supported, "Sub-group mask is not supported on host device"}; #endif } + } // namespace oneapi } // namespace ext } // namespace sycl diff --git a/sycl/test/check_device_code/sub_group_mask.cpp b/sycl/test/check_device_code/sub_group_mask.cpp index 5ebca8f738714..b074567e6d461 100644 --- a/sycl/test/check_device_code/sub_group_mask.cpp +++ b/sycl/test/check_device_code/sub_group_mask.cpp @@ -4,7 +4,7 @@ using namespace sycl; -SYCL_EXTERNAL void test_group_mask(group<> g) { - ext::oneapi::sub_group_ballot(g, true); +SYCL_EXTERNAL void test_group_mask(sub_group g) { + ext::oneapi::group_ballot(g, true); } // CHECK: %{{.*}} = call spir_func <4 x i32> @_Z[[#]]__spirv_GroupNonUniformBallotjb(i32 {{.*}}, i1{{.*}}) diff --git a/sycl/test/extensions/sub_group_mask.cpp b/sycl/test/extensions/sub_group_mask.cpp index 9a76d0a3fe0e4..cc2cdfa17c43c 100644 --- a/sycl/test/extensions/sub_group_mask.cpp +++ b/sycl/test/extensions/sub_group_mask.cpp @@ -14,7 +14,7 @@ int main() { auto g = sycl::detail::Builder::createSubGroupMask< - sycl::ext::oneapi::sub_group_mask>(0); + sycl::ext::oneapi::sub_group_mask>(0, 32); assert(g.none() && !g.any() && !g.all()); assert(g[10] == false); // reference::operator[](id) const; g[10] = true; // reference::operator=(bool); From e39e8e17adfe7eaff03686d0cb5c63957e2774c1 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Wed, 15 Sep 2021 19:38:09 +0300 Subject: [PATCH 13/15] Fix clang format --- sycl/include/CL/sycl/detail/helpers.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/CL/sycl/detail/helpers.hpp b/sycl/include/CL/sycl/detail/helpers.hpp index d09738cc3fca0..838b68f33a641 100644 --- a/sycl/include/CL/sycl/detail/helpers.hpp +++ b/sycl/include/CL/sycl/detail/helpers.hpp @@ -83,8 +83,8 @@ class Builder { return group(Global, Local, Global / Local, Index); } - template static ResType createSubGroupMask(uint32_t Bits, - size_t BitsNum) { + template + static ResType createSubGroupMask(uint32_t Bits, size_t BitsNum) { return ResType(Bits, BitsNum); } From fdd0024ad244dab1d5a034e78c90d3ef8f056a49 Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Wed, 15 Sep 2021 20:34:20 +0300 Subject: [PATCH 14/15] Apply review comments --- sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc b/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc index ac2cd8401690c..aedbd42f92476 100755 --- a/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc +++ b/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc @@ -137,13 +137,13 @@ work-item with the id `max_local_range()-1`. |Return the highest `id` with a corresponding bit set in the mask. If no bits are set, the return value is equal to `size()`. -|`template > void insert_bits(const T &bits, id<1> pos = 0)` +|`template void insert_bits(const T &bits, id<1> pos = 0)` |Insert `CHAR_BIT * sizeof(T)` bits into the mask, starting from _pos_. `T` must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+] `CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+] `CHAR_BIT * sizeof(T)`) bits are ignored. -|`template > void extract_bits(T &out, id<1> pos = 0) const` +|`template void extract_bits(T &out, id<1> pos = 0) const` |Return `CHAR_BIT * sizeof(T)` bits from the mask, starting from _pos_. `T` must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+] `CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+] @@ -258,7 +258,6 @@ struct sub_group_mask { }; static constexpr size_t max_bits = /* implementation-defined */; - static constexpr size_t marray_size = /* implementation defined */; bool operator[](id<1> id) const; reference operator[](id<1> id); @@ -271,10 +270,10 @@ struct sub_group_mask { id<1> find_low() const; id<1> find_high() const; - template > + template void insert_bits(const T &bits, id<1> pos = 0); - template > + template void extract_bits(T &out, id<1> pos = 0); void set(); From 1bd4fc3c23a9e131be6ab2acff689c6981e6c40b Mon Sep 17 00:00:00 2001 From: Vladimir Lazarev Date: Fri, 17 Sep 2021 21:38:16 +0300 Subject: [PATCH 15/15] Apply review comment --- sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc b/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc index aedbd42f92476..c3b9a6ca98ca4 100755 --- a/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc +++ b/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc @@ -51,9 +51,9 @@ This extension is written against the SYCL 2020 specification, Revision 3. This extension provides a feature-test macro as described in the core SYCL specification section 6.3.3 "Feature test macros". Therefore, an implementation supporting this extension must predefine the macro -`SYCL_EXT_ONEAPI_GROUP_MASK` to one of the values defined in the table below. -Applications can test for the existence of this macro to determine if the -implementation supports this feature, or applications can test the macro's +`SYCL_EXT_ONEAPI_SUB_GROUP_MASK` to one of the values defined in the table +below. Applications can test for the existence of this macro to determine if +the implementation supports this feature, or applications can test the macro's value to determine which of the extension's APIs the implementation supports. [%header,cols="1,5"]