@@ -392,18 +392,17 @@ class reducer<T, BinaryOperation,
392392// / implementation classes. It is needed to detect the reduction classes.
393393class reduction_impl_base {};
394394
395- // / Predicate returning true if and only if 'FirstT' is a reduction class and
396- // / all types except the last one from 'RestT' are reductions as well.
397- template <typename FirstT, typename ... RestT>
398- struct are_all_but_last_reductions {
395+ // / Predicate returning true if all template type parameters except the last one
396+ // / are reductions.
397+ template <typename FirstT, typename ... RestT> struct AreAllButLastReductions {
399398 static constexpr bool value =
400399 std::is_base_of<reduction_impl_base, FirstT>::value &&
401- are_all_but_last_reductions <RestT...>::value;
400+ AreAllButLastReductions <RestT...>::value;
402401};
403402
404- // / Helper specialization of are_all_but_last_reductions for one element only.
405- // / Returns true if the last and only typename is not a reduction.
406- template <typename T> struct are_all_but_last_reductions <T> {
403+ // / Helper specialization of AreAllButLastReductions for one element only.
404+ // / Returns true if the template parameter is not a reduction.
405+ template <typename T> struct AreAllButLastReductions <T> {
407406 static constexpr bool value = !std::is_base_of<reduction_impl_base, T>::value;
408407};
409408
@@ -1097,9 +1096,11 @@ reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
10971096// / the reductions for which a local accessors are needed, this function creates
10981097// / those local accessors and returns a tuple consisting of them.
10991098template <typename ... Reductions, size_t ... Is>
1100- std::tuple<typename Reductions::local_accessor_type...>
1101- createReduLocalAccs (size_t Size, handler &CGH, std::index_sequence<Is...>) {
1102- return {Reductions::getReadWriteLocalAcc (Size, CGH)...};
1099+ auto createReduLocalAccs (size_t Size, handler &CGH,
1100+ std::index_sequence<Is...>) {
1101+ return std::make_tuple (
1102+ std::tuple_element_t <Is, std::tuple<Reductions...>>::getReadWriteLocalAcc (
1103+ Size, CGH)...);
11031104}
11041105
11051106// / For the given 'Reductions' types pack and indices enumerating them this
@@ -1154,7 +1155,7 @@ void callReduUserKernelFunc(KernelType KernelFunc, nd_item<Dims> NDIt,
11541155 KernelFunc (NDIt, std::get<Is>(Reducers)...);
11551156}
11561157
1157- template <bool UniformPow2WG , typename ... LocalAccT, typename ... ReducerT,
1158+ template <bool Pow2WG , typename ... LocalAccT, typename ... ReducerT,
11581159 typename ... ResultT, size_t ... Is>
11591160void initReduLocalAccs (size_t LID, size_t WGSize,
11601161 std::tuple<LocalAccT...> LocalAccs,
@@ -1163,7 +1164,11 @@ void initReduLocalAccs(size_t LID, size_t WGSize,
11631164 std::index_sequence<Is...>) {
11641165 std::tie (std::get<Is>(LocalAccs)[LID]...) =
11651166 std::make_tuple (std::get<Is>(Reducers).MValue ...);
1166- if (!UniformPow2WG)
1167+
1168+ // For work-groups, which size is not power of two, local accessors have
1169+ // an additional element with index WGSize that is used by the tree-reduction
1170+ // algorithm. Initialize those additional elements with identity values here.
1171+ if (!Pow2WG)
11671172 std::tie (std::get<Is>(LocalAccs)[WGSize]...) =
11681173 std::make_tuple (std::get<Is>(Identities)...);
11691174}
@@ -1175,12 +1180,22 @@ void initReduLocalAccs(size_t LID, size_t GID, size_t NWorkItems, size_t WGSize,
11751180 std::tuple<LocalAccT...> InputAccs,
11761181 const std::tuple<ResultT...> Identities,
11771182 std::index_sequence<Is...>) {
1183+ // Normally, the local accessors are initialized with elements from the input
1184+ // accessors. The exception is the case when (GID >= NWorkItems), which
1185+ // possible only when UniformPow2WG is false. For that case the elements of
1186+ // local accessors are initialized with identity value, so they would not
1187+ // give any impact into the final partial sums during the tree-reduction
1188+ // algorithm work.
11781189 if (UniformPow2WG || GID < NWorkItems)
11791190 std::tie (std::get<Is>(LocalAccs)[LID]...) =
11801191 std::make_tuple (std::get<Is>(InputAccs)[GID]...);
11811192 else
11821193 std::tie (std::get<Is>(LocalAccs)[LID]...) =
11831194 std::make_tuple (std::get<Is>(Identities)...);
1195+
1196+ // For work-groups, which size is not power of two, local accessors have
1197+ // an additional element with index WGSize that is used by the tree-reduction
1198+ // algorithm. Initialize those additional elements with identity values here.
11841199 if (!UniformPow2WG)
11851200 std::tie (std::get<Is>(LocalAccs)[WGSize]...) =
11861201 std::make_tuple (std::get<Is>(Identities)...);
@@ -1196,7 +1211,7 @@ void reduceReduLocalAccs(size_t IndexA, size_t IndexB,
11961211 std::get<Is>(LocalAccs)[IndexB]))...);
11971212}
11981213
1199- template <bool UniformPow2WG , typename ... Reductions, typename ... OutAccT,
1214+ template <bool Pow2WG , typename ... Reductions, typename ... OutAccT,
12001215 typename ... LocalAccT, typename ... BOPsT, size_t ... Is,
12011216 size_t ... RWIs>
12021217void writeReduSumsToOutAccs (size_t OutAccIndex, size_t WGSize,
@@ -1214,11 +1229,16 @@ void writeReduSumsToOutAccs(size_t OutAccIndex, size_t WGSize,
12141229 std::tuple_element_t <RWIs, std::tuple<Reductions...>>::getOutPointer (
12151230 std::get<RWIs>(OutAccs))[OutAccIndex])...);
12161231
1217- if (UniformPow2WG) {
1232+ if (Pow2WG) {
1233+ // The partial sums for the work-group are stored in 0-th elements of local
1234+ // accessors. Simply write those sums to output accessors.
12181235 std::tie (std::tuple_element_t <Is, std::tuple<Reductions...>>::getOutPointer (
12191236 std::get<Is>(OutAccs))[OutAccIndex]...) =
12201237 std::make_tuple (std::get<Is>(LocalAccs)[0 ]...);
12211238 } else {
1239+ // Each of local accessors keeps two partial sums: in 0-th and WGsize-th
1240+ // elements. Combine them into final partial sums and write to output
1241+ // accessors.
12221242 std::tie (std::tuple_element_t <Is, std::tuple<Reductions...>>::getOutPointer (
12231243 std::get<Is>(OutAccs))[OutAccIndex]...) =
12241244 std::make_tuple (std::get<Is>(BOPs)(std::get<Is>(LocalAccs)[0 ],
@@ -1300,15 +1320,15 @@ constexpr auto filterSequence(FunctorT F, std::index_sequence<Is...> Indices) {
13001320 return filterSequenceHelper<T...>(F, Indices);
13011321}
13021322
1303- template <typename KernelName, bool UniformPow2WG , bool IsOneWG,
1304- typename KernelType, int Dims, typename ... Reductions, size_t ... Is>
1323+ template <typename KernelName, bool Pow2WG , bool IsOneWG, typename KernelType ,
1324+ int Dims, typename ... Reductions, size_t ... Is>
13051325void reduCGFuncImpl (handler &CGH, KernelType KernelFunc,
13061326 const nd_range<Dims> &Range,
13071327 std::tuple<Reductions...> &ReduTuple,
13081328 std::index_sequence<Is...> ReduIndices) {
13091329
13101330 size_t WGSize = Range.get_local_range ().size ();
1311- size_t LocalAccSize = WGSize + (UniformPow2WG ? 0 : 1 );
1331+ size_t LocalAccSize = WGSize + (Pow2WG ? 0 : 1 );
13121332 auto LocalAccsTuple =
13131333 createReduLocalAccs<Reductions...>(LocalAccSize, CGH, ReduIndices);
13141334
@@ -1318,10 +1338,8 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
13181338 auto IdentitiesTuple = getReduIdentities (ReduTuple, ReduIndices);
13191339 auto BOPsTuple = getReduBOPs (ReduTuple, ReduIndices);
13201340
1321- using Name =
1322- typename get_reduction_main_kernel_name_t <KernelName, KernelType,
1323- UniformPow2WG, IsOneWG,
1324- decltype (OutAccsTuple)>::name;
1341+ using Name = typename get_reduction_main_kernel_name_t <
1342+ KernelName, KernelType, Pow2WG, IsOneWG, decltype (OutAccsTuple)>::name;
13251343 CGH.parallel_for <Name>(Range, [=](nd_item<Dims> NDIt) {
13261344 auto ReduIndices = std::index_sequence_for<Reductions...>();
13271345 auto ReducersTuple =
@@ -1332,8 +1350,8 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
13321350
13331351 size_t WGSize = NDIt.get_local_range ().size ();
13341352 size_t LID = NDIt.get_local_linear_id ();
1335- initReduLocalAccs<UniformPow2WG >(LID, WGSize, LocalAccsTuple, ReducersTuple,
1336- IdentitiesTuple, ReduIndices);
1353+ initReduLocalAccs<Pow2WG >(LID, WGSize, LocalAccsTuple, ReducersTuple,
1354+ IdentitiesTuple, ReduIndices);
13371355 NDIt.barrier ();
13381356
13391357 size_t PrevStep = WGSize;
@@ -1342,7 +1360,7 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
13421360 // LocalReds[LID] = BOp(LocalReds[LID], LocalReds[LID + CurStep]);
13431361 reduceReduLocalAccs (LID, LID + CurStep, LocalAccsTuple, BOPsTuple,
13441362 ReduIndices);
1345- } else if (!UniformPow2WG && LID == CurStep && (PrevStep & 0x1 )) {
1363+ } else if (!Pow2WG && LID == CurStep && (PrevStep & 0x1 )) {
13461364 // LocalReds[WGSize] = BOp(LocalReds[WGSize], LocalReds[PrevStep - 1]);
13471365 reduceReduLocalAccs (WGSize, PrevStep - 1 , LocalAccsTuple, BOPsTuple,
13481366 ReduIndices);
@@ -1363,7 +1381,7 @@ void reduCGFuncImpl(handler &CGH, KernelType KernelFunc,
13631381 Predicate;
13641382 auto RWReduIndices =
13651383 filterSequence<Reductions...>(Predicate, ReduIndices);
1366- writeReduSumsToOutAccs<UniformPow2WG >(
1384+ writeReduSumsToOutAccs<Pow2WG >(
13671385 GrID, WGSize, (std::tuple<Reductions...> *)nullptr , OutAccsTuple,
13681386 LocalAccsTuple, BOPsTuple, ReduIndices, RWReduIndices);
13691387 }
@@ -1376,21 +1394,18 @@ void reduCGFunc(handler &CGH, KernelType KernelFunc,
13761394 const nd_range<Dims> &Range,
13771395 std::tuple<Reductions...> &ReduTuple,
13781396 std::index_sequence<Is...> ReduIndices) {
1379- size_t NWorkItems = Range.get_global_range ().size ();
13801397 size_t WGSize = Range.get_local_range ().size ();
13811398 size_t NWorkGroups = Range.get_group_range ().size ();
1382-
13831399 bool Pow2WG = (WGSize & (WGSize - 1 )) == 0 ;
1384- bool HasUniformWG = Pow2WG && (NWorkGroups * WGSize == NWorkItems);
13851400 if (NWorkGroups == 1 ) {
1386- if (HasUniformWG )
1401+ if (Pow2WG )
13871402 reduCGFuncImpl<KernelName, true , true >(CGH, KernelFunc, Range, ReduTuple,
13881403 ReduIndices);
13891404 else
13901405 reduCGFuncImpl<KernelName, false , true >(CGH, KernelFunc, Range, ReduTuple,
13911406 ReduIndices);
13921407 } else {
1393- if (HasUniformWG )
1408+ if (Pow2WG )
13941409 reduCGFuncImpl<KernelName, true , false >(CGH, KernelFunc, Range, ReduTuple,
13951410 ReduIndices);
13961411 else
0 commit comments