Skip to content

Commit

Permalink
HD.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 4, 2023
1 parent 54af6f4 commit c1cb30e
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ inline LINALG_HD int Popc(uint64_t v) {
}

template <std::size_t D, typename Head>
auto IdxToArr(std::size_t (&arr)[D], Head head) {
LINALG_HD void IndexToArr(std::size_t (&arr)[D], Head head) {
static_assert(std::is_integral<std::remove_reference_t<Head>>::value, "Invalid index type.");
arr[D - 1] = head;
}
Expand All @@ -164,21 +164,24 @@ auto IdxToArr(std::size_t (&arr)[D], Head head) {
* \brief Convert index from parameter pack to C-style array.
*/
template <std::size_t D, typename Head, typename... Rest>
auto IdxToArr(std::size_t (&arr)[D], Head head, Rest &&...index) {
LINALG_HD void IndexToArr(std::size_t (&arr)[D], Head head, Rest &&...index) {
static_assert(sizeof...(Rest) < D, "Index overflow.");
static_assert(std::is_integral<std::remove_reference_t<Head>>::value, "Invalid index type.");
arr[D - sizeof...(Rest) - 1] = head;
IdxToArr(arr, std::forward<Rest>(index)...);
IndexToArr(arr, std::forward<Rest>(index)...);
}

template <class T, std::size_t N, std::size_t... Idx>
constexpr auto Arr2Tup(T (&arr)[N], std::index_sequence<Idx...>) {
constexpr auto ArrToTuple(T (&arr)[N], std::index_sequence<Idx...>) {
return std::make_tuple(arr[Idx]...);
}

/**
* \brief Convert C-styple array to std::tuple.
*/
template <class T, std::size_t N>
constexpr auto Arr2Tup(T (&arr)[N]) {
return Arr2Tup(arr, std::make_index_sequence<N>{});
constexpr auto ArrToTuple(T (&arr)[N]) {
return ArrToTuple(arr, std::make_index_sequence<N>{});
}

// uint division optimization inspired by the CIndexer in cupy. Division operation is
Expand All @@ -201,7 +204,7 @@ LINALG_HD auto UnravelImpl(I idx, common::Span<size_t const, D> shape) {
}
}
index[0] = idx;
return Arr2Tup(index);
return ArrToTuple(index);
}

template <size_t dim, typename I, int32_t D>
Expand Down Expand Up @@ -568,14 +571,14 @@ template <typename Container, typename... S,
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOLINT
using T = typename Container::value_type;
std::size_t in_shape[sizeof...(S)];
detail::IdxToArr(in_shape, std::forward<S>(shape)...);
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->gpu_id};
}

template <typename T, typename... S>
LINALG_HD auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
std::size_t in_shape[sizeof...(S)];
detail::IdxToArr(in_shape, std::forward<S>(shape)...);
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->gpu_id};
}

Expand Down

0 comments on commit c1cb30e

Please sign in to comment.