Skip to content

Commit cc6e4ae

Browse files
committed
[SYCL] Update sub-group reduce/scan syntax
Aligns the sub-group implementation with the extension doc: - Enables reduce and scan to take functors as arguments - Adds OpenCL-like reduce(x, op) overload to extension doc Also clarifies interpretation of init values in documentation. Signed-off-by: John Pennycook <john.pennycook@intel.com>
1 parent 5be0314 commit cc6e4ae

File tree

5 files changed

+231
-103
lines changed

5 files changed

+231
-103
lines changed

sycl/doc/extensions/SubGroupNDRange/SubGroupNDRange.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,12 @@ The `plus`, `minimum` and `maximum` functors in the `cl::sycl` namespace corresp
143143
|Member functions|Description|
144144
|----------------|-----------|
145145
| `template <typename T>T broadcast(T x, id<1> local_id) const` | Broadcast the value of `x` from the work-item with the specified id to all work-items within the sub-group. The value of `local_id` must be the same for all work-items in the sub-group. |
146-
| `template <typename T, class BinaryOp>T reduce(T x, T init, BinaryOp binary_op) const` | Combine the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. |
147-
| `template <typename T, class BinaryOp>T exclusive_scan(T x, T init, BinaryOp binary_op) const` | Perform an exclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the exclusive scan of the first `i` work-items in the sub-group. |
148-
| `template <typename T, class BinaryOp>T inclusive_scan(T x, BinaryOp binary_op, T init) const` | Perform an inclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the inclusive scan of the first `i` work-items in the sub-group. |
146+
| `template <typename T, class BinaryOp>T reduce(T x, BinaryOp binary_op) const` | Combine the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. |
147+
| `template <typename T, class BinaryOp>T reduce(T x, T init, BinaryOp binary_op) const` | Combine the values of `x` from all work-items in the sub-group using an initial value of `init` and the specified operator, which must be one of: `plus`, `minimum` or `maximum`. |
148+
| `template <typename T, class BinaryOp>T exclusive_scan(T x, BinaryOp binary_op) const` | Perform an exclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the exclusive scan of the first `i` work-items in the sub-group. The initial value is the identity value of the operator. |
149+
| `template <typename T, class BinaryOp>T exclusive_scan(T x, T init, BinaryOp binary_op) const` | Perform an exclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the exclusive scan of the first `i` work-items in the sub-group. The initial value is specified by `init`. |
150+
| `template <typename T, class BinaryOp>T inclusive_scan(T x, BinaryOp binary_op) const` | Perform an inclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the inclusive scan of the first `i` work-items in the sub-group. |
151+
| `template <typename T, class BinaryOp>T inclusive_scan(T x, BinaryOp binary_op, T init) const` | Perform an inclusive scan over the values of `x` from all work-items in the sub-group using the specified operator, which must be one of: `plus`, `minimum` or `maximum`. The value returned on work-item `i` is the inclusive scan of the initial value `init` and the first `i` work-items in the sub-group. |
149152

150153
## Extended Functionality
151154

