@@ -144,6 +144,19 @@ using is_plus_or_multiplies_if_complex = std::integral_constant<
144144 is_multiplies<T, BinaryOperation>::value)
145145 : std::true_type::value)>;
146146
147+ // used to transform a vector op to a scalar op;
148+ // e.g. sycl::plus<std::vec<T, N>> to sycl::plus<T>
149+ template <typename T> struct get_scalar_binary_op ;
150+
151+ template <template <typename > typename F, typename T, int n>
152+ struct get_scalar_binary_op <F<sycl::vec<T, n>>> {
153+ using type = F<T>;
154+ };
155+
156+ template <template <typename > typename F> struct get_scalar_binary_op <F<void >> {
157+ using type = F<void >;
158+ };
159+
147160// ---- identity_for_ga_op
148161// the group algorithms support std::complex, limited to sycl::plus operation
149162// get the correct identity for group algorithm operation.
@@ -201,11 +214,8 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
201214 detail::is_native_op<T, BinaryOperation>::value),
202215 T>
203216reduce_over_group (Group g, T x, BinaryOperation binary_op) {
204- // FIXME: Do not special-case for half precision
205217 static_assert (
206- std::is_same_v<decltype (binary_op (x, x)), T> ||
207- (std::is_same_v<T, half> &&
208- std::is_same_v<decltype (binary_op (x, x)), float >),
218+ std::is_same_v<decltype (binary_op (x, x)), T>,
209219 " Result type of binary_op must match reduction accumulation type." );
210220#ifdef __SYCL_DEVICE_ONLY__
211221#if defined(__NVPTX__)
@@ -251,24 +261,21 @@ reduce_over_group(Group g, T x, BinaryOperation binary_op) {
251261#endif
252262}
253263
254- template <typename Group, typename T, int N, class BinaryOperation >
255- std::enable_if_t <
256- (is_group_v<std::decay_t <Group>> &&
257- detail::is_vector_arithmetic_or_complex<sycl::vec<T, N>>::value &&
258- detail::is_native_op<sycl::vec<T, N>, BinaryOperation>::value),
259- sycl::vec<T, N>>
260- reduce_over_group (Group g, sycl::vec<T, N> x, BinaryOperation binary_op) {
261- // FIXME: Do not special-case for half precision
264+ template <typename Group, typename T, class BinaryOperation >
265+ std::enable_if_t <(is_group_v<std::decay_t <Group>> &&
266+ detail::is_vector_arithmetic_or_complex<T>::value &&
267+ detail::is_native_op<T, BinaryOperation>::value),
268+ T>
269+ reduce_over_group (Group g, T x, BinaryOperation binary_op) {
262270 static_assert (
263- std::is_same_v<decltype (binary_op (x[0 ], x[0 ])),
264- typename sycl::vec<T, N>::element_type> ||
265- (std::is_same_v<sycl::vec<T, N>, half> &&
266- std::is_same_v<decltype (binary_op (x[0 ], x[0 ])), float >),
271+ std::is_same_v<decltype (binary_op (x, x)), T>,
267272 " Result type of binary_op must match reduction accumulation type." );
268- sycl::vec<T, N> result;
269-
270- detail::loop<N>(
271- [&](size_t s) { result[s] = reduce_over_group (g, x[s], binary_op); });
273+ T result;
274+ typename detail::get_scalar_binary_op<BinaryOperation>::type
275+ scalar_binary_op{};
276+ detail::loop<x.size ()>([&](size_t s) {
277+ result[s] = reduce_over_group (g, x[s], scalar_binary_op);
278+ });
272279 return result;
273280}
274281
@@ -284,11 +291,8 @@ std::enable_if_t<
284291 std::is_convertible_v<V, T>),
285292 T>
286293reduce_over_group (Group g, V x, T init, BinaryOperation binary_op) {
287- // FIXME: Do not special-case for half precision
288294 static_assert (
289- std::is_same_v<decltype (binary_op (init, x)), T> ||
290- (std::is_same_v<T, half> &&
291- std::is_same_v<decltype (binary_op (init, x)), float >),
295+ std::is_same_v<decltype (binary_op (init, x)), T>,
292296 " Result type of binary_op must match reduction accumulation type." );
293297#ifdef __SYCL_DEVICE_ONLY__
294298 return binary_op (init, reduce_over_group (g, T (x), binary_op));
@@ -307,17 +311,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
307311 detail::is_native_op<T, BinaryOperation>::value),
308312 T>
309313reduce_over_group (Group g, V x, T init, BinaryOperation binary_op) {
310- // FIXME: Do not special-case for half precision
311314 static_assert (
312- std::is_same_v<decltype (binary_op (init[0 ], x[0 ])),
313- typename T::element_type> ||
314- (std::is_same_v<T, half> &&
315- std::is_same_v<decltype (binary_op (init[0 ], x[0 ])), float >),
315+ std::is_same_v<decltype (binary_op (init, x)), T>,
316316 " Result type of binary_op must match reduction accumulation type." );
317+ typename detail::get_scalar_binary_op<BinaryOperation>::type
318+ scalar_binary_op{};
317319#ifdef __SYCL_DEVICE_ONLY__
318320 T result = init;
319321 for (int s = 0 ; s < x.size (); ++s) {
320- result[s] = binary_op (init[s], reduce_over_group (g, x[s], binary_op));
322+ result[s] =
323+ scalar_binary_op (init[s], reduce_over_group (g, x[s], scalar_binary_op));
321324 }
322325 return result;
323326#else
@@ -338,11 +341,8 @@ std::enable_if_t<
338341 detail::is_native_op<T, BinaryOperation>::value),
339342 T>
340343joint_reduce (Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
341- // FIXME: Do not special-case for half precision
342344 static_assert (
343- std::is_same_v<decltype (binary_op (init, *first)), T> ||
344- (std::is_same_v<T, half> &&
345- std::is_same_v<decltype (binary_op (init, *first)), float >),
345+ std::is_same_v<decltype (binary_op (init, *first)), T>,
346346 " Result type of binary_op must match reduction accumulation type." );
347347#ifdef __SYCL_DEVICE_ONLY__
348348 T partial = detail::identity_for_ga_op<T, BinaryOperation>();
@@ -667,10 +667,7 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
667667 detail::is_native_op<T, BinaryOperation>::value),
668668 T>
669669exclusive_scan_over_group (Group g, T x, BinaryOperation binary_op) {
670- // FIXME: Do not special-case for half precision
671- static_assert (std::is_same_v<decltype (binary_op (x, x)), T> ||
672- (std::is_same_v<T, half> &&
673- std::is_same_v<decltype (binary_op (x, x)), float >),
670+ static_assert (std::is_same_v<decltype (binary_op (x, x)), T>,
674671 " Result type of binary_op must match scan accumulation type." );
675672#ifdef __SYCL_DEVICE_ONLY__
676673#if defined(__NVPTX__)
@@ -718,15 +715,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
718715 detail::is_native_op<T, BinaryOperation>::value),
719716 T>
720717exclusive_scan_over_group (Group g, T x, BinaryOperation binary_op) {
721- // FIXME: Do not special-case for half precision
722- static_assert (std::is_same_v<decltype (binary_op (x[0 ], x[0 ])),
723- typename T::element_type> ||
724- (std::is_same_v<T, half> &&
725- std::is_same_v<decltype (binary_op (x[0 ], x[0 ])), float >),
718+ static_assert (std::is_same_v<decltype (binary_op (x, x)), T>,
726719 " Result type of binary_op must match scan accumulation type." );
727720 T result;
721+ typename detail::get_scalar_binary_op<BinaryOperation>::type
722+ scalar_binary_op{};
728723 for (int s = 0 ; s < x.size (); ++s) {
729- result[s] = exclusive_scan_over_group (g, x[s], binary_op );
724+ result[s] = exclusive_scan_over_group (g, x[s], scalar_binary_op );
730725 }
731726 return result;
732727}
@@ -741,15 +736,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
741736 detail::is_native_op<T, BinaryOperation>::value),
742737 T>
743738exclusive_scan_over_group (Group g, V x, T init, BinaryOperation binary_op) {
744- // FIXME: Do not special-case for half precision
745- static_assert (std::is_same_v<decltype (binary_op (init[0 ], x[0 ])),
746- typename T::element_type> ||
747- (std::is_same_v<T, half> &&
748- std::is_same_v<decltype (binary_op (init[0 ], x[0 ])), float >),
739+ static_assert (std::is_same_v<decltype (binary_op (init, x)), T>,
749740 " Result type of binary_op must match scan accumulation type." );
750741 T result;
742+ typename detail::get_scalar_binary_op<BinaryOperation>::type
743+ scalar_binary_op{};
751744 for (int s = 0 ; s < x.size (); ++s) {
752- result[s] = exclusive_scan_over_group (g, x[s], init[s], binary_op );
745+ result[s] = exclusive_scan_over_group (g, x[s], init[s], scalar_binary_op );
753746 }
754747 return result;
755748}
@@ -764,10 +757,7 @@ std::enable_if_t<
764757 std::is_convertible_v<V, T>),
765758 T>
766759exclusive_scan_over_group (Group g, V x, T init, BinaryOperation binary_op) {
767- // FIXME: Do not special-case for half precision
768- static_assert (std::is_same_v<decltype (binary_op (init, x)), T> ||
769- (std::is_same_v<T, half> &&
770- std::is_same_v<decltype (binary_op (init, x)), float >),
760+ static_assert (std::is_same_v<decltype (binary_op (init, x)), T>,
771761 " Result type of binary_op must match scan accumulation type." );
772762#ifdef __SYCL_DEVICE_ONLY__
773763 typename Group::linear_id_type local_linear_id =
@@ -804,10 +794,7 @@ std::enable_if_t<
804794 OutPtr>
805795joint_exclusive_scan (Group g, InPtr first, InPtr last, OutPtr result, T init,
806796 BinaryOperation binary_op) {
807- // FIXME: Do not special-case for half precision
808- static_assert (std::is_same_v<decltype (binary_op (init, *first)), T> ||
809- (std::is_same_v<T, half> &&
810- std::is_same_v<decltype (binary_op (init, *first)), float >),
797+ static_assert (std::is_same_v<decltype (binary_op (init, *first)), T>,
811798 " Result type of binary_op must match scan accumulation type." );
812799#ifdef __SYCL_DEVICE_ONLY__
813800 ptrdiff_t offset = sycl::detail::get_local_linear_id (g);
@@ -859,14 +846,9 @@ std::enable_if_t<
859846 OutPtr>
860847joint_exclusive_scan (Group g, InPtr first, InPtr last, OutPtr result,
861848 BinaryOperation binary_op) {
862- // FIXME: Do not special-case for half precision
863- static_assert (
864- std::is_same_v<decltype (binary_op (*first, *first)),
865- typename detail::remove_pointer<OutPtr>::type> ||
866- (std::is_same_v<typename detail::remove_pointer<OutPtr>::type,
867- half> &&
868- std::is_same_v<decltype (binary_op (*first, *first)), float >),
869- " Result type of binary_op must match scan accumulation type." );
849+ static_assert (std::is_same_v<decltype (binary_op (*first, *first)),
850+ typename detail::remove_pointer<OutPtr>::type>,
851+ " Result type of binary_op must match scan accumulation type." );
870852 using T = typename detail::remove_pointer<OutPtr>::type;
871853 T init = detail::identity_for_ga_op<T, BinaryOperation>();
872854 return joint_exclusive_scan (g, first, last, result, init, binary_op);
@@ -882,15 +864,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
882864 detail::is_native_op<T, BinaryOperation>::value),
883865 T>
884866inclusive_scan_over_group (Group g, T x, BinaryOperation binary_op) {
885- // FIXME: Do not special-case for half precision
886- static_assert (std::is_same_v<decltype (binary_op (x[0 ], x[0 ])),
887- typename T::element_type> ||
888- (std::is_same_v<T, half> &&
889- std::is_same_v<decltype (binary_op (x[0 ], x[0 ])), float >),
867+ static_assert (std::is_same_v<decltype (binary_op (x, x)), T>,
890868 " Result type of binary_op must match scan accumulation type." );
891869 T result;
870+ typename detail::get_scalar_binary_op<BinaryOperation>::type
871+ scalar_binary_op{};
892872 for (int s = 0 ; s < x.size (); ++s) {
893- result[s] = inclusive_scan_over_group (g, x[s], binary_op );
873+ result[s] = inclusive_scan_over_group (g, x[s], scalar_binary_op );
894874 }
895875 return result;
896876}
@@ -903,10 +883,7 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
903883 detail::is_native_op<T, BinaryOperation>::value),
904884 T>
905885inclusive_scan_over_group (Group g, T x, BinaryOperation binary_op) {
906- // FIXME: Do not special-case for half precision
907- static_assert (std::is_same_v<decltype (binary_op (x, x)), T> ||
908- (std::is_same_v<T, half> &&
909- std::is_same_v<decltype (binary_op (x, x)), float >),
886+ static_assert (std::is_same_v<decltype (binary_op (x, x)), T>,
910887 " Result type of binary_op must match scan accumulation type." );
911888#ifdef __SYCL_DEVICE_ONLY__
912889#if defined(__NVPTX__)
@@ -959,10 +936,7 @@ std::enable_if_t<
959936 std::is_convertible_v<V, T>),
960937 T>
961938inclusive_scan_over_group (Group g, V x, BinaryOperation binary_op, T init) {
962- // FIXME: Do not special-case for half precision
963- static_assert (std::is_same_v<decltype (binary_op (init, x)), T> ||
964- (std::is_same_v<T, half> &&
965- std::is_same_v<decltype (binary_op (init, x)), float >),
939+ static_assert (std::is_same_v<decltype (binary_op (init, x)), T>,
966940 " Result type of binary_op must match scan accumulation type." );
967941#ifdef __SYCL_DEVICE_ONLY__
968942 T y = x;
@@ -985,14 +959,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
985959 detail::is_native_op<T, BinaryOperation>::value),
986960 T>
987961inclusive_scan_over_group (Group g, V x, BinaryOperation binary_op, T init) {
988- // FIXME: Do not special-case for half precision
989- static_assert (std::is_same_v<decltype (binary_op (init[0 ], x[0 ])), T> ||
990- (std::is_same_v<T, half> &&
991- std::is_same_v<decltype (binary_op (init[0 ], x[0 ])), float >),
962+ static_assert (std::is_same_v<decltype (binary_op (init, x)), T>,
992963 " Result type of binary_op must match scan accumulation type." );
993964 T result;
965+ typename detail::get_scalar_binary_op<BinaryOperation>::type
966+ scalar_binary_op{};
994967 for (int s = 0 ; s < x.size (); ++s) {
995- result[s] = inclusive_scan_over_group (g, x[s], binary_op , init[s]);
968+ result[s] = inclusive_scan_over_group (g, x[s], scalar_binary_op , init[s]);
996969 }
997970 return result;
998971}
@@ -1013,10 +986,7 @@ std::enable_if_t<
1013986 OutPtr>
1014987joint_inclusive_scan (Group g, InPtr first, InPtr last, OutPtr result,
1015988 BinaryOperation binary_op, T init) {
1016- // FIXME: Do not special-case for half precision
1017- static_assert (std::is_same_v<decltype (binary_op (init, *first)), T> ||
1018- (std::is_same_v<T, half> &&
1019- std::is_same_v<decltype (binary_op (init, *first)), float >),
989+ static_assert (std::is_same_v<decltype (binary_op (init, *first)), T>,
1020990 " Result type of binary_op must match scan accumulation type." );
1021991#ifdef __SYCL_DEVICE_ONLY__
1022992 ptrdiff_t offset = sycl::detail::get_local_linear_id (g);
@@ -1065,14 +1035,9 @@ std::enable_if_t<
10651035 OutPtr>
10661036joint_inclusive_scan (Group g, InPtr first, InPtr last, OutPtr result,
10671037 BinaryOperation binary_op) {
1068- // FIXME: Do not special-case for half precision
1069- static_assert (
1070- std::is_same_v<decltype (binary_op (*first, *first)),
1071- typename detail::remove_pointer<OutPtr>::type> ||
1072- (std::is_same_v<typename detail::remove_pointer<OutPtr>::type,
1073- half> &&
1074- std::is_same_v<decltype (binary_op (*first, *first)), float >),
1075- " Result type of binary_op must match scan accumulation type." );
1038+ static_assert (std::is_same_v<decltype (binary_op (*first, *first)),
1039+ typename detail::remove_pointer<OutPtr>::type>,
1040+ " Result type of binary_op must match scan accumulation type." );
10761041
10771042 using T = typename detail::remove_pointer<OutPtr>::type;
10781043 T init = detail::identity_for_ga_op<T, BinaryOperation>();
0 commit comments