Skip to content

Commit c72d95e

Browse files
Update associateWithHandler uses
1 parent 6e476c3 commit c72d95e

File tree

1 file changed

+12
-30
lines changed

1 file changed

+12
-30
lines changed

sycl/include/sycl/ext/oneapi/reduction.hpp

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -578,15 +578,6 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
578578
RedOutVar RedOut)
579579
: base(Identity, BinaryOp, Init), MRedOut(std::move(RedOut)){};
580580

581-
/// Associates the reduction accessor to user's memory with \p CGH handler
582-
/// to keep the accessor alive until the command group finishes the work.
583-
/// This function does not do anything for USM reductions.
584-
void associateWithHandler(handler &CGH) {
585-
if constexpr (is_acc) {
586-
CGH.associateWithHandler(&MRedOut, access::target::device);
587-
}
588-
}
589-
590581
/// Creates and returns a local accessor with the \p Size elements.
591582
/// By default the local accessor elements are of the same type as the
592583
/// elements processed by the reduction, but may it be altered by specifying
@@ -624,7 +615,7 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
624615
rw_accessor_type getWriteAccForPartialReds(size_t Size, handler &CGH) {
625616
if constexpr (is_rw_acc) {
626617
if (Size == 1) {
627-
associateWithHandler(CGH);
618+
CGH.associateWithHandler(&MRedOut, access::target::device);
628619
return MRedOut;
629620
}
630621
}
@@ -800,7 +791,7 @@ class reduction_impl
800791
reduction_impl(RedOutVar &Acc, handler &CGH, bool InitializeToIdentity)
801792
: algo(reducer_type::getIdentity(), BinaryOperation(),
802793
InitializeToIdentity, Acc) {
803-
algo::associateWithHandler(CGH);
794+
associateWithHandler(CGH, &Acc, access::target::device);
804795
if (Acc.size() != 1)
805796
throw sycl::runtime_error(errc::invalid,
806797
"Reduction variable must be a scalar.",
@@ -830,7 +821,7 @@ class reduction_impl
830821
reduction_impl(RedOutVar &Acc, handler &CGH, const T &Identity,
831822
BinaryOperation BOp, bool InitializeToIdentity)
832823
: algo(chooseIdentity(Identity), BOp, InitializeToIdentity, Acc) {
833-
algo::associateWithHandler(CGH);
824+
associateWithHandler(CGH, &Acc, access::target::device);
834825
if (Acc.size() != 1)
835826
throw sycl::runtime_error(errc::invalid,
836827
"Reduction variable must be a scalar.",
@@ -1553,7 +1544,7 @@ template <typename KernelName, class Reduction>
15531544
std::enable_if_t<!Reduction::is_usm>
15541545
reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) {
15551546
auto InAcc = Redu.getReadAccToPreviousPartialReds(CGH);
1556-
Redu.associateWithHandler(CGH);
1547+
associateWithHandler(CGH, &Redu.getUserRedVar(), access::target::device);
15571548
CGH.copy(InAcc, Redu.getUserRedVar());
15581549
}
15591550

@@ -2081,26 +2072,16 @@ void reduCGFuncAtomic64(handler &CGH, KernelType KernelFunc,
20812072
CGH, KernelFunc, Range, Redu, Out);
20822073
}
20832074

2084-
inline void associateReduAccsWithHandlerHelper(handler &) {}
2085-
2086-
template <typename ReductionT>
2087-
void associateReduAccsWithHandlerHelper(handler &CGH, ReductionT &Redu) {
2088-
Redu.associateWithHandler(CGH);
2089-
}
2090-
2091-
template <typename ReductionT, typename... RestT,
2092-
enable_if_t<(sizeof...(RestT) > 0), int> Z = 0>
2093-
void associateReduAccsWithHandlerHelper(handler &CGH, ReductionT &Redu,
2094-
RestT &...Rest) {
2095-
Redu.associateWithHandler(CGH);
2096-
associateReduAccsWithHandlerHelper(CGH, Rest...);
2097-
}
2098-
20992075
template <typename... Reductions, size_t... Is>
21002076
void associateReduAccsWithHandler(handler &CGH,
21012077
std::tuple<Reductions...> &ReduTuple,
21022078
std::index_sequence<Is...>) {
2103-
associateReduAccsWithHandlerHelper(CGH, std::get<Is>(ReduTuple)...);
2079+
auto ProcessOne = [&CGH](auto Redu) {
2080+
if constexpr (decltype(Redu)::is_acc) {
2081+
associateWithHandler(CGH, &Redu.getUserRedVar(), access::target::device);
2082+
}
2083+
};
2084+
(ProcessOne(std::get<Is>(ReduTuple)), ...);
21042085
}
21052086

21062087
/// All scalar reductions are processed together; there is one loop of log2(N)
@@ -2371,7 +2352,8 @@ void reduSaveFinalResultToUserMemHelper(
23712352
handler::withAuxHandler(Queue, IsHost, [&](handler &CopyHandler) {
23722353
auto InAcc = Redu.getReadAccToPreviousPartialReds(CopyHandler);
23732354
auto OutAcc = Redu.getUserRedVar();
2374-
Redu.associateWithHandler(CopyHandler);
2355+
associateWithHandler(CopyHandler, &Redu.getUserRedVar(),
2356+
access::target::device);
23752357
if (!Events.empty())
23762358
CopyHandler.depends_on(Events.back());
23772359
CopyHandler.copy(InAcc, OutAcc);

0 commit comments

Comments
 (0)