diff --git a/sycl/include/sycl/reduction.hpp b/sycl/include/sycl/reduction.hpp index 47dc36c3e0bc6..2144f547cbf61 100644 --- a/sycl/include/sycl/reduction.hpp +++ b/sycl/include/sycl/reduction.hpp @@ -147,6 +147,9 @@ __SYCL_EXPORT size_t reduComputeWGSize(size_t NWorkItems, size_t MaxWGSize, __SYCL_EXPORT size_t reduGetPreferredWGSize(std::shared_ptr &Queue, size_t LocalMemBytesPerWorkItem); +template +class ReducerElement; + /// Helper class for accessing reducer-defined types in CRTP /// May prove to be useful for other things later template struct ReducerTraits; @@ -160,6 +163,7 @@ struct ReducerTraits; }; /// Helper class for accessing internal reducer member functions. @@ -275,10 +279,16 @@ template class combiner { void atomic_combine_impl(T *ReduVarPtr, AtomicFunctor Functor) const { auto reducer = static_cast(this); for (size_t E = 0; E < Extent; ++E) { + const auto &ReducerElem = getReducerAccess(*reducer).getElement(E); + + // If the reducer element doesn't have a value we can skip the combine. + if (!ReducerElem) + continue; + auto AtomicRef = sycl::atomic_ref(), Space>( address_space_cast(ReduVarPtr)[E]); - Functor(std::move(AtomicRef), getReducerAccess(*reducer).getElement(E)); + Functor(std::move(AtomicRef), *ReducerElem); } } @@ -361,13 +371,6 @@ template class combiner { } }; -template class reducer_common { -public: - using value_type = T; - using binary_operation = BinaryOperation; - static constexpr int dimensions = Dims; -}; - /// Templated class for common functionality of all reduction implementation /// classes. template +class ReducerElement { + using value_type = std::conditional_t, T>; + + template + constexpr value_type GetInitValue( + const ReductionIdentityContainer + &IdentityContainer) { + constexpr bool ContainerHasIdentity = + ReductionIdentityContainer::has_identity; + static_assert(IsOptional || ContainerHasIdentity); + if constexpr (!ContainerHasIdentity) + return std::nullopt; + else + return IdentityContainer.getIdentity(); + } + +public: + ReducerElement() = default; + ReducerElement(T Value) : MValue{Value} {} + + template + ReducerElement( + const ReductionIdentityContainer + &IdentityContainer) + : MValue(GetInitValue(IdentityContainer)) {} + + ReducerElement &combine(BinaryOperation BinOp, const T &OtherValue) { + if constexpr (IsOptional) + MValue = MValue ? BinOp(*MValue, OtherValue) : OtherValue; + else + MValue = BinOp(MValue, OtherValue); + return *this; + } + + ReducerElement &combine(BinaryOperation BinOp, const ReducerElement &Other) { + if constexpr (IsOptional) { + if (Other.MValue) + return combine(BinOp, *Other.MValue); + // If the other value doesn't have a value it is a no-op. + return *this; + } else { + return combine(BinOp, Other.MValue); + } + } + + constexpr T &operator*() noexcept { + if constexpr (IsOptional) + return *MValue; + else + return MValue; + } + constexpr const T &operator*() const noexcept { + if constexpr (IsOptional) + return *MValue; + else + return MValue; + } + + constexpr explicit operator bool() const { + if constexpr (IsOptional) + return MValue.has_value(); + return true; + } + +private: + value_type MValue; +}; + +template class reducer_common { +public: + using value_type = T; + using binary_operation = BinaryOperation; + static constexpr int dimensions = Dims; +}; + // Token class to help with the in-place construction of reducers. template struct ReducerToken { @@ -450,21 +532,13 @@ class reducer< !detail::IsKnownIdentityOp::value>>>, public detail::reducer_common { static constexpr bool has_identity = IdentityContainerT::has_identity; - using internal_value_type = - std::conditional_t>; - - constexpr internal_value_type - GetInitialValue(const IdentityContainerT &IdentityContainer) { - if constexpr (has_identity) - return IdentityContainer.getIdentity(); - else - return std::nullopt; - } + using element_type = + detail::ReducerElement; public: reducer(const IdentityContainerT &IdentityContainer, BinaryOperation BOp) - : MValue(GetInitialValue(IdentityContainer)), - MIdentity(IdentityContainer), MBinaryOp(BOp) {} + : MValue(IdentityContainer), MIdentity(IdentityContainer), + MBinaryOp(BOp) {} reducer( const detail::ReducerToken &Token) : reducer(Token.IdentityContainer, Token.BOp) {} @@ -475,10 +549,7 @@ class reducer< reducer &operator=(reducer &&) = delete; reducer &combine(const T &Partial) { - if constexpr (has_identity) - MValue = MBinaryOp(MValue, Partial); - else - MValue = MValue ? MBinaryOp(*MValue, Partial) : Partial; + MValue.combine(MBinaryOp, Partial); return *this; } @@ -491,20 +562,10 @@ class reducer< private: template friend class detail::ReducerAccess; - T &getElement(size_t) { - if constexpr (has_identity) - return MValue; - else - return *MValue; - } - const T &getElement(size_t) const { - if constexpr (has_identity) - return MValue; - else - return *MValue; - } + element_type &getElement(size_t) { return MValue; } + const element_type &getElement(size_t) const { return MValue; } - internal_value_type MValue; + detail::ReducerElement MValue; const IdentityContainerT MIdentity; BinaryOperation MBinaryOp; }; @@ -526,6 +587,10 @@ class reducer< Dims == 0 && Extent == 1 && View == false && detail::IsKnownIdentityOp::value>>>, public detail::reducer_common { + static constexpr bool has_identity = IdentityContainerT::has_identity; + using element_type = + detail::ReducerElement; + public: reducer() : MValue(getIdentity()) {} reducer(const IdentityContainerT & /* Identity */, BinaryOperation) @@ -541,7 +606,7 @@ class reducer< reducer &combine(const T &Partial) { BinaryOperation BOp; - MValue = BOp(MValue, Partial); + MValue.combine(BOp, Partial); return *this; } @@ -554,9 +619,9 @@ class reducer< return detail::known_identity_impl::value; } - T &getElement(size_t) { return MValue; } - const T &getElement(size_t) const { return MValue; } - T MValue; + element_type &getElement(size_t) { return MValue; } + const element_type &getElement(size_t) const { return MValue; } + detail::ReducerElement MValue; }; /// Component of 'reducer' class for array reductions, representing a single @@ -570,11 +635,11 @@ class reducer>>, public detail::reducer_common { static constexpr bool has_identity = IdentityContainerT::has_identity; - using internal_value_type = - std::conditional_t>; + using element_type = + detail::ReducerElement; public: - reducer(internal_value_type &Ref, BinaryOperation BOp) + reducer(element_type &Ref, BinaryOperation BOp) : MElement(Ref), MBinaryOp(BOp) {} reducer( const detail::ReducerToken &Token) @@ -586,17 +651,17 @@ class reducer friend class detail::ReducerAccess; - internal_value_type &MElement; + element_type &getElement(size_t) { return MElement; } + const element_type &getElement(size_t) const { return MElement; } + + element_type &MElement; BinaryOperation MBinaryOp; }; @@ -615,21 +680,13 @@ class reducer< !detail::IsKnownIdentityOp::value>>>, public detail::reducer_common { static constexpr bool has_identity = IdentityContainerT::has_identity; - using internal_value_type = - std::conditional_t>; - - constexpr internal_value_type - GetInitialValue(const IdentityContainerT &IdentityContainer) { - if constexpr (has_identity) - return IdentityContainer.getIdentity(); - else - return std::nullopt; - } + using element_type = + detail::ReducerElement; public: reducer(const IdentityContainerT &IdentityContainer, BinaryOperation BOp) - : MValue(GetInitialValue(IdentityContainer)), - MIdentity(IdentityContainer), MBinaryOp(BOp) {} + : MValue(IdentityContainer), MIdentity(IdentityContainer), + MBinaryOp(BOp) {} reducer( const detail::ReducerToken &Token) : reducer(Token.IdentityContainer, Token.BOp) {} @@ -653,20 +710,10 @@ class reducer< private: template friend class detail::ReducerAccess; - T &getElement(size_t E) { - if constexpr (has_identity) - return MValue[E]; - else - return *(MValue[E]); - } - const T &getElement(size_t E) const { - if constexpr (has_identity) - return MValue[E]; - else - return *(MValue[E]); - } + element_type &getElement(size_t E) { return MValue[E]; } + const element_type &getElement(size_t E) const { return MValue[E]; } - marray MValue; + marray MValue; const IdentityContainerT MIdentity; BinaryOperation MBinaryOp; }; @@ -685,6 +732,10 @@ class reducer< Dims == 1 && View == false && detail::IsKnownIdentityOp::value>>>, public detail::reducer_common { + static constexpr bool has_identity = IdentityContainerT::has_identity; + using element_type = + detail::ReducerElement; + public: reducer() : MValue(getIdentity()) {} reducer(const IdentityContainerT & /* Identity */, BinaryOperation) @@ -714,10 +765,10 @@ class reducer< return detail::known_identity_impl::value; } - T &getElement(size_t E) { return MValue[E]; } - const T &getElement(size_t E) const { return MValue[E]; } + element_type &getElement(size_t E) { return MValue[E]; } + const element_type &getElement(size_t E) const { return MValue[E]; } - marray MValue; + marray MValue; }; namespace detail { @@ -739,10 +790,8 @@ template struct get_red_t { using type = T; }; -template -struct get_red_t< - accessor> { +template +struct get_red_t> { using type = T; }; @@ -797,6 +846,8 @@ class reduction_impl_algo { detail::ReducerToken; using reducer_type = reducer; + using reducer_element_type = + typename ReducerTraits::element_type; using result_type = T; using binary_operation = BinaryOperation; @@ -846,9 +897,10 @@ class reduction_impl_algo { // If there is only one WG we can avoid creation of temporary buffer with // partial sums and write directly into user's reduction variable. if constexpr (IsOneWG) { - return MRedOut; + return getUserRedVarAccess(CGH); } else { - MOutBufPtr = std::make_shared>(range<1>(Size)); + MOutBufPtr = + std::make_shared>(range<1>(Size)); CGH.addReduction(MOutBufPtr); return accessor{*MOutBufPtr, CGH}; } @@ -867,17 +919,24 @@ class reduction_impl_algo { /// Otherwise, a new buffer is created and accessor to that buffer is /// returned. auto getWriteAccForPartialReds(size_t Size, handler &CGH) { - if constexpr (!is_usm) { + static_assert(!has_identity || sizeof(reducer_element_type) == sizeof(T), + "Unexpected size of reducer element type."); + + // We can only use the output memory directly if it is not USM and we have + // and identity, i.e. it has a thin element wrapper. + if constexpr (!is_usm && has_identity) { if (Size == 1) { - CGH.associateWithHandler(&MRedOut, access::target::device); - return MRedOut; + auto ReinterpretRedOut = + MRedOut.template reinterpret(); + return accessor{ReinterpretRedOut, CGH}; } } // Create a new output buffer and return an accessor to it. // // Array reductions are performed element-wise to avoid stack growth. - MOutBufPtr = std::make_shared>(range<1>(Size)); + MOutBufPtr = + std::make_shared>(range<1>(Size)); CGH.addReduction(MOutBufPtr); return accessor{*MOutBufPtr, CGH}; } @@ -934,8 +993,8 @@ class reduction_impl_algo { } }); } else { - associateWithHandler(CopyHandler, &Out, access::target::device); - CopyHandler.copy(Mem, Out); + accessor OutAcc{Out, CGH}; + CopyHandler.copy(Mem, OutAcc); } }); }; @@ -947,7 +1006,7 @@ class reduction_impl_algo { if (initializeToIdentity()) DoIt(MRedOut); else - Func(MRedOut); + Func(accessor{MRedOut, CGH}); } } @@ -959,7 +1018,7 @@ class reduction_impl_algo { std::ignore = CGH; assert(!initializeToIdentity() && "Initialize to identity not allowed for identity-less reductions."); - Func(MRedOut); + Func(accessor{MRedOut, CGH}); } const identity_container_type &getIdentityContainer() { @@ -997,7 +1056,13 @@ class reduction_impl_algo { BinaryOperation getBinaryOperation() const { return MBinaryOp; } bool initializeToIdentity() const { return InitializeToIdentity; } - RedOutVar &getUserRedVar() { return MRedOut; } + auto getUserRedVarAccess(handler &CGH) { + std::ignore = CGH; + if constexpr (is_usm) + return MRedOut; + else + return accessor{MRedOut, CGH}; + } private: // Object holding the identity if available. @@ -1005,7 +1070,7 @@ class reduction_impl_algo { // Array reduction is performed element-wise to avoid stack growth, hence // 1-dimensional always. - std::shared_ptr> MOutBufPtr; + std::shared_ptr> MOutBufPtr; BinaryOperation MBinaryOp; bool InitializeToIdentity; @@ -1103,15 +1168,16 @@ void reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) { "USM reduction, not a buffer-based one."); size_t NElements = Reduction::num_elements; auto InAcc = Redu.getReadAccToPreviousPartialReds(CGH); - auto UserVarPtr = Redu.getUserRedVar(); + auto UserVarPtr = Redu.getUserRedVarAccess(CGH); bool IsUpdateOfUserVar = !Redu.initializeToIdentity(); auto BOp = Redu.getBinaryOperation(); CGH.single_task([=] { for (int i = 0; i < NElements; ++i) { + auto Elem = InAcc[i]; if (IsUpdateOfUserVar) - UserVarPtr[i] = BOp(UserVarPtr[i], InAcc.get_pointer()[i]); + UserVarPtr[i] = BOp(UserVarPtr[i], *Elem); else - UserVarPtr[i] = InAcc.get_pointer()[i]; + UserVarPtr[i] = *Elem; } }); } @@ -1121,6 +1187,10 @@ template struct MainKrn; template struct AuxKrn; } // namespace reduction +// Tag structs to help creating unique kernels for multi-reduction cases. +struct KernelOneWGTag {}; +struct KernelMultipleWGTag {}; + /// A helper to pass undefined (sycl::detail::auto_name) names unmodified. We /// must do that to avoid name collisions. template