Skip to content

Commit

Permalink
Implement categorical prediction for CPU and GPU predict leaf. (#7001)
Browse files Browse the repository at this point in the history
* Categorical prediction with CPU predictor and GPU predict leaf.

* Implement categorical prediction for CPU prediction.
* Implement categorical prediction for GPU predict leaf.
* Refactor the prediction functions to have a unified get next node function.

Co-authored-by: Shvets Kirill <kirill.shvets@intel.com>
  • Loading branch information
trivialfis and Shvets Kirill authored Jun 11, 2021
1 parent 72f9daf commit f79cc4a
Show file tree
Hide file tree
Showing 10 changed files with 338 additions and 198 deletions.
66 changes: 18 additions & 48 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,10 @@ class RegTree : public Model {
bst_float right_leaf_weight, bst_float loss_change,
float sum_hess, float left_sum, float right_sum);

bool HasCategoricalSplit() const {
return !split_categories_.empty();
}

/*!
* \brief get current depth
* \param nid node id
Expand Down Expand Up @@ -537,13 +541,6 @@ class RegTree : public Model {
std::vector<Entry> data_;
bool has_missing_;
};
/*!
* \brief get the leaf index
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \return the leaf index of the given feature
*/
template <bool has_missing = true>
int GetLeafIndex(const FVec& feat) const;

/*!
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
Expand Down Expand Up @@ -582,14 +579,6 @@ class RegTree : public Model {
*/
void CalculateContributionsApprox(const RegTree::FVec& feat,
bst_float* out_contribs) const;
/*!
* \brief get next position of the tree given current pid
* \param pid Current node id.
* \param fvalue feature value if not missing.
* \param is_unknown Whether current required feature is missing.
*/
template <bool has_missing = true>
inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const;
/*!
* \brief dump the model in the requested format as a text string
* \param fmap feature map that may help give interpretations of feature
Expand Down Expand Up @@ -627,6 +616,20 @@ class RegTree : public Model {
size_t size {0};
};

struct CategoricalSplitMatrix {
common::Span<FeatureType const> split_type;
common::Span<uint32_t const> categories;
common::Span<Segment const> node_ptr;
};

CategoricalSplitMatrix GetCategoriesMatrix() const {
CategoricalSplitMatrix view;
view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
view.categories = this->GetSplitCategories();
view.node_ptr = common::Span<Segment const>(split_categories_segments_);
return view;
}

private:
void LoadCategoricalSplit(Json const& in);
void SaveCategoricalSplit(Json* p_out) const;
Expand Down Expand Up @@ -724,38 +727,5 @@ inline bool RegTree::FVec::IsMissing(size_t i) const {
inline bool RegTree::FVec::HasMissing() const {
return has_missing_;
}

template <bool has_missing>
inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
unsigned split_index = (*this)[nid].SplitIndex();
nid = this->GetNext<has_missing>(nid, feat.GetFvalue(split_index),
has_missing && feat.IsMissing(split_index));
}
return nid;
}

/*! \brief get next position of the tree given current pid */
template <bool has_missing>
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
if (has_missing) {
if (is_unknown) {
return (*this)[pid].DefaultChild();
} else {
if (fvalue < (*this)[pid].SplitCond()) {
return (*this)[pid].LeftChild();
} else {
return (*this)[pid].RightChild();
}
}
} else {
// 35% speed up due to reduced miss branch predictions
// The following expression returns the left child if (fvalue < split_cond);
// the right child otherwise.
return (*this)[pid].LeftChild() + !(fvalue < (*this)[pid].SplitCond());
}
}

} // namespace xgboost
#endif // XGBOOST_TREE_MODEL_H_
103 changes: 76 additions & 27 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,31 @@
#include "xgboost/logging.h"
#include "xgboost/host_device_vector.h"

#include "predict_fn.h"
#include "../data/adapter.h"
#include "../common/math.h"
#include "../common/threading_utils.h"
#include "../common/categorical.h"
#include "../gbm/gbtree_model.h"

