diff --git a/sycl/include/CL/sycl/ONEAPI/reduction.hpp b/sycl/include/CL/sycl/ONEAPI/reduction.hpp index 9c3eb5a1b55c8..f72667fce6a3b 100644 --- a/sycl/include/CL/sycl/ONEAPI/reduction.hpp +++ b/sycl/include/CL/sycl/ONEAPI/reduction.hpp @@ -238,6 +238,36 @@ class reducer { T getIdentity() const { return MIdentity; } + template + enable_if_t::value> + operator+=(const _T &Partial) { + combine(Partial); + } + + template + enable_if_t::value> + operator*=(const _T &Partial) { + combine(Partial); + } + + template + enable_if_t::value> + operator|=(const _T &Partial) { + combine(Partial); + } + + template + enable_if_t::value> + operator^=(const _T &Partial) { + combine(Partial); + } + + template + enable_if_t::value> + operator&=(const _T &Partial) { + combine(Partial); + } + T MValue; private: @@ -281,48 +311,33 @@ class reducer - enable_if_t::value && - IsReduPlus::value, - reducer &> + enable_if_t::value> operator+=(const _T &Partial) { combine(Partial); - return *this; } template - enable_if_t::value && - IsReduMultiplies::value, - reducer &> + enable_if_t::value> operator*=(const _T &Partial) { combine(Partial); - return *this; } template - enable_if_t::value && - IsReduBitOR::value, - reducer &> + enable_if_t::value> operator|=(const _T &Partial) { combine(Partial); - return *this; } template - enable_if_t::value && - IsReduBitXOR::value, - reducer &> + enable_if_t::value> operator^=(const _T &Partial) { combine(Partial); - return *this; } template - enable_if_t::value && - IsReduBitAND::value, - reducer &> + enable_if_t::value> operator&=(const _T &Partial) { combine(Partial); - return *this; } /// Atomic ADD operation: *ReduVarPtr += MValue;