|
15 | 15 | #include <sycl/pointers.hpp> |
16 | 16 | #include <sycl/types.hpp> |
17 | 17 |
|
| 18 | +#include <algorithm> |
| 19 | + |
18 | 20 | // TODO Decide whether to mark functions with this attribute. |
19 | 21 | #define __NOEXC /*noexcept*/ |
20 | 22 |
|
@@ -1082,12 +1084,14 @@ detail::enable_if_t<detail::is_ugeninteger<T>::value, T> abs(T x) __NOEXC { |
1082 | 1084 | return __sycl_std::__invoke_u_abs<T>(x); |
1083 | 1085 | } |
1084 | 1086 |
|
1085 | | -// ugeninteger abs (geninteger x) |
| 1087 | +// igeninteger abs (geninteger x) |
1086 | 1088 | template <typename T> |
1087 | | -detail::enable_if_t<detail::is_igeninteger<T>::value, |
1088 | | - detail::make_unsigned_t<T>> |
1089 | | -abs(T x) __NOEXC { |
1090 | | - return __sycl_std::__invoke_s_abs<detail::make_unsigned_t<T>>(x); |
| 1089 | +detail::enable_if_t<detail::is_igeninteger<T>::value, T> abs(T x) __NOEXC { |
| 1090 | + auto res = __sycl_std::__invoke_s_abs<detail::make_unsigned_t<T>>(x); |
| 1091 | + if constexpr (detail::is_vigeninteger<T>::value) { |
| 1092 | + return res.template convert<detail::vector_element_t<T>>(); |
| 1093 | + } else |
| 1094 | + return detail::make_signed_t<decltype(res)>(res); |
1091 | 1095 | } |
1092 | 1096 |
|
1093 | 1097 | // ugeninteger abs_diff (geninteger x, geninteger y) |
@@ -1434,9 +1438,8 @@ mul24(T x, T y) __NOEXC { |
1434 | 1438 |
|
1435 | 1439 | #define __SYCL_MARRAY_INTEGER_FUNCTION_ABS_I_OVERLOAD(NAME, ARG, ...) \ |
1436 | 1440 | template <typename T, size_t N> \ |
1437 | | - std::enable_if_t<detail::is_igeninteger<T>::value, \ |
1438 | | - marray<detail::make_unsigned_t<T>, N>> \ |
1439 | | - NAME(marray<T, N> ARG) __NOEXC { \ |
| 1441 | + std::enable_if_t<detail::is_igeninteger<T>::value, marray<T, N>> NAME( \ |
| 1442 | + marray<T, N> ARG) __NOEXC { \ |
1440 | 1443 | __SYCL_MARRAY_INTEGER_FUNCTION_OVERLOAD_IMPL(NAME, __VA_ARGS__) \ |
1441 | 1444 | } |
1442 | 1445 |
|
@@ -2073,19 +2076,27 @@ detail::enable_if_t<detail::is_gentype<T>::value, T> bitselect(T a, T b, |
2073 | 2076 | template <typename T> |
2074 | 2077 | detail::enable_if_t<detail::is_sgentype<T>::value, T> select(T a, T b, |
2075 | 2078 | bool c) __NOEXC { |
2076 | | - return __sycl_std::__invoke_select<T>(a, b, static_cast<int>(c)); |
2077 | | -} |
2078 | | - |
2079 | | -// mgentype select (mgentype a, mgentype b, marray<bool, { N }> c) |
2080 | | -template <typename T, |
2081 | | - typename = std::enable_if_t<detail::is_mgenfloat<T>::value>> |
2082 | | -sycl::marray<detail::marray_element_t<T>, T::size()> |
2083 | | -select(T a, T b, sycl::marray<bool, T::size()> c) __NOEXC { |
2084 | | - sycl::marray<detail::marray_element_t<T>, T::size()> res; |
2085 | | - for (int i = 0; i < a.size(); i++) { |
2086 | | - res[i] = select(a[i], b[i], c[i]); |
2087 | | - } |
2088 | | - return res; |
| 2079 | + constexpr size_t SizeT = sizeof(T); |
| 2080 | + |
| 2081 | + // sycl::select(sgentype a, sgentype b, bool c) calls OpenCL built-in |
| 2082 | + // select(sgentype a, sgentype b, igentype c). This type trait makes the |
| 2083 | + // proper conversion for argument c from bool to igentype, based on sgentype |
| 2084 | + // == T. |
| 2085 | + using get_select_opencl_builtin_c_arg_type = typename std::conditional_t< |
| 2086 | + SizeT == 1, char, |
| 2087 | + std::conditional_t< |
| 2088 | + SizeT == 2, short, |
| 2089 | + std::conditional_t< |
| 2090 | + (detail::is_contained< |
| 2091 | + T, detail::type_list<long, unsigned long>>::value && |
| 2092 | + (SizeT == 4 || SizeT == 8)), |
| 2093 | + long, // long and ulong are 32-bit on |
| 2094 | + // Windows and 64-bit on Linux |
| 2095 | + std::conditional_t<SizeT == 4, int, |
| 2096 | + std::conditional_t<SizeT == 8, long, void>>>>>; |
| 2097 | + |
| 2098 | + return __sycl_std::__invoke_select<T>( |
| 2099 | + a, b, static_cast<get_select_opencl_builtin_c_arg_type>(c)); |
2089 | 2100 | } |
2090 | 2101 |
|
2091 | 2102 | // geninteger select (geninteger a, geninteger b, igeninteger c) |
@@ -2164,6 +2175,40 @@ select(T a, T b, T2 c) __NOEXC { |
2164 | 2175 | return __sycl_std::__invoke_select<T>(a, b, c); |
2165 | 2176 | } |
2166 | 2177 |
|
| 2178 | +// other marray relational functions |
| 2179 | + |
| 2180 | +template <typename T, size_t N> |
| 2181 | +detail::enable_if_t<detail::is_sigeninteger<T>::value, bool> |
| 2182 | +any(marray<T, N> x) __NOEXC { |
| 2183 | + return std::any_of(x.begin(), x.end(), [](T i) { return any(i); }); |
| 2184 | +} |
| 2185 | + |
| 2186 | +template <typename T, size_t N> |
| 2187 | +detail::enable_if_t<detail::is_sigeninteger<T>::value, bool> |
| 2188 | +all(marray<T, N> x) __NOEXC { |
| 2189 | + return std::all_of(x.begin(), x.end(), [](T i) { return all(i); }); |
| 2190 | +} |
| 2191 | + |
| 2192 | +template <typename T, size_t N> |
| 2193 | +detail::enable_if_t<detail::is_gentype<T>::value, marray<T, N>> |
| 2194 | +bitselect(marray<T, N> a, marray<T, N> b, marray<T, N> c) __NOEXC { |
| 2195 | + marray<T, N> res; |
| 2196 | + for (int i = 0; i < N; i++) { |
| 2197 | + res[i] = bitselect(a[i], b[i], c[i]); |
| 2198 | + } |
| 2199 | + return res; |
| 2200 | +} |
| 2201 | + |
| 2202 | +template <typename T, size_t N> |
| 2203 | +detail::enable_if_t<detail::is_gentype<T>::value, marray<T, N>> |
| 2204 | +select(marray<T, N> a, marray<T, N> b, marray<bool, N> c) __NOEXC { |
| 2205 | + marray<T, N> res; |
| 2206 | + for (int i = 0; i < N; i++) { |
| 2207 | + res[i] = select(a[i], b[i], c[i]); |
| 2208 | + } |
| 2209 | + return res; |
| 2210 | +} |
| 2211 | + |
2167 | 2212 | namespace native { |
2168 | 2213 | /* ----------------- 4.13.3 Math functions. ---------------------------------*/ |
2169 | 2214 |
|
|
0 commit comments