namespace xgboost {
namespace predictor {

DMLC_REGISTRY_FILE_TAG(cpu_predictor);

template <bool has_missing, bool has_categorical>
bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat,
RegTree::CategoricalSplitMatrix const& cats) {
bst_node_t nid = 0;
while (!tree[nid].IsLeaf()) {
unsigned split_index = tree[nid].SplitIndex();
auto fvalue = feat.GetFvalue(split_index);
nid = GetNextNode<has_missing, has_categorical>(
tree[nid], nid, fvalue, has_missing && feat.IsMissing(split_index), cats);
}
return nid;
}

bst_float PredValue(const SparsePage::Inst &inst,
const std::vector<std::unique_ptr<RegTree>> &trees,
const std::vector<int> &tree_info, int bst_group,
Expand All @@ -35,32 +50,59 @@ bst_float PredValue(const SparsePage::Inst &inst,
p_feats->Fill(inst);
for (size_t i = tree_begin; i < tree_end; ++i) {
if (tree_info[i] == bst_group) {
int tid = trees[i]->GetLeafIndex(*p_feats);
psum += (*trees[i])[tid].LeafValue();
auto const &tree = *trees[i];
bool has_categorical = tree.HasCategoricalSplit();

auto categories = common::Span<uint32_t const>{tree.GetSplitCategories()};
auto split_types = tree.GetSplitTypes();
auto categories_ptr =
common::Span<RegTree::Segment const>{tree.GetSplitCategoriesPtr()};
auto cats = tree.GetCategoriesMatrix();
bst_node_t nidx = -1;
if (has_categorical) {
nidx = GetLeafIndex<true, true>(tree, *p_feats, cats);
} else {
nidx = GetLeafIndex<true, false>(tree, *p_feats, cats);
}
psum += (*trees[i])[nidx].LeafValue();
}
}
p_feats->Drop(inst);
return psum;
}

inline bst_float PredValueByOneTree(const RegTree::FVec& p_feats,
const std::unique_ptr<RegTree>& tree) {
const int lid = p_feats.HasMissing() ? tree->GetLeafIndex<true>(p_feats) :
tree->GetLeafIndex<false>(p_feats); // 35% speed up
return (*tree)[lid].LeafValue();
template <bool has_categorical>
bst_float
PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree,
RegTree::CategoricalSplitMatrix const& cats) {
const bst_node_t leaf = p_feats.HasMissing() ?
GetLeafIndex<true, has_categorical>(tree, p_feats, cats) :
GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
return tree[leaf].LeafValue();
}

inline void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
const size_t tree_end, std::vector<bst_float>* out_preds,
const size_t predict_offset, const size_t num_group,
const std::vector<RegTree::FVec> &thread_temp,
const size_t offset, const size_t block_size) {
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
const size_t tree_end, std::vector<bst_float> *out_preds,
const size_t predict_offset, const size_t num_group,
const std::vector<RegTree::FVec> &thread_temp,
const size_t offset, const size_t block_size) {
std::vector<bst_float> &preds = *out_preds;
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
const size_t gid = model.tree_info[tree_id];
for (size_t i = 0; i < block_size; ++i) {
preds[(predict_offset + i) * num_group + gid] += PredValueByOneTree(thread_temp[offset + i],
model.trees[tree_id]);
auto const &tree = *model.trees[tree_id];
auto const& cats = tree.GetCategoriesMatrix();
auto has_categorical = tree.HasCategoricalSplit();

if (has_categorical) {
for (size_t i = 0; i < block_size; ++i) {
preds[(predict_offset + i) * num_group + gid] +=
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
}
} else {
for (size_t i = 0; i < block_size; ++i) {
preds[(predict_offset + i) * num_group + gid] +=
PredValueByOneTree<false>(thread_temp[offset + i], tree, cats);
}
}
}
}
Expand All @@ -77,6 +119,7 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
feats.Fill(inst);
}
}

template <typename DataView>
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
Expand Down Expand Up @@ -145,28 +188,32 @@ class AdapterView {
};

template <typename DataView, size_t block_of_rows_size>
void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin,
int32_t tree_end,
std::vector<RegTree::FVec> *p_thread_temp) {
auto& thread_temp = *p_thread_temp;
void PredictBatchByBlockOfRowsKernel(
DataView batch, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end,
std::vector<RegTree::FVec> *p_thread_temp) {
auto &thread_temp = *p_thread_temp;
int32_t const num_group = model.learner_model_param->num_output_group;

CHECK_EQ(model.param.size_leaf_vector, 0)
<< "size_leaf_vector is enforced to 0 so far";
// parallel over local batch
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
const int num_feature = model.learner_model_param->num_feature;
const bst_omp_uint n_row_blocks = (nsize) / block_of_rows_size + !!((nsize) % block_of_rows_size);
common::ParallelFor(n_row_blocks, [&](bst_omp_uint block_id) {
omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size);

common::ParallelFor(n_blocks, [&](bst_omp_uint block_id) {
const size_t batch_offset = block_id * block_of_rows_size;
const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size);
const size_t block_size =
std::min(nsize - batch_offset, block_of_rows_size);
const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size;

FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp);
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset,
p_thread_temp);
// process block of rows through all trees to keep cache locality
PredictByAllTrees(model, tree_begin, tree_end, out_preds, batch_offset + batch.base_rowid,
num_group, thread_temp, fvec_offset, block_size);
PredictByAllTrees(model, tree_begin, tree_end, out_preds,
batch_offset + batch.base_rowid, num_group, thread_temp,
fvec_offset, block_size);
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
});
}
Expand Down Expand Up @@ -344,7 +391,9 @@ class CPUPredictor : public Predictor {
}
feats.Fill(page[i]);
for (unsigned j = 0; j < ntree_limit; ++j) {
int tid = model.trees[j]->GetLeafIndex(feats);
auto const& tree = *model.trees[j];
auto const& cats = tree.GetCategoriesMatrix();
bst_node_t tid = GetLeafIndex<true, true>(tree, feats, cats);
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
}
feats.Drop(page[i]);
Expand Down
Loading

0 comments on commit f79cc4a

Please sign in to comment.