Skip to content

[SYCL] Fix constexpr initialization of vec for half #8503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions sycl/include/sycl/detail/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,56 @@ static T<NewDim> convertToArrayOfN(T<OldDim> OldObj) {
return NewObj;
}

// Helper function for concatenating two std::array.
template <typename T, std::size_t... Is1, std::size_t... Is2>
constexpr std::array<T, sizeof...(Is1) + sizeof...(Is2)>
ConcatArrays(const std::array<T, sizeof...(Is1)> &A1,
const std::array<T, sizeof...(Is2)> &A2,
std::index_sequence<Is1...>, std::index_sequence<Is2...>) {
return {A1[Is1]..., A2[Is2]...};
}
template <typename T, std::size_t N1, std::size_t N2>
constexpr std::array<T, N1 + N2> ConcatArrays(const std::array<T, N1> &A1,
const std::array<T, N2> &A2) {
return ConcatArrays(A1, A2, std::make_index_sequence<N1>(),
std::make_index_sequence<N2>());
}

// Utility for creating an std::array from the results of flattening the
// arguments using a flattening functor.
template <typename DataT, template <typename, typename> typename FlattenF,
typename... ArgTN>
struct ArrayCreator;
template <typename DataT, template <typename, typename> typename FlattenF,
typename ArgT, typename... ArgTN>
struct ArrayCreator<DataT, FlattenF, ArgT, ArgTN...> {
static constexpr auto Create(const ArgT &Arg, const ArgTN &...Args) {
auto ImmArray = FlattenF<DataT, ArgT>()(Arg);
if constexpr (sizeof...(Args))
return ConcatArrays(
ImmArray, ArrayCreator<DataT, FlattenF, ArgTN...>::Create(Args...));
else
return ImmArray;
}
};
template <typename DataT, template <typename, typename> typename FlattenF>
struct ArrayCreator<DataT, FlattenF> {
static constexpr auto Create() { return std::array<DataT, 0>{}; }
};

