@@ -578,15 +578,6 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
578578 RedOutVar RedOut)
579579 : base(Identity, BinaryOp, Init), MRedOut(std::move(RedOut)){};
580580
581- // / Associates the reduction accessor to user's memory with \p CGH handler
582- // / to keep the accessor alive until the command group finishes the work.
583- // / This function does not do anything for USM reductions.
584- void associateWithHandler (handler &CGH) {
585- if constexpr (is_acc) {
586- CGH.associateWithHandler (&MRedOut, access::target::device);
587- }
588- }
589-
590581 // / Creates and returns a local accessor with the \p Size elements.
591582 // / By default the local accessor elements are of the same type as the
592583 // / elements processed by the reduction, but may it be altered by specifying
@@ -624,7 +615,7 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
624615 rw_accessor_type getWriteAccForPartialReds (size_t Size, handler &CGH) {
625616 if constexpr (is_rw_acc) {
626617 if (Size == 1 ) {
627- associateWithHandler (CGH );
618+ CGH. associateWithHandler (&MRedOut, access::target::device );
628619 return MRedOut;
629620 }
630621 }
@@ -800,7 +791,7 @@ class reduction_impl
800791 reduction_impl (RedOutVar &Acc, handler &CGH, bool InitializeToIdentity)
801792 : algo(reducer_type::getIdentity(), BinaryOperation(),
802793 InitializeToIdentity, Acc) {
803- algo:: associateWithHandler (CGH);
794+ associateWithHandler (CGH, &Acc, access::target::device );
804795 if (Acc.size () != 1 )
805796 throw sycl::runtime_error (errc::invalid,
806797 " Reduction variable must be a scalar." ,
@@ -830,7 +821,7 @@ class reduction_impl
830821 reduction_impl (RedOutVar &Acc, handler &CGH, const T &Identity,
831822 BinaryOperation BOp, bool InitializeToIdentity)
832823 : algo(chooseIdentity(Identity), BOp, InitializeToIdentity, Acc) {
833- algo:: associateWithHandler (CGH);
824+ associateWithHandler (CGH, &Acc, access::target::device );
834825 if (Acc.size () != 1 )
835826 throw sycl::runtime_error (errc::invalid,
836827 " Reduction variable must be a scalar." ,
@@ -1553,7 +1544,7 @@ template <typename KernelName, class Reduction>
15531544std::enable_if_t <!Reduction::is_usm>
15541545reduSaveFinalResultToUserMem (handler &CGH, Reduction &Redu) {
15551546 auto InAcc = Redu.getReadAccToPreviousPartialReds (CGH);
1556- Redu. associateWithHandler (CGH);
1547+ associateWithHandler (CGH, &Redu. getUserRedVar (), access::target::device );
15571548 CGH.copy (InAcc, Redu.getUserRedVar ());
15581549}
15591550
@@ -2081,26 +2072,16 @@ void reduCGFuncAtomic64(handler &CGH, KernelType KernelFunc,
20812072 CGH, KernelFunc, Range, Redu, Out);
20822073}
20832074
2084- inline void associateReduAccsWithHandlerHelper (handler &) {}
2085-
2086- template <typename ReductionT>
2087- void associateReduAccsWithHandlerHelper (handler &CGH, ReductionT &Redu) {
2088- Redu.associateWithHandler (CGH);
2089- }
2090-
2091- template <typename ReductionT, typename ... RestT,
2092- enable_if_t <(sizeof ...(RestT) > 0 ), int > Z = 0 >
2093- void associateReduAccsWithHandlerHelper (handler &CGH, ReductionT &Redu,
2094- RestT &...Rest) {
2095- Redu.associateWithHandler (CGH);
2096- associateReduAccsWithHandlerHelper (CGH, Rest...);
2097- }
2098-
20992075template <typename ... Reductions, size_t ... Is>
21002076void associateReduAccsWithHandler (handler &CGH,
21012077 std::tuple<Reductions...> &ReduTuple,
21022078 std::index_sequence<Is...>) {
2103- associateReduAccsWithHandlerHelper (CGH, std::get<Is>(ReduTuple)...);
2079+ auto ProcessOne = [&CGH](auto Redu) {
2080+ if constexpr (decltype (Redu)::is_acc) {
2081+ associateWithHandler (CGH, &Redu.getUserRedVar (), access::target::device);
2082+ }
2083+ };
2084+ (ProcessOne (std::get<Is>(ReduTuple)), ...);
21042085}
21052086
21062087// / All scalar reductions are processed together; there is one loop of log2(N)
@@ -2371,7 +2352,8 @@ void reduSaveFinalResultToUserMemHelper(
23712352 handler::withAuxHandler (Queue, IsHost, [&](handler &CopyHandler) {
23722353 auto InAcc = Redu.getReadAccToPreviousPartialReds (CopyHandler);
23732354 auto OutAcc = Redu.getUserRedVar ();
2374- Redu.associateWithHandler (CopyHandler);
2355+ associateWithHandler (CopyHandler, &Redu.getUserRedVar (),
2356+ access::target::device);
23752357 if (!Events.empty ())
23762358 CopyHandler.depends_on (Events.back ());
23772359 CopyHandler.copy (InAcc, OutAcc);
0 commit comments