diff --git a/sycl/include/sycl/reduction.hpp b/sycl/include/sycl/reduction.hpp index bc84422033fbe..2a5d61a20b303 100644 --- a/sycl/include/sycl/reduction.hpp +++ b/sycl/include/sycl/reduction.hpp @@ -1568,48 +1568,6 @@ template <> struct NDRangeReduction { } }; -// Auto-dispatch. Must be the last one. -template <> struct NDRangeReduction { - // Some readability aliases, to increase signal/noise ratio below. - template - using Impl = NDRangeReduction; - using S = reduction::strategy; - - template - static void run(handler &CGH, std::shared_ptr &Queue, - nd_range NDRange, PropertiesT &Properties, - Reduction &Redu, KernelType &KernelFunc) { - auto Delegate = [&](auto Impl) { - Impl.template run(CGH, Queue, NDRange, Properties, Redu, - KernelFunc); - }; - - if constexpr (Reduction::has_float64_atomics) { - if (getDeviceFromHandler(CGH).has(aspect::atomic64)) - return Delegate(Impl{}); - - if constexpr (Reduction::has_fast_reduce) - return Delegate(Impl{}); - else - return Delegate(Impl{}); - } else if constexpr (Reduction::has_fast_atomics) { - if constexpr (Reduction::has_fast_reduce) { - return Delegate(Impl{}); - } else { - return Delegate(Impl{}); - } - } else { - if constexpr (Reduction::has_fast_reduce) - return Delegate(Impl{}); - else - return Delegate(Impl{}); - } - - assert(false && "Must be unreachable!"); - } -}; - /// For the given 'Reductions' types pack and indices enumerating them this /// function either creates new temporary accessors for partial sums (if IsOneWG /// is false) or returns user's accessor/USM-pointer if (IsOneWG is true). @@ -2227,21 +2185,109 @@ tuple_select_elements(TupleT Tuple, std::index_sequence) { return {std::get(std::move(Tuple))...}; } +template <> struct NDRangeReduction { + template + static void run(handler &CGH, std::shared_ptr &Queue, + nd_range NDRange, PropertiesT &Properties, + RestT... Rest) { + std::tuple ArgsTuple(Rest...); + constexpr size_t NumArgs = sizeof...(RestT); + auto KernelFunc = std::get(ArgsTuple); + auto ReduIndices = std::make_index_sequence(); + auto ReduTuple = detail::tuple_select_elements(ArgsTuple, ReduIndices); + + size_t LocalMemPerWorkItem = reduGetMemPerWorkItem(ReduTuple, ReduIndices); + // TODO: currently the maximal work group size is determined for the given + // queue/device, while it is safer to use queries to the kernel compiled + // for the device. + size_t MaxWGSize = reduGetMaxWGSize(Queue, LocalMemPerWorkItem); + if (NDRange.get_local_range().size() > MaxWGSize) + throw sycl::runtime_error("The implementation handling parallel_for with" + " reduction requires work group size not bigger" + " than " + + std::to_string(MaxWGSize), + PI_ERROR_INVALID_WORK_GROUP_SIZE); + + reduCGFuncMulti(CGH, KernelFunc, NDRange, Properties, ReduTuple, + ReduIndices); + reduction::finalizeHandler(CGH); + + size_t NWorkItems = NDRange.get_group_range().size(); + while (NWorkItems > 1) { + reduction::withAuxHandler(CGH, [&](handler &AuxHandler) { + NWorkItems = reduAuxCGFunc( + AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices); + }); + } // end while (NWorkItems > 1) + } +}; + +// Auto-dispatch. Must be the last one. +template <> struct NDRangeReduction { + // Some readability aliases, to increase signal/noise ratio below. + template + using Impl = NDRangeReduction; + using Strat = reduction::strategy; + + template + static void run(handler &CGH, std::shared_ptr &Queue, + nd_range NDRange, PropertiesT &Properties, + Reduction &Redu, KernelType &KernelFunc) { + auto Delegate = [&](auto Impl) { + Impl.template run(CGH, Queue, NDRange, Properties, Redu, + KernelFunc); + }; + + if constexpr (Reduction::has_float64_atomics) { + if (getDeviceFromHandler(CGH).has(aspect::atomic64)) + return Delegate(Impl{}); + + if constexpr (Reduction::has_fast_reduce) + return Delegate(Impl{}); + else + return Delegate(Impl{}); + } else if constexpr (Reduction::has_fast_atomics) { + if constexpr (Reduction::has_fast_reduce) { + return Delegate(Impl{}); + } else { + return Delegate(Impl{}); + } + } else { + if constexpr (Reduction::has_fast_reduce) + return Delegate(Impl{}); + else + return Delegate(Impl{}); + } + + assert(false && "Must be unreachable!"); + } + template + static void run(handler &CGH, std::shared_ptr &Queue, + nd_range NDRange, PropertiesT &Properties, + RestT... Rest) { + return Impl::run(CGH, Queue, NDRange, Properties, + Rest...); + } +}; + template + typename PropertiesT, typename... RestT> void reduction_parallel_for(handler &CGH, std::shared_ptr Queue, nd_range NDRange, PropertiesT Properties, - Reduction Redu, KernelType KernelFunc) { - NDRangeReduction::template run( - CGH, Queue, NDRange, Properties, Redu, KernelFunc); + RestT... Rest) { + NDRangeReduction::template run(CGH, Queue, NDRange, + Properties, Rest...); } __SYCL_EXPORT uint32_t reduGetMaxNumConcurrentWorkGroups(std::shared_ptr Queue); -template +template void reduction_parallel_for(handler &CGH, std::shared_ptr Queue, range Range, PropertiesT Properties, @@ -2300,7 +2346,10 @@ void reduction_parallel_for(handler &CGH, KernelFunc(getDelinearizedId(Range, I), Reducer); }; - constexpr auto Strategy = [&]() { + constexpr auto StrategyToUse = [&]() { + if constexpr (Strategy != reduction::strategy::auto_select) + return Strategy; + if constexpr (Reduction::has_fast_reduce) return reduction::strategy::group_reduce_and_last_wg_detection; else if constexpr (Reduction::has_fast_atomics) @@ -2309,57 +2358,8 @@ void reduction_parallel_for(handler &CGH, return reduction::strategy::range_basic; }(); - reduction_parallel_for(CGH, Queue, NDRange, Properties, - Redu, UpdatedKernelFunc); -} - -template <> struct NDRangeReduction { - template - static void run(handler &CGH, std::shared_ptr &Queue, - nd_range NDRange, PropertiesT &Properties, - RestT... Rest) { - std::tuple ArgsTuple(Rest...); - constexpr size_t NumArgs = sizeof...(RestT); - auto KernelFunc = std::get(ArgsTuple); - auto ReduIndices = std::make_index_sequence(); - auto ReduTuple = detail::tuple_select_elements(ArgsTuple, ReduIndices); - - size_t LocalMemPerWorkItem = reduGetMemPerWorkItem(ReduTuple, ReduIndices); - // TODO: currently the maximal work group size is determined for the given - // queue/device, while it is safer to use queries to the kernel compiled - // for the device. - size_t MaxWGSize = reduGetMaxWGSize(Queue, LocalMemPerWorkItem); - if (NDRange.get_local_range().size() > MaxWGSize) - throw sycl::runtime_error("The implementation handling parallel_for with" - " reduction requires work group size not bigger" - " than " + - std::to_string(MaxWGSize), - PI_ERROR_INVALID_WORK_GROUP_SIZE); - - reduCGFuncMulti(CGH, KernelFunc, NDRange, Properties, ReduTuple, - ReduIndices); - reduction::finalizeHandler(CGH); - - size_t NWorkItems = NDRange.get_group_range().size(); - while (NWorkItems > 1) { - reduction::withAuxHandler(CGH, [&](handler &AuxHandler) { - NWorkItems = reduAuxCGFunc( - AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices); - }); - } // end while (NWorkItems > 1) - } -}; - -template -void reduction_parallel_for(handler &CGH, - std::shared_ptr Queue, - nd_range NDRange, PropertiesT Properties, - RestT... Rest) { - constexpr auto Strategy = reduction::strategy::multi; - NDRangeReduction::template run(CGH, Queue, NDRange, - Properties, Rest...); + reduction_parallel_for( + CGH, Queue, NDRange, Properties, Redu, UpdatedKernelFunc); } } // namespace detail diff --git a/sycl/include/sycl/reduction_forward.hpp b/sycl/include/sycl/reduction_forward.hpp index af6145b8b5362..d4bd92d98d5e2 100644 --- a/sycl/include/sycl/reduction_forward.hpp +++ b/sycl/include/sycl/reduction_forward.hpp @@ -44,24 +44,18 @@ inline void finalizeHandler(handler &CGH); template void withAuxHandler(handler &CGH, FunctorTy Func); } // namespace reduction -template -void reduction_parallel_for(handler &CGH, - std::shared_ptr Queue, - range Range, PropertiesT Properties, - Reduction Redu, KernelType KernelFunc); - template void reduction_parallel_for(handler &CGH, std::shared_ptr Queue, - nd_range NDRange, PropertiesT Properties, + range Range, PropertiesT Properties, Reduction Redu, KernelType KernelFunc); -template +template void reduction_parallel_for(handler &CGH, std::shared_ptr Queue, nd_range NDRange, PropertiesT Properties,