Skip to content

Commit 29a73b8

Browse files
committed
Split the type list.
1 parent 8253481 commit 29a73b8

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

src/encoder/ordinal.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,21 @@ struct CatStrArrayView {
7373
return this->offsets.size_bytes() + values.size_bytes();
7474
}
7575
};
76+
77+
// We keep a single type list here for supported types and use various transformations to
78+
// add specializations. This way we can modify the type list with ease.
79+
80+
/**
81+
* @brief All the primitive types supported by the encoder.
82+
*/
83+
using CatPrimIndexTypes =
84+
std::tuple<std::int8_t, std::int16_t, std::int32_t, std::int64_t, float, double>;
85+
7686
/**
77-
* @brief All the types supported by the encoder.
87+
* @brief All the column types supported by the encoder.
7888
*/
79-
using CatIndexViewTypes =
80-
std::tuple<enc::CatStrArrayView, Span<std::int8_t const>, Span<std::int16_t const>,
81-
Span<std::int32_t const>, Span<std::int64_t const>, Span<float const>,
82-
Span<double const>>;
89+
using CatIndexViewTypes = decltype(std::tuple_cat(std::tuple<enc::CatStrArrayView>{},
90+
PrimToSpan<CatPrimIndexTypes>::Type{}));
8391

8492
/**
8593
* @brief Host categories view for a single column.

src/encoder/types.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <tuple> // for tuple
1313
#include <variant> // for variant
1414

15+
#include "xgboost/span.h" // for Span
16+
1517
#if defined(XGBOOST_USE_CUDA)
1618

1719
#include <cuda/std/variant> // for variant
@@ -27,7 +29,24 @@ struct Overloaded : Ts... {
2729
template <typename... Ts>
2830
ENC_DEVICE Overloaded(Ts...) -> Overloaded<Ts...>;
2931

32+
// Whether a type is a member of a type list (a.k.a tuple).
33+
template <typename... Ts>
34+
struct MemberOf;
35+
36+
template <typename T, typename... Ts>
37+
struct MemberOf<T, std::tuple<Ts...>> : public std::disjunction<std::is_same<T, Ts>...> {};
38+
39+
// Convert primitive types to span types.
40+
template <typename... Ts>
41+
struct PrimToSpan;
42+
43+
template <typename... Ts>
44+
struct PrimToSpan<std::tuple<Ts...>> {
45+
using Type = std::tuple<xgboost::common::Span<std::add_const_t<Ts>>...>;
46+
};
47+
3048
namespace cpu_impl {
49+
// Convert tuple of types to variant of types.
3150
template <typename... Ts>
3251
struct TupToVar;
3352

@@ -42,6 +61,7 @@ using TupToVarT = typename TupToVar<Ts...>::Type;
4261

4362
#if defined(XGBOOST_USE_CUDA)
4463
namespace cuda_impl {
64+
// Convert tuple of types to CUDA variant of types.
4565
template <typename... Ts>
4666
struct TupToVar {};
4767

0 commit comments

Comments
 (0)