@@ -214,12 +217,21 @@ struct sub_group {
214217
template <typename T>
215218
T broadcast(T x, id<1> local_id) const;
216219

220+
template <typename T, class BinaryOp>
221+
T reduce(T x, BinaryOp binary_op) const;
222+
217223
template <typename T, class BinaryOp>
218224
T reduce(T x, T init, BinaryOp binary_op) const;
219225

226+
template <typename T, class BinaryOp>
227+
T exclusive_scan(T x, BinaryOp binary_op) const;
228+
220229
template <typename T, class BinaryOp>
221230
T exclusive_scan(T x, T init, BinaryOp binary_op) const;
222231

232+
template <typename T, class BinaryOp>
233+
T inclusive_scan(T x, BinaryOp binary_op) const;
234+
223235
template <typename T, class BinaryOp>
224236
T inclusive_scan(T x, BinaryOp binary_op, T init) const;
225237

sycl/include/CL/sycl/intel/sub_group.hpp

Lines changed: 92 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -22,69 +22,80 @@ namespace cl {
2222
namespace sycl {
2323
template <typename T, access::address_space Space> class multi_ptr;
2424
namespace intel {
25-
template <typename>
2625

27-
struct is_vec : std::false_type {};
26+
template <typename> struct is_vec : std::false_type {};
2827
template <typename T, std::size_t N>
2928
struct is_vec<cl::sycl::vec<T, N>> : std::true_type {};
3029

31-
struct minimum {
32-
template <typename T, __spv::GroupOperation O>
33-
static typename std::enable_if<
34-
!detail::is_floating_point<T>::value && std::is_signed<T>::value, T>::type
35-
calc(T x) {
36-
return __spirv_GroupSMin(__spv::Scope::Subgroup, O, x);
30+
template <typename T> struct minimum {
31+
T operator()(const T &lhs, const T &rhs) const {
32+
return (lhs <= rhs) ? lhs : rhs;
3733
}
34+
};
3835

39-
template <typename T, __spv::GroupOperation O>
40-
static typename std::enable_if<
41-
!detail::is_floating_point<T>::value && std::is_unsigned<T>::value, T>::type
42-
calc(T x) {
43-
return __spirv_GroupUMin(__spv::Scope::Subgroup, O, x);
36+
template <typename T> struct maximum {
37+
T operator()(const T &lhs, const T &rhs) const {
38+
return (lhs >= rhs) ? lhs : rhs;
4439
}
40+
};
4541

46-
template <typename T, __spv::GroupOperation O>
47-
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
48-
calc(T x) {
49-
return __spirv_GroupFMin(__spv::Scope::Subgroup, O, x);
50-
}
42+
template <typename T> struct plus {
43+
T operator()(const T &lhs, const T &rhs) const { return lhs + rhs; }
5144
};
5245

53-
struct maximum {
54-
template <typename T, __spv::GroupOperation O>
55-
static typename std::enable_if<
46+
template <typename T, __spv::GroupOperation O>
47+
static typename std::enable_if<
5648
!detail::is_floating_point<T>::value && std::is_signed<T>::value, T>::type
57-
calc(T x) {
58-
return __spirv_GroupSMax(__spv::Scope::Subgroup, O, x);
59-
}
49+
calc(T x, minimum<T> op) {
50+
return __spirv_GroupSMin(__spv::Scope::Subgroup, O, x);
51+
}
6052

61-
template <typename T, __spv::GroupOperation O>
62-
static typename std::enable_if<
53+
template <typename T, __spv::GroupOperation O>
54+
static typename std::enable_if<
6355
!detail::is_floating_point<T>::value && std::is_unsigned<T>::value, T>::type
64-
calc(T x) {
65-
return __spirv_GroupUMax(__spv::Scope::Subgroup, O, x);
66-
}
67-
68-
template <typename T, __spv::GroupOperation O>
69-
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
70-
calc(T x) {
71-
return __spirv_GroupFMax(__spv::Scope::Subgroup, O, x);
72-
}
73-
};
56+
calc(T x, minimum<T> op) {
57+
return __spirv_GroupUMin(__spv::Scope::Subgroup, O, x);
58+
}
59+
60+
template <typename T, __spv::GroupOperation O>
61+
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
62+
calc(T x, minimum<T> op) {
63+
return __spirv_GroupFMin(__spv::Scope::Subgroup, O, x);
64+
}
65+
66+
template <typename T, __spv::GroupOperation O>
67+
static typename std::enable_if<
68+
!detail::is_floating_point<T>::value && std::is_signed<T>::value, T>::type
69+
calc(T x, maximum<T> op) {
70+
return __spirv_GroupSMax(__spv::Scope::Subgroup, O, x);
71+
}
7472

75-
struct plus {
76-
template <typename T, __spv::GroupOperation O>
77-
static typename std::enable_if<
73+
template <typename T, __spv::GroupOperation O>
74+
static typename std::enable_if<
75+
!detail::is_floating_point<T>::value && std::is_unsigned<T>::value, T>::type
76+
calc(T x, maximum<T> op) {
77+
return __spirv_GroupUMax(__spv::Scope::Subgroup, O, x);
78+
}
79+
80+
template <typename T, __spv::GroupOperation O>
81+
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
82+
calc(T x, maximum<T> op) {
83+
return __spirv_GroupFMax(__spv::Scope::Subgroup, O, x);
84+
}
85+
86+
template <typename T, __spv::GroupOperation O>
87+
static typename std::enable_if<
7888
!detail::is_floating_point<T>::value && std::is_integral<T>::value, T>::type
79-
calc(T x) {
80-
return __spirv_GroupIAdd<T>(__spv::Scope::Subgroup, O, x);
81-
}
82-
template <typename T, __spv::GroupOperation O>
83-
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
84-
calc(T x) {
85-
return __spirv_GroupFAdd<T>(__spv::Scope::Subgroup, O, x);
86-
}
87-
};
89+
calc(T x, plus<T> op) {
90+
return __spirv_GroupIAdd<T>(__spv::Scope::Subgroup, O, x);
91+
}
92+
93+
template <typename T, __spv::GroupOperation O>
94+
static typename std::enable_if<detail::is_floating_point<T>::value, T>::type
95+
calc(T x, plus<T> op) {
96+
return __spirv_GroupFAdd<T>(__spv::Scope::Subgroup, O, x);
97+
}
98+
8899
struct sub_group {
89100
/* --- common interface members --- */
90101

@@ -131,20 +142,45 @@ struct sub_group {
131142
}
132143

133144
template <typename T, class BinaryOperation>
134-
T reduce(EnableIfIsScalarArithmetic<T> x) const {
135-
return BinaryOperation::template calc<T, __spv::GroupOperation::Reduce>(x);
145+
EnableIfIsScalarArithmetic<T> reduce(T x, BinaryOperation op) const {
146+
return calc<T, __spv::GroupOperation::Reduce>(x, op);
147+
}
148+
149+
template <typename T, class BinaryOperation>
150+
EnableIfIsScalarArithmetic<T> reduce(T x, T init, BinaryOperation op) const {
151+
return op(init, reduce(x, op));
152+
}
153+
154+
template <typename T, class BinaryOperation>
155+
EnableIfIsScalarArithmetic<T> exclusive_scan(T x, BinaryOperation op) const {
156+
return calc<T, __spv::GroupOperation::ExclusiveScan>(x, op);
136157
}
137158

138159
template <typename T, class BinaryOperation>
139-
T exclusive_scan(EnableIfIsScalarArithmetic<T> x) const {
140-
return BinaryOperation::template
141-
calc<T, __spv::GroupOperation::ExclusiveScan>(x);
160+
EnableIfIsScalarArithmetic<T> exclusive_scan(T x, T init,
161+
BinaryOperation op) const {
162+
if (get_local_id().get(0) == 0) {
163+
x = op(init, x);
164+
}
165+
T scan = exclusive_scan(x, op);
166+
if (get_local_id().get(0) == 0) {
167+
scan = init;
168+
}
169+
return scan;
142170
}
143171

144172
template <typename T, class BinaryOperation>
145-
T inclusive_scan(EnableIfIsScalarArithmetic<T> x) const {
146-
return BinaryOperation::template
147-
calc<T, __spv::GroupOperation::InclusiveScan>(x);
173+
EnableIfIsScalarArithmetic<T> inclusive_scan(T x, BinaryOperation op) const {
174+
return calc<T, __spv::GroupOperation::InclusiveScan>(x, op);
175+
}
176+
177+
template <typename T, class BinaryOperation>
178+
EnableIfIsScalarArithmetic<T> inclusive_scan(T x, BinaryOperation op,
179+
T init) const {
180+
if (get_local_id().get(0) == 0) {
181+
x = op(init, x);
182+
}
183+
return inclusive_scan(x, op);
148184
}
149185

150186
/* --- one - input shuffles --- */

sycl/include/CL/sycl/intel/sub_group_host.hpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,22 @@ namespace cl {
1818
namespace sycl {
1919
template <typename T, access::address_space Space> class multi_ptr;
2020
namespace intel {
21-
struct minimum {};
22-
struct maximum {};
23-
struct plus {};
21+
22+
template <typename T> struct minimum {
23+
T operator()(const T &lhs, const T &rhs) const {
24+
return (lhs <= rhs) ? lhs : rhs;
25+
}
26+
};
27+
28+
template <typename T> struct maximum {
29+
T operator()(const T &lhs, const T &rhs) const {
30+
return (lhs >= rhs) ? lhs : rhs;
31+
}
32+
};
33+
34+
template <typename T> struct plus {
35+
T operator()(const T &lhs, const T &rhs) const { return lhs + rhs; }
36+
};
2437

2538
struct sub_group {
2639
/* --- common interface members --- */
@@ -64,15 +77,33 @@ struct sub_group {
6477
throw runtime_error("Subgroups are not supported on host device. ");
6578
}
6679

67-
template <typename T, class BinaryOperation> T reduce(T x) const {
80+
template <typename T, class BinaryOperation>
81+
T reduce(T x, BinaryOperation op) const {
82+
throw runtime_error("Subgroups are not supported on host device. ");
83+
}
84+
85+
template <typename T, class BinaryOperation>
86+
T reduce(T x, T init, BinaryOperation op) const {
87+
throw runtime_error("Subgroups are not supported on host device. ");
88+
}
89+
90+
template <typename T, class BinaryOperation>
91+
T exclusive_scan(T x, BinaryOperation op) const {
92+
throw runtime_error("Subgroups are not supported on host device. ");
93+
}
94+
95+
template <typename T, class BinaryOperation>
96+
T exclusive_scan(T x, T init, BinaryOperation op) const {
6897
throw runtime_error("Subgroups are not supported on host device. ");
6998
}
7099

71-
template <typename T, class BinaryOperation> T exclusive_scan(T x) const {
100+
template <typename T, class BinaryOperation>
101+
T inclusive_scan(T x, BinaryOperation op) const {
72102
throw runtime_error("Subgroups are not supported on host device. ");
73103
}
74104

75-
template <typename T, class BinaryOperation> T inclusive_scan(T x) const {
105+
template <typename T, class BinaryOperation>
106+
T inclusive_scan(T x, BinaryOperation op, T init) const {
76107
throw runtime_error("Subgroups are not supported on host device. ");
77108
}
78109

sycl/test/sub_group/reduce.cpp

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
#include "helper.hpp"
1616
#include <CL/sycl.hpp>
17-
template <typename T> class sycl_subgr;
17+
template <typename T, bool init> class sycl_subgr;
1818
using namespace cl::sycl;
19-
template <typename T> void check(queue &Queue, size_t G = 240, size_t L = 60) {
19+
template <typename T, bool init>
20+
void check(queue &Queue, size_t G = 240, size_t L = 60) {
2021
try {
2122
nd_range<1> NdRange(G, L);
2223
buffer<T> minbuf(G);
@@ -26,14 +27,26 @@ template <typename T> void check(queue &Queue, size_t G = 240, size_t L = 60) {
2627
auto minacc = minbuf.template get_access<access::mode::read_write>(cgh);
2728
auto maxacc = maxbuf.template get_access<access::mode::read_write>(cgh);
2829
auto addacc = addbuf.template get_access<access::mode::read_write>(cgh);
29-
cgh.parallel_for<sycl_subgr<T>>(NdRange, [=](nd_item<1> NdItem) {
30+
cgh.parallel_for<sycl_subgr<T, init>>(NdRange, [=](nd_item<1> NdItem) {
3031
intel::sub_group sg = NdItem.get_sub_group();
31-
minacc[NdItem.get_global_id()] =
32-
sg.reduce<T, intel::minimum>(NdItem.get_global_id(0));
33-
maxacc[NdItem.get_global_id()] =
34-
sg.reduce<T, intel::maximum>(NdItem.get_global_id(0));
35-
addacc[NdItem.get_global_id()] =
36-
sg.reduce<T, intel::plus>(NdItem.get_global_id(0));
32+
if (init) {
33+
minacc[NdItem.get_global_id()] = sg.reduce(
34+
static_cast<T>(NdItem.get_global_id(0)),
35+
static_cast<T>(NdItem.get_global_range(0)), intel::minimum<T>());
36+
maxacc[NdItem.get_global_id()] =
37+
sg.reduce(static_cast<T>(NdItem.get_global_id(0)),
38+
static_cast<T>(0), intel::maximum<T>());
39+
addacc[NdItem.get_global_id()] =
40+
sg.reduce(static_cast<T>(NdItem.get_global_id(0)),
41+
static_cast<T>(0), intel::plus<T>());
42+
} else {
43+
minacc[NdItem.get_global_id()] = sg.reduce(
44+
static_cast<T>(NdItem.get_global_id(0)), intel::minimum<T>());
45+
maxacc[NdItem.get_global_id()] = sg.reduce(
46+
static_cast<T>(NdItem.get_global_id(0)), intel::maximum<T>());
47+
addacc[NdItem.get_global_id()] = sg.reduce(
48+
static_cast<T>(NdItem.get_global_id(0)), intel::plus<T>());
49+
}
3750
});
3851
});
3952
auto minacc = minbuf.template get_access<access::mode::read_write>();
@@ -71,19 +84,26 @@ int main() {
7184
std::cout << "Skipping test\n";
7285
return 0;
7386
}
74-
check<int>(Queue);
75-
check<unsigned int>(Queue);
76-
check<long>(Queue);
77-
check<unsigned long>(Queue);
78-
check<float>(Queue);
87+
check<int, true>(Queue);
88+
check<int, false>(Queue);
89+
check<unsigned int, true>(Queue);
90+
check<unsigned int, false>(Queue);
91+
check<long, true>(Queue);
92+
check<long, false>(Queue);
93+
check<unsigned long, true>(Queue);
94+
check<unsigned long, false>(Queue);
95+
check<float, true>(Queue);
96+
check<float, false>(Queue);
7997
// reduce half type is not supported in OCL CPU RT
8098
#ifdef SG_GPU
8199
if (Queue.get_device().has_extension("cl_khr_fp16")) {
82-
check<cl::sycl::half>(Queue);
100+
check<cl::sycl::half, true>(Queue);
101+
check<cl::sycl::half, false>(Queue);
83102
}
84103
#endif
85104
if (Queue.get_device().has_extension("cl_khr_fp64")) {
86-
check<double>(Queue);
105+
check<double, true>(Queue);
106+
check<double, false>(Queue);
87107
}
88108
std::cout << "Test passed." << std::endl;
89109
return 0;

0 commit comments

Comments
 (0)