// Helper function for creating an arbitrary sized array with the same value
// repeating.
template <typename T, size_t... Is>
static constexpr std::array<T, sizeof...(Is)>
RepeatValueHelper(const T &Arg, std::index_sequence<Is...>) {
auto ReturnArg = [&](size_t) { return Arg; };
return {ReturnArg(Is)...};
}
template <size_t N, typename T>
static constexpr std::array<T, N> RepeatValue(const T &Arg) {
return RepeatValueHelper(Arg, std::make_index_sequence<N>());
}

} // namespace detail
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
5 changes: 5 additions & 0 deletions sycl/include/sycl/half_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ template <int NumElements> struct half_vec {
StorageT s[NumElements];

__SYCL_CONSTEXPR_HALF half_vec() : s{0.0f} { initialize_data(); }
template <typename... Ts,
typename = std::enable_if_t<(sizeof...(Ts) == NumElements) &&
(std::is_same_v<half, Ts> && ...)>>
__SYCL_CONSTEXPR_HALF half_vec(const Ts &...hs) : s{hs...} {}

constexpr void initialize_data() {
for (size_t i = 0; i < NumElements; ++i) {
s[i] = StorageT(0.0f);
Expand Down
96 changes: 30 additions & 66 deletions sycl/include/sycl/marray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,65 +39,6 @@ template <typename T, typename... Ts> struct GetMArrayArgsSize<T, Ts...> {
static constexpr std::size_t value = 1 + GetMArrayArgsSize<Ts...>::value;
};

// Helper function for concatenating two std::array.
template <typename T, std::size_t... Is1, std::size_t... Is2>
constexpr std::array<T, sizeof...(Is1) + sizeof...(Is2)>
ConcatArrays(const std::array<T, sizeof...(Is1)> &A1,
const std::array<T, sizeof...(Is2)> &A2,
std::index_sequence<Is1...>, std::index_sequence<Is2...>) {
return {A1[Is1]..., A2[Is2]...};
}
template <typename T, std::size_t N1, std::size_t N2>
constexpr std::array<T, N1 + N2> ConcatArrays(const std::array<T, N1> &A1,
const std::array<T, N2> &A2) {
return ConcatArrays(A1, A2, std::make_index_sequence<N1>(),
std::make_index_sequence<N2>());
}

// Utility trait for creating an std::array from an marray.
template <typename DataT, typename T, std::size_t... Is>
constexpr std::array<T, sizeof...(Is)>
MArrayToArray(const marray<T, sizeof...(Is)> &A, std::index_sequence<Is...>) {
return {static_cast<DataT>(A.MData[Is])...};
}
template <typename DataT, typename T, std::size_t N>
constexpr std::array<T, N> MArrayToArray(const marray<T, N> &A) {
return MArrayToArray<DataT>(A, std::make_index_sequence<N>());
}

// Utility for creating an std::array from a arguments of either types
// convertible to DataT or marrays of a type convertible to DataT.
template <typename DataT, typename... ArgTN> struct ArrayCreator;
template <typename DataT, typename ArgT, typename... ArgTN>
struct ArrayCreator<DataT, ArgT, ArgTN...> {
static constexpr std::array<DataT, GetMArrayArgsSize<ArgT, ArgTN...>::value>
Create(const ArgT &Arg, const ArgTN &...Args) {
std::array<DataT, 1> ImmArray{static_cast<DataT>(Arg)};
if constexpr (sizeof...(Args))
return ConcatArrays(ImmArray,
ArrayCreator<DataT, ArgTN...>::Create(Args...));
else
return ImmArray;
}
};
template <typename DataT, typename T, std::size_t N, typename... ArgTN>
struct ArrayCreator<DataT, marray<T, N>, ArgTN...> {
static constexpr std::array<DataT,
GetMArrayArgsSize<marray<T, N>, ArgTN...>::value>
Create(const marray<T, N> &Arg, const ArgTN &...Args) {
auto ImmArray = MArrayToArray<DataT>(Arg);
if constexpr (sizeof...(Args))
return ConcatArrays(ImmArray,
ArrayCreator<DataT, ArgTN...>::Create(Args...));
else
return ImmArray;
}
};
template <typename DataT> struct ArrayCreator<DataT> {
static constexpr std::array<DataT, 0> Create() {
return std::array<DataT, 0>{};
}
};
} // namespace detail

/// Provides a cross-platform math array class template that works on
Expand Down Expand Up @@ -129,12 +70,35 @@ template <typename Type, std::size_t NumElements> class marray {
template <typename... ArgTN>
struct AllSuitableArgTypes : std::conjunction<IsSuitableArgType<ArgTN>...> {};

// FIXME: MArrayToArray needs to be a friend to access MData. If the subscript
// operator is made constexpr this can be removed.
template <typename, typename T, std::size_t... Is>
friend constexpr std::array<T, sizeof...(Is)>
detail::MArrayToArray(const marray<T, sizeof...(Is)> &,
std::index_sequence<Is...>);
// Utility trait for creating an std::array from an marray argument.
template <typename DataT, typename T, std::size_t... Is>
static constexpr std::array<DataT, sizeof...(Is)>
MArrayToArray(const marray<T, sizeof...(Is)> &A, std::index_sequence<Is...>) {
return {static_cast<DataT>(A.MData[Is])...};
}
template <typename DataT, typename T, std::size_t N>
static constexpr std::array<DataT, N>
FlattenMArrayArgHelper(const marray<T, N> &A) {
return MArrayToArray<DataT>(A, std::make_index_sequence<N>());
}
template <typename DataT, typename T>
static constexpr auto FlattenMArrayArgHelper(const T &A) {
return std::array<DataT, 1>{static_cast<DataT>(A)};
}
template <typename DataT, typename T> struct FlattenMArrayArg {
constexpr auto operator()(const T &A) const {
return FlattenMArrayArgHelper<DataT>(A);
}
};

// Alias for shortening the marray arguments to array converter.
template <typename DataT, typename... ArgTN>
using MArrayArgArrayCreator =
detail::ArrayCreator<DataT, FlattenMArrayArg, ArgTN...>;

// FIXME: Other marray specializations needs to be a friend to access MData.
// If the subscript operator is made constexpr this can be removed.
template <typename Type_, std::size_t NumElements_> friend class marray;

constexpr void initialize_data(const Type &Arg) {
for (size_t i = 0; i < NumElements; ++i) {
Expand All @@ -159,7 +123,7 @@ template <typename Type, std::size_t NumElements> class marray {
AllSuitableArgTypes<ArgTN...>::value &&
detail::GetMArrayArgsSize<ArgTN...>::value == NumElements>>
constexpr marray(const ArgTN &...Args)
: marray{detail::ArrayCreator<DataT, ArgTN...>::Create(Args...),
: marray{MArrayArgArrayCreator<DataT, ArgTN...>::Create(Args...),
std::make_index_sequence<NumElements>()} {}

constexpr marray(const marray<Type, NumElements> &Rhs) = default;
Expand Down
Loading