Skip to content

Commit cd8194d

Browse files
authored
[SYCL] Update sub-group reduce/scan syntax (#688)
Aligns the sub-group implementation with the extension doc: - Enables reduce and scan to take functors as arguments - Adds OpenCL-like reduce(x, op) overload to extension doc Also clarifies interpretation of init values in documentation. Signed-off-by: John Pennycook john.pennycook@intel.com
2 parents 6e76414 + 9b03fec commit cd8194d

File tree

6 files changed

+320
-162
lines changed

6 files changed

+320
-162
lines changed

sycl/doc/extensions/SubGroupNDRange/SubGroupNDRange.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,12 @@ The `plus`, `minimum` and `maximum` functors in the `cl::sycl` namespace corresp
143143
|Member functions|Description|
144144
|----------------|-----------|
145145
| `template <typename T>T broadcast(T x, id<1> local_id) const` | Broadcast the value of `x` from the work-item with the specified id to all work-items within the sub-group. The value of `local_id` must be the same for all work-items in the sub-group. |
146-
| `template <typename T, class BinaryOp>T reduce(T x, T init, BinaryOp binary_op) const` | Combine the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. |
147-
| `template <typename T, class BinaryOp>T exclusive_scan(T x, T init, BinaryOp binary_op) const` | Perform an exclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the exclusive scan of the first `i` work-items in the sub-group. |
148-
| `template <typename T, class BinaryOp>T inclusive_scan(T x, BinaryOp binary_op, T init) const` | Perform an inclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the inclusive scan of the first `i` work-items in the sub-group. |
146+
| `template <typename T, class BinaryOp>T reduce(T x, BinaryOp binary_op) const` | Combine the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. |
147+
| `template <typename T, class BinaryOp>T reduce(T x, T init, BinaryOp binary_op) const` | Combine the values of `x` from all work-items in the sub-group using an initial value of `init` and the specified operator, which must be one of: `plus`, `minimum` or `maximum`. |
148+
| `template <typename T, class BinaryOp>T exclusive_scan(T x, BinaryOp binary_op) const` | Perform an exclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the exclusive scan of the first `i` work-items in the sub-group. The initial value is the identity value of the operator. |
149+
| `template <typename T, class BinaryOp>T exclusive_scan(T x, T init, BinaryOp binary_op) const` | Perform an exclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the exclusive scan of the first `i` work-items in the sub-group. The initial value is specified by `init`. |
150+
| `template <typename T, class BinaryOp>T inclusive_scan(T x, BinaryOp binary_op) const` | Perform an inclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the inclusive scan of the first `i` work-items in the sub-group. |
151+
| `template <typename T, class BinaryOp>T inclusive_scan(T x, BinaryOp binary_op, T init) const` | Perform an inclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the inclusive scan of the initial value `init` and the first `i` work-items in the sub-group. |
149152

150153
## Extended Functionality
151154

@@ -214,12 +217,21 @@ struct sub_group {
214217
template <typename T>
215218
T broadcast(T x, id<1> local_id) const;
216219

220+
template <typename T, class BinaryOp>
221+
T reduce(T x, BinaryOp binary_op) const;
222+
217223
template <typename T, class BinaryOp>
218224
T reduce(T x, T init, BinaryOp binary_op) const;
219225

226+
template <typename T, class BinaryOp>
227+
T exclusive_scan(T x, BinaryOp binary_op) const;
228+
220229
template <typename T, class BinaryOp>
221230
T exclusive_scan(T x, T init, BinaryOp binary_op) const;
222231

232+
template <typename T, class BinaryOp>
233+
T inclusive_scan(T x, BinaryOp binary_op) const;
234+
223235
template <typename T, class BinaryOp>
224236
T inclusive_scan(T x, BinaryOp binary_op, T init) const;
225237

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//==----------- functional.hpp --- SYCL functional -------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
namespace cl {
12+
namespace sycl {
13+
namespace intel {
14+
15+
template <typename T = void> struct minimum {
16+
T operator()(const T &lhs, const T &rhs) const {
17+
return (lhs <= rhs) ? lhs : rhs;
18+
}
19+
};
20+
21+
template <> struct minimum<void> {
22+
template <typename T> T operator()(const T &lhs, const T &rhs) const {
23+
return (lhs <= rhs) ? lhs : rhs;
24+
}
25+
};
26+
27+
template <typename T = void> struct maximum {
28+
T operator()(const T &lhs, const T &rhs) const {
29+
return (lhs >= rhs) ? lhs : rhs;
30+
}
31+
};
32+
33+
template <> struct maximum<void> {
34+
template <typename T> T operator()(const T &lhs, const T &rhs) const {
35+
return (lhs >= rhs) ? lhs : rhs;
36+
}
37+
};
38+
39+
template <typename T = void> struct plus {
40+
T operator()(const T &lhs, const T &rhs) const { return lhs + rhs; }
41+
};
42+
43+
template <> struct plus<void> {
44+
template <typename T> T operator()(const T &lhs, const T &rhs) const {
45+
return lhs + rhs;
46+
}
47+
};
48+
49+
} // namespace intel
50+
} // namespace sycl
51+
} // namespace cl

sycl/include/CL/sycl/intel/sub_group.hpp

Lines changed: 100 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,76 +15,83 @@
1515
#include <CL/sycl/id.hpp>
1616
#include <CL/sycl/range.hpp>
1717
#include <CL/sycl/types.hpp>
18+
#include <CL/sycl/intel/functional.hpp>
1819
#include <type_traits>
1920
#ifdef __SYCL_DEVICE_ONLY__
2021

2122
namespace cl {
2223
namespace sycl {
2324
template <typename T, access::address_space Space> class multi_ptr;
24-
namespace intel {
25-
template <typename>
2625

27-
struct is_vec : std::false_type {};
26+
namespace detail {
27+
28+
template <typename> struct is_vec : std::false_type {};
2829
template <typename T, std::size_t N>
2930
struct is_vec<cl::sycl::vec<T, N>> : std::true_type {};
3031

31-
struct minimum {
32-
template <typename T, __spv::GroupOperation O>
33-
static typename std::enable_if<
32+
template <typename T, __spv::GroupOperation O>
33+
static typename std::enable_if<
3434
!detail::is_floating_point<T>::value && std::is_signed<T>::value, T>::type
35-
calc(T x) {
36-
return __spirv_GroupSMin(__spv::Scope::Subgroup, O, x);
37-
}
35+
calc(T x, intel::minimum<T> op) {
36+
return __spirv_GroupSMin(__spv::Scope::Subgroup, O, x);
37+
}
3838

39-
template <typename T, __spv::GroupOperation O>
40-
static typename std::enable_if<
39+
template <typename T, __spv::GroupOperation O>
40+
static typename std::enable_if<
4141
!detail::is_floating_point<T>::value && std::is_unsigned<T>::value, T>::type
42-
calc(T x) {
43-
return __spirv_GroupUMin(__spv::Scope::Subgroup, O, x);
44-
}
45-
46-
template <typename T, __spv::GroupOperation O>
47-
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
48-
calc(T x) {
49-
return __spirv_GroupFMin(__spv::Scope::Subgroup, O, x);
50-
}
51-
};
52-
53-
struct maximum {
54-
template <typename T, __spv::GroupOperation O>
55-
static typename std::enable_if<
42+
calc(T x, intel::minimum<T> op) {
43+
return __spirv_GroupUMin(__spv::Scope::Subgroup, O, x);
44+
}
45+
46+
template <typename T, __spv::GroupOperation O>
47+
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
48+
calc(T x, intel::minimum<T> op) {
49+
return __spirv_GroupFMin(__spv::Scope::Subgroup, O, x);
50+
}
51+
52+
template <typename T, __spv::GroupOperation O>
53+
static typename std::enable_if<
5654
!detail::is_floating_point<T>::value && std::is_signed<T>::value, T>::type
57-
calc(T x) {
58-
return __spirv_GroupSMax(__spv::Scope::Subgroup, O, x);
59-
}
55+
calc(T x, intel::maximum<T> op) {
56+
return __spirv_GroupSMax(__spv::Scope::Subgroup, O, x);
57+
}
6058

61-
template <typename T, __spv::GroupOperation O>
62-
static typename std::enable_if<
59+
template <typename T, __spv::GroupOperation O>
60+
static typename std::enable_if<
6361
!detail::is_floating_point<T>::value && std::is_unsigned<T>::value, T>::type
64-
calc(T x) {
65-
return __spirv_GroupUMax(__spv::Scope::Subgroup, O, x);
66-
}
62+
calc(T x, intel::maximum<T> op) {
63+
return __spirv_GroupUMax(__spv::Scope::Subgroup, O, x);
64+
}
65+
66+
template <typename T, __spv::GroupOperation O>
67+
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
68+
calc(T x, intel::maximum<T> op) {
69+
return __spirv_GroupFMax(__spv::Scope::Subgroup, O, x);
70+
}
71+
72+
template <typename T, __spv::GroupOperation O>
73+
static typename std::enable_if<
74+
!detail::is_floating_point<T>::value && std::is_integral<T>::value, T>::type
75+
calc(T x, intel::plus<T> op) {
76+
return __spirv_GroupIAdd<T>(__spv::Scope::Subgroup, O, x);
77+
}
6778

68-
template <typename T, __spv::GroupOperation O>
69-
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
70-
calc(T x) {
71-
return __spirv_GroupFMax(__spv::Scope::Subgroup, O, x);
72-
}
73-
};
79+
template <typename T, __spv::GroupOperation O>
80+
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
81+
calc(T x, intel::plus<T> op) {
82+
return __spirv_GroupFAdd<T>(__spv::Scope::Subgroup, O, x);
83+
}
84+
85+
template <typename T, __spv::GroupOperation O,
86+
template <typename> class BinaryOperation>
87+
static T calc(T x, BinaryOperation<void>) {
88+
return calc<T, O>(x, BinaryOperation<T>());
89+
}
90+
91+
} // namespace detail
92+
93+
namespace intel {
7494

75-
struct plus {
76-
template <typename T, __spv::GroupOperation O>
77-
static typename std::enable_if<
78-
!detail::is_floating_point<T>::value && std::is_integral<T>::value, T>::type
79-
calc(T x) {
80-
return __spirv_GroupIAdd<T>(__spv::Scope::Subgroup, O, x);
81-
}
82-
template <typename T, __spv::GroupOperation O>
83-
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
84-
calc(T x) {
85-
return __spirv_GroupFAdd<T>(__spv::Scope::Subgroup, O, x);
86-
}
87-
};
8895
struct sub_group {
8996
/* --- common interface members --- */
9097

@@ -120,7 +127,7 @@ struct sub_group {
120127

121128
template <typename T>
122129
using EnableIfIsScalarArithmetic = detail::enable_if_t<
123-
!is_vec<T>::value && detail::is_arithmetic<T>::value, T>;
130+
!detail::is_vec<T>::value && detail::is_arithmetic<T>::value, T>;
124131

125132
/* --- collectives --- */
126133

@@ -131,20 +138,45 @@ struct sub_group {
131138
}
132139

133140
template <typename T, class BinaryOperation>
134-
T reduce(EnableIfIsScalarArithmetic<T> x) const {
135-
return BinaryOperation::template calc<T, __spv::GroupOperation::Reduce>(x);
141+
EnableIfIsScalarArithmetic<T> reduce(T x, BinaryOperation op) const {
142+
return detail::calc<T, __spv::GroupOperation::Reduce>(x, op);
143+
}
144+
145+
template <typename T, class BinaryOperation>
146+
EnableIfIsScalarArithmetic<T> reduce(T x, T init, BinaryOperation op) const {
147+
return op(init, reduce(x, op));
148+
}
149+
150+
template <typename T, class BinaryOperation>
151+
EnableIfIsScalarArithmetic<T> exclusive_scan(T x, BinaryOperation op) const {
152+
return detail::calc<T, __spv::GroupOperation::ExclusiveScan>(x, op);
153+
}
154+
155+
template <typename T, class BinaryOperation>
156+
EnableIfIsScalarArithmetic<T> exclusive_scan(T x, T init,
157+
BinaryOperation op) const {
158+
if (get_local_id().get(0) == 0) {
159+
x = op(init, x);
160+
}
161+
T scan = exclusive_scan(x, op);
162+
if (get_local_id().get(0) == 0) {
163+
scan = init;
164+
}
165+
return scan;
136166
}
137167

138168
template <typename T, class BinaryOperation>
139-
T exclusive_scan(EnableIfIsScalarArithmetic<T> x) const {
140-
return BinaryOperation::template
141-
calc<T, __spv::GroupOperation::ExclusiveScan>(x);
169+
EnableIfIsScalarArithmetic<T> inclusive_scan(T x, BinaryOperation op) const {
170+
return detail::calc<T, __spv::GroupOperation::InclusiveScan>(x, op);
142171
}
143172

144173
template <typename T, class BinaryOperation>
145-
T inclusive_scan(EnableIfIsScalarArithmetic<T> x) const {
146-
return BinaryOperation::template
147-
calc<T, __spv::GroupOperation::InclusiveScan>(x);
174+
EnableIfIsScalarArithmetic<T> inclusive_scan(T x, BinaryOperation op,
175+
T init) const {
176+
if (get_local_id().get(0) == 0) {
177+
x = op(init, x);
178+
}
179+
return inclusive_scan(x, op);
148180
}
149181

150182
/* --- one - input shuffles --- */
@@ -157,7 +189,7 @@ struct sub_group {
157189
}
158190

159191
template <typename T>
160-
typename std::enable_if<is_vec<T>::value, T>::type
192+
typename std::enable_if<detail::is_vec<T>::value, T>::type
161193
shuffle(T x, id<1> local_id) const {
162194
return __spirv_SubgroupShuffleINTEL((typename T::vector_t)x,
163195
local_id.get(0));
@@ -170,7 +202,7 @@ struct sub_group {
170202
}
171203

172204
template <typename T>
173-
typename std::enable_if<is_vec<T>::value, T>::type
205+
typename std::enable_if<detail::is_vec<T>::value, T>::type
174206
shuffle_down(T x, uint32_t delta) const {
175207
return shuffle_down(x, x, delta);
176208
}
@@ -182,7 +214,7 @@ struct sub_group {
182214
}
183215

184216
template <typename T>
185-
typename std::enable_if<is_vec<T>::value, T>::type
217+
typename std::enable_if<detail::is_vec<T>::value, T>::type
186218
shuffle_up(T x, uint32_t delta) const {
187219
return shuffle_up(x, x, delta);
188220
}
@@ -194,7 +226,7 @@ struct sub_group {
194226
}
195227

196228
template <typename T>
197-
typename std::enable_if<is_vec<T>::value, T>::type
229+
typename std::enable_if<detail::is_vec<T>::value, T>::type
198230
shuffle_xor(T x, id<1> value) const {
199231
return __spirv_SubgroupShuffleXorINTEL((typename T::vector_t)x,
200232
(uint32_t)value.get(0));
@@ -210,7 +242,7 @@ struct sub_group {
210242
}
211243

212244
template <typename T>
213-
typename std::enable_if<is_vec<T>::value, T>::type
245+
typename std::enable_if<detail::is_vec<T>::value, T>::type
214246
shuffle(T x, T y, id<1> local_id) const {
215247
return __spirv_SubgroupShuffleDownINTEL(
216248
(typename T::vector_t)x, (typename T::vector_t)y,
@@ -224,7 +256,7 @@ struct sub_group {
224256
}
225257

226258
template <typename T>
227-
typename std::enable_if<is_vec<T>::value, T>::type
259+
typename std::enable_if<detail::is_vec<T>::value, T>::type
228260
shuffle_down(T current, T next, uint32_t delta) const {
229261
return __spirv_SubgroupShuffleDownINTEL(
230262
(typename T::vector_t)current, (typename T::vector_t)next, delta);
@@ -237,7 +269,7 @@ struct sub_group {
237269
}
238270

239271
template <typename T>
240-
typename std::enable_if<is_vec<T>::value, T>::type
272+
typename std::enable_if<detail::is_vec<T>::value, T>::type
241273
shuffle_up(T previous, T current, uint32_t delta) const {
242274
return __spirv_SubgroupShuffleUpINTEL(
243275
(typename T::vector_t)previous, (typename T::vector_t)current, delta);

0 commit comments

Comments
 (0)