diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 4473173d2af1..0565a10f0ed6 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -1,52 +1,64 @@ /** * Copyright 2017-2023 by XGBoost Contributors */ -#include - -#include // for any, any_cast -#include -#include -#include - -#include "../collective/communicator-inl.h" -#include "../common/categorical.h" -#include "../common/math.h" -#include "../common/threading_utils.h" -#include "../data/adapter.h" -#include "../data/gradient_index.h" -#include "../gbm/gbtree_model.h" -#include "cpu_treeshap.h" // CalculateContributions -#include "predict_fn.h" -#include "xgboost/base.h" -#include "xgboost/data.h" -#include "xgboost/host_device_vector.h" -#include "xgboost/logging.h" -#include "xgboost/predictor.h" -#include "xgboost/tree_model.h" - -namespace xgboost { -namespace predictor { +#include // for max, fill, min +#include // for any, any_cast +#include // for assert +#include // for size_t +#include // for uint32_t, int32_t, uint64_t +#include // for unique_ptr, shared_ptr +#include // for char_traits, operator<<, basic_ostream +#include // for type_info +#include // for vector + +#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed +#include "../collective/communicator.h" // for Operation +#include "../common/bitfield.h" // for RBitField8 +#include "../common/categorical.h" // for IsCat, Decision +#include "../common/common.h" // for DivRoundUp +#include "../common/math.h" // for CheckNAN +#include "../common/threading_utils.h" // for ParallelFor +#include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter +#include "../data/gradient_index.h" // for GHistIndexMatrix +#include "../data/proxy_dmatrix.h" // for DMatrixProxy +#include "../gbm/gbtree_model.h" // for GBTreeModel, GBTreeModelParam +#include "cpu_treeshap.h" // for CalculateContributions +#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG +#include "predict_fn.h" // for GetNextNode, GetNextNodeMulti +#include "xgboost/base.h" // for bst_float, bst_node_t, bst_omp_uint, bst_fe... +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for Entry, DMatrix, MetaInfo, SparsePage, Batch... +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/learner.h" // for LearnerModelParam +#include "xgboost/linalg.h" // for TensorView, All, VectorView, Tensor +#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_NE +#include "xgboost/multi_target_tree_model.h" // for MultiTargetTree +#include "xgboost/predictor.h" // for PredictionCacheEntry, Predictor, PredictorReg +#include "xgboost/span.h" // for Span +#include "xgboost/tree_model.h" // for RegTree, MTNotImplemented, RTreeNodeStat + +namespace xgboost::predictor { DMLC_REGISTRY_FILE_TAG(cpu_predictor); +namespace scalar { template bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat, - RegTree::CategoricalSplitMatrix const& cats) { - bst_node_t nid = 0; - while (!tree[nid].IsLeaf()) { - unsigned split_index = tree[nid].SplitIndex(); + RegTree::CategoricalSplitMatrix const &cats) { + bst_node_t nidx{0}; + while (!tree[nidx].IsLeaf()) { + bst_feature_t split_index = tree[nidx].SplitIndex(); auto fvalue = feat.GetFvalue(split_index); - nid = GetNextNode( - tree[nid], nid, fvalue, has_missing && feat.IsMissing(split_index), cats); + nidx = GetNextNode( + tree[nidx], nidx, fvalue, has_missing && feat.IsMissing(split_index), cats); } - return nid; + return nidx; } bst_float PredValue(const SparsePage::Inst &inst, const std::vector> &trees, - const std::vector &tree_info, int bst_group, - RegTree::FVec *p_feats, unsigned tree_begin, - unsigned tree_end) { + const std::vector &tree_info, std::int32_t bst_group, + RegTree::FVec *p_feats, std::uint32_t tree_begin, std::uint32_t tree_end) { bst_float psum = 0.0f; p_feats->Fill(inst); for (size_t i = tree_begin; i < tree_end; ++i) { @@ -68,40 +80,92 @@ bst_float PredValue(const SparsePage::Inst &inst, } template -bst_float -PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree, - RegTree::CategoricalSplitMatrix const& cats) { - const bst_node_t leaf = p_feats.HasMissing() ? - GetLeafIndex(tree, p_feats, cats) : - GetLeafIndex(tree, p_feats, cats); +bst_float PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree, + RegTree::CategoricalSplitMatrix const &cats) { + const bst_node_t leaf = p_feats.HasMissing() + ? GetLeafIndex(tree, p_feats, cats) + : GetLeafIndex(tree, p_feats, cats); return tree[leaf].LeafValue(); } void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin, - const size_t tree_end, std::vector *out_preds, - const size_t predict_offset, const size_t num_group, - const std::vector &thread_temp, - const size_t offset, const size_t block_size) { - std::vector &preds = *out_preds; + const size_t tree_end, const size_t predict_offset, + const std::vector &thread_temp, const size_t offset, + const size_t block_size, linalg::TensorView out_predt) { for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) { const size_t gid = model.tree_info[tree_id]; auto const &tree = *model.trees[tree_id]; - auto const& cats = tree.GetCategoriesMatrix(); + auto const &cats = tree.GetCategoriesMatrix(); auto has_categorical = tree.HasCategoricalSplit(); if (has_categorical) { - for (size_t i = 0; i < block_size; ++i) { - preds[(predict_offset + i) * num_group + gid] += + for (std::size_t i = 0; i < block_size; ++i) { + out_predt(predict_offset + i, gid) += PredValueByOneTree(thread_temp[offset + i], tree, cats); } } else { - for (size_t i = 0; i < block_size; ++i) { - preds[(predict_offset + i) * num_group + gid] += - PredValueByOneTree(thread_temp[offset + i], tree, cats); + for (std::size_t i = 0; i < block_size; ++i) { + out_predt(predict_offset + i, gid) += + PredValueByOneTree(thread_temp[offset + i], tree, cats); } } } } +} // namespace scalar + +namespace multi { +template +bst_node_t GetLeafIndex(MultiTargetTree const &tree, const RegTree::FVec &feat, + RegTree::CategoricalSplitMatrix const &cats) { + bst_node_t nidx{0}; + while (!tree.IsLeaf(nidx)) { + unsigned split_index = tree.SplitIndex(nidx); + auto fvalue = feat.GetFvalue(split_index); + nidx = GetNextNodeMulti( + tree, nidx, fvalue, has_missing && feat.IsMissing(split_index), cats); + } + return nidx; +} + +template +void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tree, + RegTree::CategoricalSplitMatrix const &cats, + linalg::VectorView out_predt) { + bst_node_t const leaf = p_feats.HasMissing() + ? GetLeafIndex(tree, p_feats, cats) + : GetLeafIndex(tree, p_feats, cats); + auto leaf_value = tree.LeafValue(leaf); + assert(out_predt.Shape(0) == leaf_value.Shape(0) && "shape mismatch."); + for (size_t i = 0; i < leaf_value.Size(); ++i) { + out_predt(i) += leaf_value(i); + } +} + +void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin, + const size_t tree_end, const size_t predict_offset, + const std::vector &thread_temp, const size_t offset, + const size_t block_size, linalg::TensorView out_predt) { + for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) { + auto const &tree = *model.trees.at(tree_id); + auto cats = tree.GetCategoriesMatrix(); + bool has_categorical = tree.HasCategoricalSplit(); + + if (has_categorical) { + for (std::size_t i = 0; i < block_size; ++i) { + auto t_predts = out_predt.Slice(predict_offset + i, linalg::All()); + PredValueByOneTree(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats, + t_predts); + } + } else { + for (std::size_t i = 0; i < block_size; ++i) { + auto t_predts = out_predt.Slice(predict_offset + i, linalg::All()); + PredValueByOneTree(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats, + t_predts); + } + } + } +} +} // namespace multi template void FVecFill(const size_t block_size, const size_t batch_offset, const int num_feature, @@ -127,7 +191,7 @@ void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batc } namespace { -static size_t constexpr kUnroll = 8; +static std::size_t constexpr kUnroll = 8; } // anonymous namespace struct SparsePageView { @@ -227,15 +291,13 @@ class AdapterView { }; template -void PredictBatchByBlockOfRowsKernel( - DataView batch, std::vector *out_preds, - gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end, - std::vector *p_thread_temp, int32_t n_threads) { +void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &model, + int32_t tree_begin, int32_t tree_end, + std::vector *p_thread_temp, int32_t n_threads, + linalg::TensorView out_predt) { auto &thread_temp = *p_thread_temp; - int32_t const num_group = model.learner_model_param->num_output_group; - CHECK_EQ(model.param.size_leaf_vector, 0) - << "size_leaf_vector is enforced to 0 so far"; + CHECK_EQ(model.param.size_leaf_vector, 0) << "size_leaf_vector is enforced to 0 so far"; // parallel over local batch const auto nsize = static_cast(batch.Size()); const int num_feature = model.learner_model_param->num_feature; @@ -243,16 +305,19 @@ void PredictBatchByBlockOfRowsKernel( common::ParallelFor(n_blocks, n_threads, [&](bst_omp_uint block_id) { const size_t batch_offset = block_id * block_of_rows_size; - const size_t block_size = - std::min(nsize - batch_offset, block_of_rows_size); + const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size); const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size; - FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, - p_thread_temp); + FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp); // process block of rows through all trees to keep cache locality - PredictByAllTrees(model, tree_begin, tree_end, out_preds, - batch_offset + batch.base_rowid, num_group, thread_temp, - fvec_offset, block_size); + if (model.learner_model_param->IsVectorLeaf()) { + multi::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, + thread_temp, fvec_offset, block_size, out_predt); + } else { + scalar::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, + thread_temp, fvec_offset, block_size, out_predt); + } + FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp); }); } @@ -557,33 +622,6 @@ class ColumnSplitHelper { class CPUPredictor : public Predictor { protected: - void PredictGHistIndex(DMatrix *p_fmat, gbm::GBTreeModel const &model, int32_t tree_begin, - int32_t tree_end, std::vector *out_preds) const { - auto const n_threads = this->ctx_->Threads(); - - constexpr double kDensityThresh = .5; - size_t total = - std::max(p_fmat->Info().num_row_ * p_fmat->Info().num_col_, static_cast(1)); - double density = static_cast(p_fmat->Info().num_nonzero_) / static_cast(total); - bool blocked = density > kDensityThresh; - - std::vector feat_vecs; - InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1), &feat_vecs); - std::vector workspace(p_fmat->Info().num_col_ * kUnroll * n_threads); - auto ft = p_fmat->Info().feature_types.ConstHostVector(); - for (auto const &batch : p_fmat->GetBatches({})) { - if (blocked) { - PredictBatchByBlockOfRowsKernel( - GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, - out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads); - } else { - PredictBatchByBlockOfRowsKernel( - GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, - out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads); - } - } - } - void PredictDMatrix(DMatrix *p_fmat, std::vector *out_preds, gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const { if (p_fmat->IsColumnSplit()) { @@ -592,11 +630,6 @@ class CPUPredictor : public Predictor { return; } - if (!p_fmat->PageExists()) { - this->PredictGHistIndex(p_fmat, model, tree_begin, tree_end, out_preds); - return; - } - auto const n_threads = this->ctx_->Threads(); constexpr double kDensityThresh = .5; size_t total = @@ -606,16 +639,38 @@ class CPUPredictor : public Predictor { std::vector feat_vecs; InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1), &feat_vecs); - for (auto const &batch : p_fmat->GetBatches()) { - CHECK_EQ(out_preds->size(), - p_fmat->Info().num_row_ * model.learner_model_param->num_output_group); - if (blocked) { - PredictBatchByBlockOfRowsKernel( - SparsePageView{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads); - } else { - PredictBatchByBlockOfRowsKernel( - SparsePageView{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs, n_threads); + std::size_t n_samples = p_fmat->Info().num_row_; + std::size_t n_groups = model.learner_model_param->OutputLength(); + CHECK_EQ(out_preds->size(), n_samples * n_groups); + linalg::TensorView out_predt{*out_preds, {n_samples, n_groups}, ctx_->gpu_id}; + + if (!p_fmat->PageExists()) { + std::vector workspace(p_fmat->Info().num_col_ * kUnroll * n_threads); + auto ft = p_fmat->Info().feature_types.ConstHostVector(); + for (auto const &batch : p_fmat->GetBatches({})) { + if (blocked) { + PredictBatchByBlockOfRowsKernel( + GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, model, + tree_begin, tree_end, &feat_vecs, n_threads, out_predt); + } else { + PredictBatchByBlockOfRowsKernel( + GHistIndexMatrixView{batch, p_fmat->Info().num_col_, ft, workspace, n_threads}, model, + tree_begin, tree_end, &feat_vecs, n_threads, out_predt); + } + } + } else { + for (auto const &batch : p_fmat->GetBatches()) { + if (blocked) { + PredictBatchByBlockOfRowsKernel( + SparsePageView{&batch}, model, tree_begin, tree_end, &feat_vecs, n_threads, + out_predt); + + } else { + PredictBatchByBlockOfRowsKernel(SparsePageView{&batch}, model, + tree_begin, tree_end, &feat_vecs, + n_threads, out_predt); + } } } } @@ -623,17 +678,15 @@ class CPUPredictor : public Predictor { public: explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {} - void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, - const gbm::GBTreeModel &model, uint32_t tree_begin, - uint32_t tree_end = 0) const override { - auto* out_preds = &predts->predictions; + void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, const gbm::GBTreeModel &model, + uint32_t tree_begin, uint32_t tree_end = 0) const override { + auto *out_preds = &predts->predictions; // This is actually already handled in gbm, but large amount of tests rely on the // behaviour. if (tree_end == 0) { tree_end = model.trees.size(); } - this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin, - tree_end); + this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin, tree_end); } template @@ -653,13 +706,16 @@ class CPUPredictor : public Predictor { info.num_row_ = m->NumRows(); this->InitOutPredictions(info, &(out_preds->predictions), model); } + std::vector workspace(m->NumColumns() * kUnroll * n_threads); auto &predictions = out_preds->predictions.HostVector(); std::vector thread_temp; InitThreadTemp(n_threads * kBlockSize, &thread_temp); + std::size_t n_groups = model.learner_model_param->OutputLength(); + linalg::TensorView out_predt{predictions, {m->NumRows(), n_groups}, Context::kCpuId}; PredictBatchByBlockOfRowsKernel, kBlockSize>( - AdapterView(m.get(), missing, common::Span{workspace}, n_threads), - &predictions, model, tree_begin, tree_end, &thread_temp, n_threads); + AdapterView(m.get(), missing, common::Span{workspace}, n_threads), model, + tree_begin, tree_end, &thread_temp, n_threads, out_predt); } bool InplacePredict(std::shared_ptr p_m, const gbm::GBTreeModel &model, float missing, @@ -689,6 +745,7 @@ class CPUPredictor : public Predictor { void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) const override { + CHECK(!model.learner_model_param->IsVectorLeaf()) << "predict instance" << MTNotImplemented(); std::vector feat_vecs; feat_vecs.resize(1, RegTree::FVec()); feat_vecs[0].Init(model.learner_model_param->num_feature); @@ -701,31 +758,30 @@ class CPUPredictor : public Predictor { auto base_score = model.learner_model_param->BaseScore(ctx_)(0); // loop over output groups for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) { - (*out_preds)[gid] = - PredValue(inst, model.trees, model.tree_info, gid, &feat_vecs[0], 0, ntree_limit) + - base_score; + (*out_preds)[gid] = scalar::PredValue(inst, model.trees, model.tree_info, gid, &feat_vecs[0], + 0, ntree_limit) + + base_score; } } - void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit) const override { + void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *out_preds, + const gbm::GBTreeModel &model, unsigned ntree_limit) const override { auto const n_threads = this->ctx_->Threads(); std::vector feat_vecs; const int num_feature = model.learner_model_param->num_feature; InitThreadTemp(n_threads, &feat_vecs); - const MetaInfo& info = p_fmat->Info(); + const MetaInfo &info = p_fmat->Info(); // number of valid trees if (ntree_limit == 0 || ntree_limit > model.trees.size()) { ntree_limit = static_cast(model.trees.size()); } - std::vector& preds = out_preds->HostVector(); + std::vector &preds = out_preds->HostVector(); preds.resize(info.num_row_ * ntree_limit); // start collecting the prediction for (const auto &batch : p_fmat->GetBatches()) { // parallel over local batch auto page = batch.GetView(); - const auto nsize = static_cast(batch.Size()); - common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) { + common::ParallelFor(page.Size(), n_threads, [&](auto i) { const int tid = omp_get_thread_num(); auto ridx = static_cast(batch.base_rowid + i); RegTree::FVec &feats = feat_vecs[tid]; @@ -733,23 +789,28 @@ class CPUPredictor : public Predictor { feats.Init(num_feature); } feats.Fill(page[i]); - for (unsigned j = 0; j < ntree_limit; ++j) { - auto const& tree = *model.trees[j]; - auto const& cats = tree.GetCategoriesMatrix(); - bst_node_t tid = GetLeafIndex(tree, feats, cats); - preds[ridx * ntree_limit + j] = static_cast(tid); + for (std::uint32_t j = 0; j < ntree_limit; ++j) { + auto const &tree = *model.trees[j]; + auto const &cats = tree.GetCategoriesMatrix(); + bst_node_t nidx; + if (tree.IsMultiTarget()) { + nidx = multi::GetLeafIndex(*tree.GetMultiTargetTree(), feats, cats); + } else { + nidx = scalar::GetLeafIndex(tree, feats, cats); + } + preds[ridx * ntree_limit + j] = static_cast(nidx); } feats.Drop(page[i]); }); } } - void PredictContribution(DMatrix *p_fmat, - HostDeviceVector *out_contribs, + void PredictContribution(DMatrix *p_fmat, HostDeviceVector *out_contribs, const gbm::GBTreeModel &model, uint32_t ntree_limit, - std::vector const *tree_weights, - bool approximate, int condition, - unsigned condition_feature) const override { + std::vector const *tree_weights, bool approximate, + int condition, unsigned condition_feature) const override { + CHECK(!model.learner_model_param->IsVectorLeaf()) + << "Predict contribution" << MTNotImplemented(); auto const n_threads = this->ctx_->Threads(); const int num_feature = model.learner_model_param->num_feature; std::vector feat_vecs; @@ -825,11 +886,12 @@ class CPUPredictor : public Predictor { } } - void PredictInteractionContributions( - DMatrix *p_fmat, HostDeviceVector *out_contribs, - const gbm::GBTreeModel &model, unsigned ntree_limit, - std::vector const *tree_weights, - bool approximate) const override { + void PredictInteractionContributions(DMatrix *p_fmat, HostDeviceVector *out_contribs, + const gbm::GBTreeModel &model, unsigned ntree_limit, + std::vector const *tree_weights, + bool approximate) const override { + CHECK(!model.learner_model_param->IsVectorLeaf()) + << "Predict interaction contribution" << MTNotImplemented(); const MetaInfo& info = p_fmat->Info(); const int ngroup = model.learner_model_param->num_output_group; size_t const ncolumns = model.learner_model_param->num_feature; @@ -884,5 +946,4 @@ class CPUPredictor : public Predictor { XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor") .describe("Make predictions using CPU.") .set_body([](Context const *ctx) { return new CPUPredictor(ctx); }); -} // namespace predictor -} // namespace xgboost +} // namespace xgboost::predictor diff --git a/src/predictor/predict_fn.h b/src/predictor/predict_fn.h index 5d0c175fcf65..dbaf4a75e060 100644 --- a/src/predictor/predict_fn.h +++ b/src/predictor/predict_fn.h @@ -1,13 +1,12 @@ -/*! - * Copyright 2021 by XGBoost Contributors +/** + * Copyright 2021-2023 by XGBoost Contributors */ #ifndef XGBOOST_PREDICTOR_PREDICT_FN_H_ #define XGBOOST_PREDICTOR_PREDICT_FN_H_ #include "../common/categorical.h" #include "xgboost/tree_model.h" -namespace xgboost { -namespace predictor { +namespace xgboost::predictor { template inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue, bool is_missing, @@ -24,6 +23,25 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs } } } -} // namespace predictor -} // namespace xgboost + +template +inline XGBOOST_DEVICE bst_node_t GetNextNodeMulti(MultiTargetTree const &tree, + bst_node_t const nidx, float fvalue, + bool is_missing, + RegTree::CategoricalSplitMatrix const &cats) { + if (has_missing && is_missing) { + return tree.DefaultChild(nidx); + } else { + if (has_categorical && common::IsCat(cats.split_type, nidx)) { + auto node_categories = + cats.categories.subspan(cats.node_ptr[nidx].beg, cats.node_ptr[nidx].size); + return common::Decision(node_categories, fvalue) ? tree.LeftChild(nidx) + : tree.RightChild(nidx); + } else { + return tree.LeftChild(nidx) + !(fvalue < tree.SplitCond(nidx)); + } + } +} + +} // namespace xgboost::predictor #endif // XGBOOST_PREDICTOR_PREDICT_FN_H_ diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index ebb56d2d3633..9236f569fb2c 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -224,19 +224,18 @@ std::string RandomDataGenerator::GenerateArrayInterface( return out; } -std::pair, std::string> -RandomDataGenerator::GenerateArrayInterfaceBatch( - HostDeviceVector *storage, size_t batches) const { - this->GenerateDense(storage); +std::pair, std::string> MakeArrayInterfaceBatch( + HostDeviceVector const* storage, std::size_t n_samples, bst_feature_t n_features, + std::size_t batches, std::int32_t device) { std::vector result(batches); std::vector objects; - size_t const rows_per_batch = rows_ / batches; + size_t const rows_per_batch = n_samples / batches; - auto make_interface = [storage, this](size_t offset, size_t rows) { + auto make_interface = [storage, device, n_features](std::size_t offset, std::size_t rows) { Json array_interface{Object()}; array_interface["data"] = std::vector(2); - if (device_ >= 0) { + if (device >= 0) { array_interface["data"][0] = Integer(reinterpret_cast(storage->DevicePointer() + offset)); array_interface["stream"] = Null{}; @@ -249,22 +248,22 @@ RandomDataGenerator::GenerateArrayInterfaceBatch( array_interface["shape"] = std::vector(2); array_interface["shape"][0] = rows; - array_interface["shape"][1] = cols_; + array_interface["shape"][1] = n_features; array_interface["typestr"] = String(", std::string> RandomDataGenerator::GenerateArrayInterfaceBatch( + HostDeviceVector* storage, size_t batches) const { + this->GenerateDense(storage); + return MakeArrayInterfaceBatch(storage, rows_, cols_, batches, device_); +} + std::string RandomDataGenerator::GenerateColumnarArrayInterface( std::vector> *data) const { CHECK(data); @@ -400,11 +405,14 @@ int NumpyArrayIterForTest::Next() { return 1; } -std::shared_ptr -GetDMatrixFromData(const std::vector &x, int num_rows, int num_columns){ +std::shared_ptr GetDMatrixFromData(const std::vector& x, std::size_t num_rows, + bst_feature_t num_columns) { data::DenseAdapter adapter(x.data(), num_rows, num_columns); - return std::shared_ptr(new data::SimpleDMatrix( - &adapter, std::numeric_limits::quiet_NaN(), 1)); + auto p_fmat = std::shared_ptr( + new data::SimpleDMatrix(&adapter, std::numeric_limits::quiet_NaN(), 1)); + CHECK_EQ(p_fmat->Info().num_row_, num_rows); + CHECK_EQ(p_fmat->Info().num_col_, num_columns); + return p_fmat; } std::unique_ptr CreateSparsePageDMatrix(bst_row_t n_samples, bst_feature_t n_features, @@ -572,12 +580,23 @@ std::unique_ptr CreateTrainedGBM(std::string name, Args kwargs, return gbm; } -ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols, - size_t batches) : rows_{rows}, cols_{cols}, n_batches_{batches} { +ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches) + : rows_{rows}, cols_{cols}, n_batches_{batches} { XGProxyDMatrixCreate(&proxy_); rng_.reset(new RandomDataGenerator{rows_, cols_, sparsity}); + std::tie(batches_, interface_) = rng_->GenerateArrayInterfaceBatch(&data_, n_batches_); +} + +ArrayIterForTest::ArrayIterForTest(Context const* ctx, HostDeviceVector const& data, + std::size_t n_samples, bst_feature_t n_features, + std::size_t n_batches) + : rows_{n_samples}, cols_{n_features}, n_batches_{n_batches} { + XGProxyDMatrixCreate(&proxy_); + this->data_.Resize(data.Size()); + CHECK_EQ(this->data_.Size(), rows_ * cols_ * n_batches); + this->data_.Copy(data); std::tie(batches_, interface_) = - rng_->GenerateArrayInterfaceBatch(&data_, n_batches_); + MakeArrayInterfaceBatch(&data_, rows_, cols_, n_batches_, ctx->gpu_id); } ArrayIterForTest::~ArrayIterForTest() { XGDMatrixFree(proxy_); } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index ec0abf32b452..279e3f75951e 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -188,7 +188,7 @@ class SimpleRealUniformDistribution { }; template -Json GetArrayInterface(HostDeviceVector *storage, size_t rows, size_t cols) { +Json GetArrayInterface(HostDeviceVector const* storage, size_t rows, size_t cols) { Json array_interface{Object()}; array_interface["data"] = std::vector(2); if (storage->DeviceCanRead()) { @@ -318,8 +318,8 @@ GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) { return x; } -std::shared_ptr GetDMatrixFromData(const std::vector &x, - int num_rows, int num_columns); +std::shared_ptr GetDMatrixFromData(const std::vector& x, std::size_t num_rows, + bst_feature_t num_columns); /** * \brief Create Sparse Page using data iterator. @@ -394,7 +394,7 @@ typedef void *DMatrixHandle; // NOLINT(*); class ArrayIterForTest { protected: HostDeviceVector data_; - size_t iter_ {0}; + size_t iter_{0}; DMatrixHandle proxy_; std::unique_ptr rng_; @@ -418,6 +418,11 @@ class ArrayIterForTest { auto Proxy() -> decltype(proxy_) { return proxy_; } explicit ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches); + /** + * \brief Create iterator with user provided data. + */ + explicit ArrayIterForTest(Context const* ctx, HostDeviceVector const& data, + std::size_t n_samples, bst_feature_t n_features, std::size_t n_batches); virtual ~ArrayIterForTest(); }; @@ -433,6 +438,10 @@ class NumpyArrayIterForTest : public ArrayIterForTest { public: explicit NumpyArrayIterForTest(float sparsity, size_t rows = Rows(), size_t cols = Cols(), size_t batches = Batches()); + explicit NumpyArrayIterForTest(Context const* ctx, HostDeviceVector const& data, + std::size_t n_samples, bst_feature_t n_features, + std::size_t n_batches) + : ArrayIterForTest{ctx, data, n_samples, n_features, n_batches} {} int Next() override; ~NumpyArrayIterForTest() override = default; }; diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 9a0ebee18c53..401d33c4d04d 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -305,4 +305,10 @@ TEST(CpuPredictor, Sparse) { TestSparsePrediction(0.2, "cpu_predictor"); TestSparsePrediction(0.8, "cpu_predictor"); } + +TEST(CpuPredictor, Multi) { + Context ctx; + ctx.nthread = 1; + TestVectorLeafPrediction(&ctx); +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 3e8a94c75ab9..4570a010df67 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -1,28 +1,34 @@ -/*! - * Copyright 2020-2021 by Contributors +/** + * Copyright 2020-2023 by XGBoost Contributors */ - #include "test_predictor.h" #include -#include -#include -#include -#include - -#include "../../../src/common/bitfield.h" -#include "../../../src/common/categorical.h" -#include "../../../src/common/io.h" -#include "../../../src/data/adapter.h" -#include "../../../src/data/proxy_dmatrix.h" -#include "../helpers.h" +#include // for Context +#include // for DMatrix, BatchIterator, BatchSet, MetaInfo +#include // for HostDeviceVector +#include // for PredictionCacheEntry, Predictor, Predic... + +#include // for max +#include // for numeric_limits +#include // for unordered_map + +#include "../../../src/common/bitfield.h" // for LBitField32 +#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix +#include "../../../src/data/proxy_dmatrix.h" // for DMatrixProxy +#include "../helpers.h" // for GetDMatrixFromData, RandomDataGenerator +#include "xgboost/json.h" // for Json, Object, get, String +#include "xgboost/linalg.h" // for MakeVec, Tensor, TensorView, Vector +#include "xgboost/logging.h" // for CHECK +#include "xgboost/span.h" // for operator!=, SpanIterator, Span +#include "xgboost/tree_model.h" // for RegTree namespace xgboost { TEST(Predictor, PredictionCache) { size_t constexpr kRows = 16, kCols = 4; PredictionContainer container; - DMatrix* m; + DMatrix *m; // Add a cache that is immediately expired. auto add_cache = [&]() { auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); @@ -412,4 +418,101 @@ void TestSparsePrediction(float sparsity, std::string predictor) { } } } + +void TestVectorLeafPrediction(Context const *ctx) { + std::unique_ptr cpu_predictor = + std::unique_ptr(Predictor::Create("cpu_predictor", ctx)); + + size_t constexpr kRows = 5; + size_t constexpr kCols = 5; + + LearnerModelParam mparam{static_cast(kCols), + linalg::Vector{{0.5}, {1}, Context::kCpuId}, 1, 3, + MultiStrategy::kMonolithic}; + + std::vector> trees; + trees.emplace_back(new RegTree{mparam.LeafLength(), mparam.num_feature}); + + std::vector p_w(mparam.LeafLength(), 0.0f); + std::vector l_w(mparam.LeafLength(), 1.0f); + std::vector r_w(mparam.LeafLength(), 2.0f); + + auto &tree = trees.front(); + tree->ExpandNode(0, static_cast(1), 2.0, true, + linalg::MakeVec(p_w.data(), p_w.size()), linalg::MakeVec(l_w.data(), l_w.size()), + linalg::MakeVec(r_w.data(), r_w.size())); + ASSERT_TRUE(tree->IsMultiTarget()); + ASSERT_TRUE(mparam.IsVectorLeaf()); + + gbm::GBTreeModel model{&mparam, ctx}; + model.CommitModel(std::move(trees), 0); + + auto run_test = [&](float expected, HostDeviceVector *p_data) { + { + auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols); + PredictionCacheEntry predt_cache; + cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); + ASSERT_EQ(predt_cache.predictions.Size(), kRows * mparam.LeafLength()); + cpu_predictor->PredictBatch(p_fmat.get(), &predt_cache, model, 0, 1); + auto const &h_predt = predt_cache.predictions.HostVector(); + for (auto v : h_predt) { + ASSERT_EQ(v, expected); + } + } + + { + // inplace + PredictionCacheEntry predt_cache; + auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols); + cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); + auto arr = GetArrayInterface(p_data, kRows, kCols); + std::string str; + Json::Dump(arr, &str); + auto proxy = std::shared_ptr(new data::DMatrixProxy{}); + dynamic_cast(proxy.get())->SetArrayData(str.data()); + cpu_predictor->InplacePredict(proxy, model, std::numeric_limits::quiet_NaN(), + &predt_cache, 0, 1); + auto const &h_predt = predt_cache.predictions.HostVector(); + for (auto v : h_predt) { + ASSERT_EQ(v, expected); + } + } + + { + // ghist + PredictionCacheEntry predt_cache; + auto &h_data = p_data->HostVector(); + // give it at least two bins, otherwise the histogram cuts only have min and max values. + for (std::size_t i = 0; i < 5; ++i) { + h_data[i] = 1.0; + } + auto p_fmat = GetDMatrixFromData(p_data->ConstHostVector(), kRows, kCols); + + cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); + + auto iter = NumpyArrayIterForTest{ctx, *p_data, kRows, static_cast(kCols), + static_cast(1)}; + p_fmat = + std::make_shared(&iter, iter.Proxy(), nullptr, Reset, Next, + std::numeric_limits::quiet_NaN(), 0, 256); + + cpu_predictor->InitOutPredictions(p_fmat->Info(), &predt_cache.predictions, model); + cpu_predictor->PredictBatch(p_fmat.get(), &predt_cache, model, 0, 1); + auto const &h_predt = predt_cache.predictions.HostVector(); + // the smallest v uses the min_value from histogram cuts, which leads to a left leaf + // during prediction. + for (std::size_t i = 5; i < h_predt.size(); ++i) { + ASSERT_EQ(h_predt[i], expected) << i; + } + } + }; + + // go to right + HostDeviceVector data(kRows * kCols, model.trees.front()->SplitCond(RegTree::kRoot) + 1.0); + run_test(2.5, &data); + + // go to left + data.HostVector().assign(data.Size(), model.trees.front()->SplitCond(RegTree::kRoot) - 1.0); + run_test(1.5, &data); +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index 61b05b31bb91..56c1523a1cf1 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -1,9 +1,16 @@ +/** + * Copyright 2020-2023 by XGBoost Contributors + */ #ifndef XGBOOST_TEST_PREDICTOR_H_ #define XGBOOST_TEST_PREDICTOR_H_ +#include // for Context #include -#include + #include +#include + +#include "../../../src/gbm/gbtree_model.h" // for GBTreeModel #include "../helpers.h" namespace xgboost { @@ -48,7 +55,7 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols, PredictionCacheEntry precise_out_predictions; predictor->InitOutPredictions(p_dmat->Info(), &precise_out_predictions.predictions, model); predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0); - ASSERT_FALSE(p_dmat->PageExists()); + CHECK(!p_dmat->PageExists()); } } @@ -69,6 +76,8 @@ void TestCategoricalPredictLeaf(StringView name); void TestIterationRange(std::string name); void TestSparsePrediction(float sparsity, std::string predictor); + +void TestVectorLeafPrediction(Context const* ctx); } // namespace xgboost #endif // XGBOOST_TEST_PREDICTOR_H_