@@ -225,6 +225,27 @@ struct known_identity_impl<BinaryOperation, AccumulatorT,
225225 : std::numeric_limits<AccumulatorT>::lowest();
226226};
227227
228+ #ifdef __SYCL_REDUCER_OP_EQ_CHECK_TRAIT
229+ #error "__SYCL_REDUCER_OP_EQ_CHECK_TRAIT must not be defined"
230+ #endif
231+
232+ #define __SYCL_REDUCER_OP_EQ_CHECK_TRAIT (OpName, Op ) \
233+ template <typename , typename = void > \
234+ struct HasSameTypeArg ##OpName##Eq : public std::false_type {}; \
235+ template <typename T> \
236+ struct HasSameTypeArg ##OpName##Eq< \
237+ T, std::enable_if_t <std::is_same< \
238+ decltype (static_cast <T &(T::*)(const T &)>(&T::operator +=)), \
239+ T &(T::*)(const T &)>::value>> : public std::true_type {};
240+
241+ __SYCL_REDUCER_OP_EQ_CHECK_TRAIT (Plus, +)
242+ __SYCL_REDUCER_OP_EQ_CHECK_TRAIT(Multiplies, *)
243+ __SYCL_REDUCER_OP_EQ_CHECK_TRAIT(BitwiseOR, |)
244+ __SYCL_REDUCER_OP_EQ_CHECK_TRAIT(BitwiseXOR, ^)
245+ __SYCL_REDUCER_OP_EQ_CHECK_TRAIT(BitwiseAND, &)
246+
247+ #undef __SYCL_REDUCER_OP_EQ_CHECK_TRAIT
248+
228249// / Class that is used to represent objects that are passed to user's lambda
229250// / functions and representing users' reduction variable.
230251// / The generic version of the class represents those reductions of those
@@ -238,6 +259,41 @@ class reducer {
238259
239260 T getIdentity () const { return MIdentity; }
240261
262+ template <typename _T = T>
263+ enable_if_t <HasSameTypeArgPlusEq<_T>::value, reducer &>
264+ operator +=(const _T &Partial) {
265+ MValue += Partial;
266+ return *this ;
267+ }
268+
269+ template <typename _T = T>
270+ enable_if_t <HasSameTypeArgMultipliesEq<_T>::value, reducer &>
271+ operator *=(const _T &Partial) {
272+ MValue *= Partial;
273+ return *this ;
274+ }
275+
276+ template <typename _T = T>
277+ enable_if_t <HasSameTypeArgBitwiseOREq<_T>::value, reducer &>
278+ operator |=(const _T &Partial) {
279+ MValue |= Partial;
280+ return *this ;
281+ }
282+
283+ template <typename _T = T>
284+ enable_if_t <HasSameTypeArgBitwiseXOREq<_T>::value, reducer &>
285+ operator ^=(const _T &Partial) {
286+ MValue ^= Partial;
287+ return *this ;
288+ }
289+
290+ template <typename _T = T>
291+ enable_if_t <HasSameTypeArgBitwiseANDEq<_T>::value, reducer &>
292+ operator &=(const _T &Partial) {
293+ MValue &= Partial;
294+ return *this ;
295+ }
296+
241297 T MValue;
242298
243299private:
0 commit comments