diff --git a/sycl/include/sycl/ext/oneapi/experimental/complex/complex.hpp b/sycl/include/sycl/ext/oneapi/experimental/complex/complex.hpp index ad6653081ff48..aeac3ced79c5f 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/complex/complex.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/complex/complex.hpp @@ -12,5 +12,7 @@ #include "./detail/complex.hpp" #include "./detail/complex_math.hpp" +#include "./detail/marray.hpp" +#include "./detail/marray_math.hpp" #endif // SYCL_EXT_ONEAPI_COMPLEX diff --git a/sycl/include/sycl/ext/oneapi/experimental/complex/detail/marray.hpp b/sycl/include/sycl/ext/oneapi/experimental/complex/detail/marray.hpp new file mode 100644 index 0000000000000..1e3a7adb06f73 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/experimental/complex/detail/marray.hpp @@ -0,0 +1,220 @@ +//===- marray.hpp ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "common.hpp" + +#include +#include + +namespace sycl { +inline namespace _V1 { + +template +class marray, NumElements> { +private: + using ComplexDataT = sycl::ext::oneapi::experimental::complex; + using MarrayDataT = typename sycl::detail::vec_helper::RetType; + +public: + using value_type = ComplexDataT; + using reference = ComplexDataT &; + using const_reference = const ComplexDataT &; + using iterator = ComplexDataT *; + using const_iterator = const ComplexDataT *; + +private: + value_type MData[NumElements]; + + template + constexpr marray(const std::array &Arr, + std::index_sequence) + : MData{Arr[Is]...} {} + + // detail::FlattenMArrayArgHelper::MArrayToArray needs to have access to + // MData. + // FIXME: If the subscript operator is made constexpr this can be removed. + friend class detail::FlattenMArrayArgHelper; + +public: + constexpr marray() : MData{} {}; + + explicit constexpr marray(const value_type &arg) + : marray{sycl::detail::RepeatValue( + static_cast(arg)), + std::make_index_sequence()} {} + + template < + typename... ArgTN, + typename = std::enable_if_t< + sycl::detail::AllSuitableArgTypes::value && + sycl::detail::GetMArrayArgsSize::value == NumElements>> + constexpr marray(const ArgTN &...Args) + : marray{ + sycl::detail::MArrayArgArrayCreator::Create( + Args...), + std::make_index_sequence()} {} + + constexpr marray(const marray &rhs) = default; + constexpr marray(marray &&rhs) = default; + + // Available only when: NumElements == 1 + template > + operator value_type() const { + return MData[0]; + } + + static constexpr std::size_t size() noexcept { return NumElements; } + + // subscript operator + reference operator[](std::size_t i) { return MData[i]; } + const_reference operator[](std::size_t i) const { return MData[i]; } + + marray &operator=(const marray &rhs) = default; + marray &operator=(const value_type &rhs) { + for (std::size_t i = 0; i < NumElements; ++i) { + MData[i] = rhs; + } + return *this; + } + + // iterator functions + iterator begin() { return MData; } + const_iterator begin() const { return MData; } + + iterator end() { return MData + NumElements; } + const_iterator end() const { return MData + NumElements; } + + /// ASSIGNMENT OPERATORS + +#ifdef IMPL_ASSIGN_MARRAY_CPLX_OP +#error "Multiple definition of IMPL_ASSIGN_MARRAY_CPLX_OP" +#endif + +#define IMPL_ASSIGN_MARRAY_CPLX_OP(op) \ + friend marray &operator op(marray & lhs, const marray & rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + lhs[i] op rhs[i]; \ + } \ + return lhs; \ + } \ + \ + friend marray &operator op(marray & lhs, const value_type & rhs) { \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + lhs[i] op rhs; \ + } \ + return lhs; \ + } + + IMPL_ASSIGN_MARRAY_CPLX_OP(+=) + IMPL_ASSIGN_MARRAY_CPLX_OP(-=) + IMPL_ASSIGN_MARRAY_CPLX_OP(*=) + IMPL_ASSIGN_MARRAY_CPLX_OP(/=) + +#undef IMPL_ASSIGN_MARRAY_CPLX_OP + + /// ARITHMETIC OPERATORS + +#ifdef IMPL_UNARY_MARRAY_CPLX_OP +#error "Multiple definition of IMPL_UNARY_MARRAY_CPLX_OP" +#endif + +#define IMPL_UNARY_MARRAY_CPLX_OP(op) \ + friend marray operator op(const marray &lhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = op lhs[i]; \ + } \ + return rtn; \ + } + + IMPL_UNARY_MARRAY_CPLX_OP(+) + IMPL_UNARY_MARRAY_CPLX_OP(-) + +#undef IMPL_UNARY_MARRAY_CPLX_OP + +#ifdef IMPL_ARITH_MARRAY_CPLX_OP +#error "Multiple definition of IMPL_ARITH_MARRAY_CPLX_OP" +#endif + +#define IMPL_ARITH_MARRAY_CPLX_OP(op) \ + friend marray operator op(const marray &lhs, const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs[i] op rhs[i]; \ + } \ + return rtn; \ + } \ + \ + friend marray operator op(const marray &lhs, const value_type &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs[i] op rhs; \ + } \ + return rtn; \ + } \ + \ + friend marray operator op(const value_type &lhs, const marray &rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs op rhs[i]; \ + } \ + return rtn; \ + } + + IMPL_ARITH_MARRAY_CPLX_OP(+) + IMPL_ARITH_MARRAY_CPLX_OP(-) + IMPL_ARITH_MARRAY_CPLX_OP(*) + IMPL_ARITH_MARRAY_CPLX_OP(/) + +#undef IMPL_ARITH_MARRAY_CPLX_OP + + /// COMPARAISON OPERATORS + +#ifdef IMPL_COMP_MARRAY_CPLX_OP +#error "Multiple definition of IMPL_COMP_MARRAY_CPLX_OP" +#endif + +#define IMPL_COMP_MARRAY_CPLX_OP(op) \ + friend marray operator op(const marray & lhs, \ + const marray & rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs[i] op rhs[i]; \ + } \ + return rtn; \ + } \ + \ + friend marray operator op(const marray & lhs, \ + const value_type & rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs[i] op rhs; \ + } \ + return rtn; \ + } \ + \ + friend marray operator op(const value_type & lhs, \ + const marray & rhs) { \ + marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = lhs op rhs[i]; \ + } \ + return rtn; \ + } + + IMPL_COMP_MARRAY_CPLX_OP(==) + IMPL_COMP_MARRAY_CPLX_OP(!=) + +#undef IMPL_COMP_MARRAY_CPLX_OP +}; + +} // namespace _V1 +} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/experimental/complex/detail/marray_math.hpp b/sycl/include/sycl/ext/oneapi/experimental/complex/detail/marray_math.hpp new file mode 100644 index 0000000000000..b9aaf793fe013 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/experimental/complex/detail/marray_math.hpp @@ -0,0 +1,129 @@ +//===- marray_math.hpp ----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "common.hpp" + +#include + +namespace sycl { +inline namespace _V1 { + +namespace ext { +namespace oneapi { +namespace experimental { + +#ifdef ONE_ARG_MARRAY_TYPE +#error "Multiple definition of ONE_ARG_MARRAY_TYPE" +#endif +#ifdef TWO_ARGS_MARRAY_TYPE +#error "Multiple definition of TWO_ARGS_MARRAY_TYPE" +#endif + +#ifdef TWO_ARGS_POLAR_S1_MARRAY_TYPE +#error "Multiple definition of TWO_ARGS_POLAR_S1_MARRAY_TYPE" +#endif +#ifdef TWO_ARGS_POLAR_S2_MARRAY_TYPE +#error "Multiple definition of TWO_ARGS_POLAR_S2_MARRAY_TYPE" +#endif + +#ifdef TWO_ARGS_POW_S1_MARRAY_TYPE +#error "Multiple definition of TWO_ARGS_POW_S1_MARRAY_TYPE" +#endif +#ifdef TWO_ARGS_POW_S2_MARRAY_TYPE +#error "Multiple definition of TWO_ARGS_POW_S2_MARRAY_TYPE" +#endif + +#ifdef MARRAY_CPLX_MATH_OP +#error "Multiple definition of MARRAY_CPLX_MATH_OP" +#endif + +// clang-format off +#define ONE_ARG_MARRAY_TYPE(TYPE) const sycl::marray &x +#define TWO_ARGS_MARRAY_TYPE(TYPE1, TYPE2) const sycl::marray &x, const sycl::marray &y + +#define TWO_ARGS_POLAR_S1_MARRAY_TYPE(TYPE1, TYPE2) const sycl::marray &x, const TYPE2 &y = 0 +#define TWO_ARGS_POLAR_S2_MARRAY_TYPE(TYPE1, TYPE2) const TYPE1 &x, const sycl::marray &y + +#define TWO_ARGS_POW_S1_MARRAY_TYPE(TYPE1, TYPE2) const sycl::marray &x, const TYPE2 &y +#define TWO_ARGS_POW_S2_MARRAY_TYPE(TYPE1, TYPE2) const TYPE1 &x, const sycl::marray &y + +#define MARRAY_CPLX_MATH_OP(NUM_ARGS, RTN_TYPE, NAME, F, ...) \ +template \ +_SYCL_EXT_CPLX_INLINE_VISIBILITY \ +typename std::enable_if_t::value, sycl::marray> \ +NAME(NUM_ARGS##_MARRAY_TYPE(__VA_ARGS__)) { \ + sycl::marray rtn; \ + for (std::size_t i = 0; i < NumElements; ++i) { \ + rtn[i] = F; \ + } \ + return rtn; \ +} + +// MARRAY_CPLX_MATH_OP(NUMBER_OF_ARGUMENTS, RETURN_TYPE, FUNCTION_NAME, FUNCTION_LOGIC, ARGUMENTS ... +MARRAY_CPLX_MATH_OP( ONE_ARG, T, abs, abs(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, T, arg, arg(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, T, arg, arg(x[i]), T); +MARRAY_CPLX_MATH_OP( ONE_ARG, T, norm, norm(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, T, norm, norm(x[i]), T); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, conj, conj(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, conj, conj(x[i]), T); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, proj, proj(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, proj, proj(x[i]), T); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, log, log(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, log10, log10(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, sqrt, sqrt(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, exp, exp(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, asinh, asinh(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, acosh, acosh(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, atanh, atanh(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, sinh, sinh(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, cosh, cosh(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, tanh, tanh(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, asin, asin(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, acos, acos(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, atan, atan(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, sin, sin(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, cos, cos(x[i]), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, complex, tan, tan(x[i]), complex); + +MARRAY_CPLX_MATH_OP( ONE_ARG, T, real, x[i].real(), complex); +MARRAY_CPLX_MATH_OP( ONE_ARG, T, imag, x[i].imag(), complex); + +MARRAY_CPLX_MATH_OP( TWO_ARGS, complex, polar, polar(x[i], y[i]), T, T); +MARRAY_CPLX_MATH_OP( TWO_ARGS_POLAR_S1, complex, polar, polar(x[i], y), T, T); +MARRAY_CPLX_MATH_OP( TWO_ARGS_POLAR_S2, complex, polar, polar(x, y[i]), T, T); + +MARRAY_CPLX_MATH_OP( TWO_ARGS, complex, pow, pow(x[i], y[i]), complex, T); +MARRAY_CPLX_MATH_OP( TWO_ARGS, complex, pow, pow(x[i], y[i]), complex, complex); +MARRAY_CPLX_MATH_OP( TWO_ARGS, complex, pow, pow(x[i], y[i]), T, complex); + +MARRAY_CPLX_MATH_OP( TWO_ARGS_POW_S1, complex, pow, pow(x[i], y), complex, T); +MARRAY_CPLX_MATH_OP( TWO_ARGS_POW_S1, complex, pow, pow(x[i], y), complex, complex); +MARRAY_CPLX_MATH_OP( TWO_ARGS_POW_S1, complex, pow, pow(x[i], y), T, complex); + +MARRAY_CPLX_MATH_OP( TWO_ARGS_POW_S2, complex, pow, pow(x, y[i]), complex, T); +MARRAY_CPLX_MATH_OP( TWO_ARGS_POW_S2, complex, pow, pow(x, y[i]), complex, complex); +MARRAY_CPLX_MATH_OP( TWO_ARGS_POW_S2, complex, pow, pow(x, y[i]), T, complex); +// clang-format on + +#undef ONE_ARG_MARRAY_TYPE +#undef TWO_ARGS_MARRAY_TYPE +#undef TWO_ARGS_POLAR_S1_MARRAY_TYPE +#undef TWO_ARGS_POLAR_S2_MARRAY_TYPE +#undef TWO_ARGS_POW_S1_MARRAY_TYPE +#undef TWO_ARGS_POW_S2_MARRAY_TYPE +#undef MARRAY_CPLX_MATH_OP + +} // namespace experimental +} // namespace oneapi +} // namespace ext + +} // namespace _V1 +} // namespace sycl diff --git a/sycl/include/sycl/marray.hpp b/sycl/include/sycl/marray.hpp index bc45a45424312..2b4ff20a56569 100644 --- a/sycl/include/sycl/marray.hpp +++ b/sycl/include/sycl/marray.hpp @@ -40,6 +40,51 @@ template struct GetMArrayArgsSize { static constexpr std::size_t value = 1 + GetMArrayArgsSize::value; }; +// Trait for checking if an argument type is either convertible to the data +// type or an array of types convertible to the data type. +template +struct IsSuitableArgType : std::is_convertible {}; +template +struct IsSuitableArgType> : std::is_convertible { +}; + +// Trait for computing the conjunction of of IsSuitableArgType. The empty type +// list will trivially evaluate to true. +template +struct AllSuitableArgTypes + : std::conjunction...> {}; + +class FlattenMArrayArgHelper { +private: + // Utility trait for creating an std::array from an marray argument. + template + static constexpr std::array + MArrayToArray(const marray &A, std::index_sequence) { + return {static_cast(A.MData[Is])...}; + } + +public: + template + static constexpr std::array FlattenMArray(const marray &A) { + return MArrayToArray(A, std::make_index_sequence()); + } + template + static constexpr auto FlattenMArray(const T &A) { + return std::array{static_cast(A)}; + } +}; + +template struct FlattenMArrayArg { + constexpr auto operator()(const T &A) const { + return FlattenMArrayArgHelper::FlattenMArray(A); + } +}; + +// Alias for shortening the marray arguments to array converter. +template +using MArrayArgArrayCreator = + detail::ArrayCreator; + } // namespace detail /// Provides a cross-platform math array class template that works on @@ -59,47 +104,13 @@ template class marray { private: value_type MData[NumElements]; - // Trait for checking if an argument type is either convertible to the data - // type or an array of types convertible to the data type. - template - struct IsSuitableArgType : std::is_convertible {}; - template - struct IsSuitableArgType> : std::is_convertible {}; - - // Trait for computing the conjunction of of IsSuitableArgType. The empty type - // list will trivially evaluate to true. - template - struct AllSuitableArgTypes : std::conjunction...> {}; - - // Utility trait for creating an std::array from an marray argument. - template - static constexpr std::array - MArrayToArray(const marray &A, std::index_sequence) { - return {static_cast(A.MData[Is])...}; - } - template - static constexpr std::array - FlattenMArrayArgHelper(const marray &A) { - return MArrayToArray(A, std::make_index_sequence()); - } - template - static constexpr auto FlattenMArrayArgHelper(const T &A) { - return std::array{static_cast(A)}; - } - template struct FlattenMArrayArg { - constexpr auto operator()(const T &A) const { - return FlattenMArrayArgHelper(A); - } - }; - - // Alias for shortening the marray arguments to array converter. - template - using MArrayArgArrayCreator = - detail::ArrayCreator; - - // FIXME: Other marray specializations needs to be a friend to access MData. - // If the subscript operator is made constexpr this can be removed. + // Other marray specializations needs to be a friend to access MData. + // FIXME: If the subscript operator is made constexpr these can be removed. template friend class marray; + // detail::FlattenMArrayArgHelper::MArrayToArray needs to be a friend to + // access MData. + // FIXME: If the subscript operator is made constexpr these can be removed. + friend class detail::FlattenMArrayArgHelper; constexpr void initialize_data(const Type &Arg) { for (size_t i = 0; i < NumElements; ++i) { @@ -121,10 +132,10 @@ template class marray { template ::value && + detail::AllSuitableArgTypes::value && detail::GetMArrayArgsSize::value == NumElements>> constexpr marray(const ArgTN &...Args) - : marray{MArrayArgArrayCreator::Create(Args...), + : marray{detail::MArrayArgArrayCreator::Create(Args...), std::make_index_sequence()} {} constexpr marray(const marray &Rhs) = default; @@ -142,7 +153,6 @@ template class marray { // subscript operator reference operator[](std::size_t index) { return MData[index]; } - const_reference operator[](std::size_t index) const { return MData[index]; } marray &operator=(const marray &Rhs) = default; diff --git a/sycl/test-e2e/Complex/sycl_complex_helper.hpp b/sycl/test-e2e/Complex/sycl_complex_helper.hpp index a57f736f5244e..d90e50e215724 100644 --- a/sycl/test-e2e/Complex/sycl_complex_helper.hpp +++ b/sycl/test-e2e/Complex/sycl_complex_helper.hpp @@ -33,7 +33,8 @@ template <> const char *get_typename() { return "double"; } template <> const char *get_typename() { return "float"; } template <> const char *get_typename() { return "sycl::half"; } -// Helper to test each complex specilization +/// Helper to test each complex specilization + // Overload for cplx_test_cases template