From d2facced7e7c35ce1f49020cf6597a9db7a69f77 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 27 Jan 2021 23:16:22 +0800 Subject: [PATCH 1/3] Make prediction thread safe. --- include/xgboost/predictor.h | 10 ++-- src/predictor/cpu_predictor.cc | 40 ++++++++-------- src/predictor/gpu_predictor.cu | 87 +++++++++++++++++----------------- tests/cpp/test_learner.cc | 5 +- 4 files changed, 73 insertions(+), 69 deletions(-) diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 42a5275e1dbb..5ab734359a4f 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -132,7 +132,7 @@ class Predictor { */ virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds, const gbm::GBTreeModel& model, int tree_begin, - uint32_t const ntree_limit = 0) = 0; + uint32_t const ntree_limit = 0) const = 0; /** * \brief Inplace prediction. @@ -161,7 +161,7 @@ class Predictor { virtual void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, - unsigned ntree_limit = 0) = 0; + unsigned ntree_limit = 0) const = 0; /** * \brief predict the leaf index of each tree, the output will be nsample * @@ -175,7 +175,7 @@ class Predictor { virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, - unsigned ntree_limit = 0) = 0; + unsigned ntree_limit = 0) const = 0; /** * \fn virtual void Predictor::PredictContribution( DMatrix* dmat, @@ -203,14 +203,14 @@ class Predictor { std::vector* tree_weights = nullptr, bool approximate = false, int condition = 0, - unsigned condition_feature = 0) = 0; + unsigned condition_feature = 0) const = 0; virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, unsigned ntree_limit = 0, std::vector* tree_weights = nullptr, - bool approximate = false) = 0; + bool approximate = false) const = 0; /** diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 80c9b0ee95d1..ba94d130438e 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -183,11 +183,10 @@ class CPUPredictor : public Predictor { void PredictDMatrix(DMatrix *p_fmat, std::vector *out_preds, gbm::GBTreeModel const &model, int32_t tree_begin, - int32_t tree_end) { - std::lock_guard guard(lock_); + int32_t tree_end) const { const int threads = omp_get_max_threads(); - InitThreadTemp(threads*kBlockOfRowsSize, model.learner_model_param->num_feature, - &this->thread_temp_); + std::vector feat_vecs; + InitThreadTemp(threads, model.learner_model_param->num_feature, &feat_vecs); for (auto const& batch : p_fmat->GetBatches()) { CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * model.learner_model_param->num_output_group); @@ -195,7 +194,7 @@ class CPUPredictor : public Predictor { PredictBatchByBlockOfRowsKernel, kBlockOfRowsSize>(SparsePageView{&batch}, out_preds, model, tree_begin, - tree_end, &thread_temp_); + tree_end, &feat_vecs); } } @@ -238,7 +237,7 @@ class CPUPredictor : public Predictor { // multi-output and forest. Same problem exists for tree_begin void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, const gbm::GBTreeModel& model, int tree_begin, - uint32_t const ntree_limit = 0) override { + uint32_t const ntree_limit = 0) const override { // tree_begin is not used, right now we just enforce it to be 0. CHECK_EQ(tree_begin, 0); auto* out_preds = &predts->predictions; @@ -326,10 +325,11 @@ class CPUPredictor : public Predictor { void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit) override { - if (thread_temp_.size() == 0) { - thread_temp_.resize(1, RegTree::FVec()); - thread_temp_[0].Init(model.learner_model_param->num_feature); + const gbm::GBTreeModel& model, unsigned ntree_limit) const override { + std::vector feat_vecs; + if (feat_vecs.size() == 0) { + feat_vecs.resize(1, RegTree::FVec()); + feat_vecs[0].Init(model.learner_model_param->num_feature); } ntree_limit *= model.learner_model_param->num_output_group; if (ntree_limit == 0 || ntree_limit > model.trees.size()) { @@ -340,15 +340,16 @@ class CPUPredictor : public Predictor { // loop over output groups for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) { (*out_preds)[gid] = PredValue(inst, model.trees, model.tree_info, gid, - &thread_temp_[0], 0, ntree_limit) + + &feat_vecs[0], 0, ntree_limit) + model.learner_model_param->base_score; } } void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit) override { + const gbm::GBTreeModel& model, unsigned ntree_limit) const override { const int nthread = omp_get_max_threads(); - InitThreadTemp(nthread, model.learner_model_param->num_feature, &this->thread_temp_); + std::vector feat_vecs; + InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs); const MetaInfo& info = p_fmat->Info(); // number of valid trees ntree_limit *= model.learner_model_param->num_output_group; @@ -366,7 +367,7 @@ class CPUPredictor : public Predictor { for (bst_omp_uint i = 0; i < nsize; ++i) { const int tid = omp_get_thread_num(); auto ridx = static_cast(batch.base_rowid + i); - RegTree::FVec &feats = thread_temp_[tid]; + RegTree::FVec &feats = feat_vecs[tid]; feats.Fill(page[i]); for (unsigned j = 0; j < ntree_limit; ++j) { int tid = model.trees[j]->GetLeafIndex(feats); @@ -381,9 +382,10 @@ class CPUPredictor : public Predictor { const gbm::GBTreeModel& model, uint32_t ntree_limit, std::vector* tree_weights, bool approximate, int condition, - unsigned condition_feature) override { + unsigned condition_feature) const override { const int nthread = omp_get_max_threads(); - InitThreadTemp(nthread, model.learner_model_param->num_feature, &this->thread_temp_); + std::vector feat_vecs; + InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs); const MetaInfo& info = p_fmat->Info(); // number of valid trees ntree_limit *= model.learner_model_param->num_output_group; @@ -414,7 +416,7 @@ class CPUPredictor : public Predictor { #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nsize; ++i) { auto row_idx = static_cast(batch.base_rowid + i); - RegTree::FVec &feats = thread_temp_[omp_get_thread_num()]; + RegTree::FVec &feats = feat_vecs[omp_get_thread_num()]; std::vector this_tree_contribs(ncolumns); // loop over all classes for (int gid = 0; gid < ngroup; ++gid) { @@ -452,7 +454,7 @@ class CPUPredictor : public Predictor { void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, unsigned ntree_limit, std::vector* tree_weights, - bool approximate) override { + bool approximate) const override { const MetaInfo& info = p_fmat->Info(); const int ngroup = model.learner_model_param->num_output_group; size_t const ncolumns = model.learner_model_param->num_feature; @@ -501,8 +503,6 @@ class CPUPredictor : public Predictor { } private: - std::mutex lock_; - std::vector thread_temp_; static size_t constexpr kBlockOfRowsSize = 64; }; diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 90e786f08b87..a438d75e81c8 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -501,17 +501,18 @@ size_t SharedMemoryBytes(size_t cols, size_t max_shared_memory_bytes) { class GPUPredictor : public xgboost::Predictor { private: void PredictInternal(const SparsePage& batch, + DeviceModel const& model, size_t num_features, HostDeviceVector* predictions, - size_t batch_offset) { + size_t batch_offset) const { batch.offset.SetDevice(generic_param_->gpu_id); batch.data.SetDevice(generic_param_->gpu_id); const uint32_t BLOCK_THREADS = 128; size_t num_rows = batch.Size(); auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); - + auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id); size_t shared_memory_bytes = - SharedMemoryBytes(num_features, max_shared_memory_bytes_); + SharedMemoryBytes(num_features, max_shared_memory_bytes); bool use_shared = shared_memory_bytes != 0; size_t entry_start = 0; @@ -519,51 +520,53 @@ class GPUPredictor : public xgboost::Predictor { num_features); dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( PredictKernel, data, - model_.nodes.ConstDeviceSpan(), + model.nodes.ConstDeviceSpan(), predictions->DeviceSpan().subspan(batch_offset), - model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(), - model_.split_types.ConstDeviceSpan(), - model_.categories_tree_segments.ConstDeviceSpan(), - model_.categories_node_segments.ConstDeviceSpan(), - model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_, - num_features, num_rows, entry_start, use_shared, model_.num_group); + model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), + model.split_types.ConstDeviceSpan(), + model.categories_tree_segments.ConstDeviceSpan(), + model.categories_node_segments.ConstDeviceSpan(), + model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_, + num_features, num_rows, entry_start, use_shared, model.num_group); } void PredictInternal(EllpackDeviceAccessor const& batch, + DeviceModel const& model, HostDeviceVector* out_preds, - size_t batch_offset) { + size_t batch_offset) const { const uint32_t BLOCK_THREADS = 256; size_t num_rows = batch.n_rows; auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); + DeviceModel d_model; bool use_shared = false; size_t entry_start = 0; dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} ( PredictKernel, batch, - model_.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), - model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(), - model_.split_types.ConstDeviceSpan(), - model_.categories_tree_segments.ConstDeviceSpan(), - model_.categories_node_segments.ConstDeviceSpan(), - model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_, + model.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), + model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), + model.split_types.ConstDeviceSpan(), + model.categories_tree_segments.ConstDeviceSpan(), + model.categories_node_segments.ConstDeviceSpan(), + model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_, batch.NumFeatures(), num_rows, entry_start, use_shared, - model_.num_group); + model.num_group); } void DevicePredictInternal(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, size_t tree_begin, - size_t tree_end) { - dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); + size_t tree_end) const { if (tree_end - tree_begin == 0) { return; } - model_.Init(model, tree_begin, tree_end, generic_param_->gpu_id); out_preds->SetDevice(generic_param_->gpu_id); auto const& info = dmat->Info(); + DeviceModel d_model; + d_model.Init(model, tree_begin, tree_end, generic_param_->gpu_id); if (dmat->PageExists()) { size_t batch_offset = 0; for (auto &batch : dmat->GetBatches()) { - this->PredictInternal(batch, model.learner_model_param->num_feature, + this->PredictInternal(batch, d_model, model.learner_model_param->num_feature, out_preds, batch_offset); batch_offset += batch.Size() * model.learner_model_param->num_output_group; } @@ -572,6 +575,7 @@ class GPUPredictor : public xgboost::Predictor { for (auto const& page : dmat->GetBatches()) { this->PredictInternal( page.Impl()->GetDeviceAccessor(generic_param_->gpu_id), + d_model, out_preds, batch_offset); batch_offset += page.Impl()->n_rows; @@ -591,10 +595,9 @@ class GPUPredictor : public xgboost::Predictor { void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, const gbm::GBTreeModel& model, int tree_begin, - unsigned ntree_limit = 0) override { + unsigned ntree_limit = 0) const override { // This function is duplicated with CPU predictor PredictBatch, see comments in there. // FIXME(trivialfis): Remove the duplication. - std::lock_guard const guard(lock_); int device = generic_param_->gpu_id; CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data."; ConfigureDevice(device); @@ -702,7 +705,7 @@ class GPUPredictor : public xgboost::Predictor { const gbm::GBTreeModel& model, unsigned ntree_limit, std::vector*, bool approximate, int, - unsigned) override { + unsigned) const override { if (approximate) { LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor."; } @@ -755,7 +758,7 @@ class GPUPredictor : public xgboost::Predictor { const gbm::GBTreeModel& model, unsigned ntree_limit, std::vector*, - bool approximate) override { + bool approximate) const override { if (approximate) { LOG(FATAL) << "[Internal error]: " << __func__ << " approximate is not implemented in GPU Predictor."; @@ -828,21 +831,21 @@ class GPUPredictor : public xgboost::Predictor { void PredictInstance(const SparsePage::Inst&, std::vector*, - const gbm::GBTreeModel&, unsigned) override { + const gbm::GBTreeModel&, unsigned) const override { LOG(FATAL) << "[Internal error]: " << __func__ << " is not implemented in GPU Predictor."; } void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* predictions, const gbm::GBTreeModel& model, - unsigned ntree_limit) override { + unsigned ntree_limit) const override { dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); - ConfigureDevice(generic_param_->gpu_id); + auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id); const MetaInfo& info = p_fmat->Info(); constexpr uint32_t kBlockThreads = 128; size_t shared_memory_bytes = - SharedMemoryBytes(info.num_col_, max_shared_memory_bytes_); + SharedMemoryBytes(info.num_col_, max_shared_memory_bytes); bool use_shared = shared_memory_bytes != 0; bst_feature_t num_features = info.num_col_; bst_row_t num_rows = info.num_row_; @@ -854,7 +857,8 @@ class GPUPredictor : public xgboost::Predictor { } predictions->SetDevice(generic_param_->gpu_id); predictions->Resize(num_rows * real_ntree_limit); - model_.Init(model, 0, real_ntree_limit, generic_param_->gpu_id); + DeviceModel d_model; + d_model.Init(model, 0, real_ntree_limit, this->generic_param_->gpu_id); if (p_fmat->PageExists()) { for (auto const& batch : p_fmat->GetBatches()) { @@ -868,10 +872,10 @@ class GPUPredictor : public xgboost::Predictor { static_cast(common::DivRoundUp(num_rows, kBlockThreads)); dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} ( PredictLeafKernel, data, - model_.nodes.ConstDeviceSpan(), + d_model.nodes.ConstDeviceSpan(), predictions->DeviceSpan().subspan(batch_offset), - model_.tree_segments.ConstDeviceSpan(), - model_.tree_beg_, model_.tree_end_, num_features, num_rows, + d_model.tree_segments.ConstDeviceSpan(), + d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, entry_start, use_shared); batch_offset += batch.Size(); } @@ -884,10 +888,10 @@ class GPUPredictor : public xgboost::Predictor { static_cast(common::DivRoundUp(num_rows, kBlockThreads)); dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} ( PredictLeafKernel, data, - model_.nodes.ConstDeviceSpan(), + d_model.nodes.ConstDeviceSpan(), predictions->DeviceSpan().subspan(batch_offset), - model_.tree_segments.ConstDeviceSpan(), - model_.tree_beg_, model_.tree_end_, num_features, num_rows, + d_model.tree_segments.ConstDeviceSpan(), + d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, entry_start, use_shared); batch_offset += batch.Size(); } @@ -900,15 +904,12 @@ class GPUPredictor : public xgboost::Predictor { private: /*! \brief Reconfigure the device when GPU is changed. */ - void ConfigureDevice(int device) { + static size_t ConfigureDevice(int device) { if (device >= 0) { - max_shared_memory_bytes_ = dh::MaxSharedMemory(device); + return dh::MaxSharedMemory(device); } + return 0; } - - std::mutex lock_; - DeviceModel model_; - size_t max_shared_memory_bytes_ { 0 }; }; XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 7402c2cb71df..72761056910f 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -199,7 +199,7 @@ TEST(Learner, JsonModelIO) { // ``` TEST(Learner, MultiThreadedPredict) { size_t constexpr kRows = 1000; - size_t constexpr kCols = 1000; + size_t constexpr kCols = 100; std::shared_ptr p_dmat{ RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()}; @@ -219,8 +219,11 @@ TEST(Learner, MultiThreadedPredict) { threads.emplace_back([learner, p_data] { size_t constexpr kIters = 10; auto &entry = learner->GetThreadLocal().prediction_entry; + HostDeviceVector predictions; for (size_t iter = 0; iter < kIters; ++iter) { learner->Predict(p_data, false, &entry.predictions); + learner->Predict(p_data, false, &predictions, 0, true); // leaf + learner->Predict(p_data, false, &predictions, 0, false, true); // contribs } }); } From 04685c3c91a1854749289ce764ab8eb99b4894fd Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 27 Jan 2021 23:34:23 +0800 Subject: [PATCH 2/3] Wrong size. --- src/predictor/cpu_predictor.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index ba94d130438e..54a4427e6132 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -186,7 +186,8 @@ class CPUPredictor : public Predictor { int32_t tree_end) const { const int threads = omp_get_max_threads(); std::vector feat_vecs; - InitThreadTemp(threads, model.learner_model_param->num_feature, &feat_vecs); + InitThreadTemp(threads * kBlockOfRowsSize, + model.learner_model_param->num_feature, &feat_vecs); for (auto const& batch : p_fmat->GetBatches()) { CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * model.learner_model_param->num_output_group); @@ -327,10 +328,8 @@ class CPUPredictor : public Predictor { std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) const override { std::vector feat_vecs; - if (feat_vecs.size() == 0) { - feat_vecs.resize(1, RegTree::FVec()); - feat_vecs[0].Init(model.learner_model_param->num_feature); - } + feat_vecs.resize(1, RegTree::FVec()); + feat_vecs[0].Init(model.learner_model_param->num_feature); ntree_limit *= model.learner_model_param->num_output_group; if (ntree_limit == 0 || ntree_limit > model.trees.size()) { ntree_limit = static_cast(model.trees.size()); @@ -349,7 +348,7 @@ class CPUPredictor : public Predictor { const gbm::GBTreeModel& model, unsigned ntree_limit) const override { const int nthread = omp_get_max_threads(); std::vector feat_vecs; - InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs); + InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs); const MetaInfo& info = p_fmat->Info(); // number of valid trees ntree_limit *= model.learner_model_param->num_output_group; From 294572c9d43f09636319d4aa4773551cf5971318 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 28 Jan 2021 21:45:19 +0800 Subject: [PATCH 3/3] Unused variable. --- src/predictor/gpu_predictor.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index a438d75e81c8..f06bf722b1d0 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -536,7 +536,6 @@ class GPUPredictor : public xgboost::Predictor { const uint32_t BLOCK_THREADS = 256; size_t num_rows = batch.n_rows; auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); - DeviceModel d_model; bool use_shared = false; size_t entry_start = 0;