Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Choose predictor only when it's training. #9344

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien
out_model.param.num_parallel_tree = model_.param.num_parallel_tree;
}

void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool,
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
bst_layer_t layer_begin, bst_layer_t layer_end) {
CHECK(configured_);
if (layer_end == 0) {
Expand All @@ -526,7 +526,7 @@ void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool
CHECK_EQ(out_preds->version, 0);
}

auto const& predictor = GetPredictor(&out_preds->predictions, p_fmat);
auto const& predictor = GetPredictor(&out_preds->predictions, p_fmat, is_training);
if (out_preds->version == 0) {
// out_preds->Size() can be non-zero as it's initialized here before any
// tree is built at the 0^th iterator.
Expand All @@ -546,9 +546,8 @@ void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool
}
}

std::unique_ptr<Predictor> const &
GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
DMatrix *f_dmat) const {
std::unique_ptr<Predictor> const& GBTree::GetPredictor(HostDeviceVector<float> const* out_pred,
DMatrix* f_dmat, bool is_training) const {
CHECK(configured_);
if (tparam_.predictor != PredictorType::kAuto) {
if (tparam_.predictor == PredictorType::kGPUPredictor) {
Expand All @@ -574,7 +573,7 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,

// Data comes from SparsePageDMatrix. Since we are loading data in pages, no need to
// prevent data copy.
if (f_dmat && !f_dmat->SingleColBlock()) {
if ((f_dmat && !f_dmat->SingleColBlock()) || !is_training) {
if (ctx_->IsCPU()) {
return cpu_predictor_;
} else {
Expand All @@ -589,12 +588,11 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
}

// Data comes from Device DMatrix.
auto is_ellpack = f_dmat && f_dmat->PageExists<EllpackPage>() &&
!f_dmat->PageExists<SparsePage>();
auto is_ellpack =
f_dmat && f_dmat->PageExists<EllpackPage>() && !f_dmat->PageExists<SparsePage>();
// Data comes from device memory, like CuDF or CuPy.
auto is_from_device =
f_dmat && f_dmat->PageExists<SparsePage>() &&
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
auto is_from_device = f_dmat && f_dmat->PageExists<SparsePage>() &&
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
auto on_device = is_ellpack || is_from_device;

// Use GPU Predictor if data is already on device and gpu_id is set.
Expand Down Expand Up @@ -750,7 +748,7 @@ class Dart : public GBTree {
bool training, unsigned layer_begin,
unsigned layer_end) const {
CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented();
auto &predictor = this->GetPredictor(&p_out_preds->predictions, p_fmat);
auto &predictor = this->GetPredictor(&p_out_preds->predictions, p_fmat, training);
CHECK(predictor);
predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
Expand Down Expand Up @@ -843,7 +841,7 @@ class Dart : public GBTree {
}
CHECK(success) << msg;
} else {
predictor = this->GetPredictor().get();
predictor = this->GetPredictor(nullptr, nullptr, false).get();
bool success = predictor->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
CHECK(success) << msg << std::endl
<< "Current Predictor: "
Expand Down Expand Up @@ -886,7 +884,7 @@ class Dart : public GBTree {
std::vector<bst_float> *out_preds,
unsigned layer_begin, unsigned layer_end) override {
DropTrees(false);
auto &predictor = this->GetPredictor();
auto& predictor = this->GetPredictor(nullptr, nullptr, false);
uint32_t _, tree_end;
std::tie(_, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end);
predictor->PredictInstance(inst, out_preds, model_, tree_end);
Expand Down
17 changes: 9 additions & 8 deletions src/gbm/gbtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ class GBTree : public GradientBooster {
}
LOG(FATAL) << msg;
} else {
bool success = this->GetPredictor()->InplacePredict(p_m, model_, missing, out_preds,
tree_begin, tree_end);
bool success = this->GetPredictor(nullptr, nullptr, false)
->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
CHECK(success) << msg << std::endl
<< "Current Predictor: "
<< (tparam_.predictor == PredictorType::kCPUPredictor
Expand Down Expand Up @@ -349,7 +349,7 @@ class GBTree : public GradientBooster {
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: (0, "
"n_iteration), use model slicing instead.";
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, tree_end);
this->GetPredictor(nullptr, nullptr, false)->PredictLeaf(p_fmat, out_preds, model_, tree_end);
}

void PredictContribution(DMatrix* p_fmat,
Expand All @@ -361,7 +361,7 @@ class GBTree : public GradientBooster {
CHECK_EQ(tree_begin, 0)
<< "Predict contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor()->PredictContribution(
this->GetPredictor(nullptr, nullptr, false)->PredictContribution(
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
}

Expand All @@ -373,8 +373,9 @@ class GBTree : public GradientBooster {
CHECK_EQ(tree_begin, 0)
<< "Predict interaction contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor()->PredictInteractionContributions(
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
this->GetPredictor(nullptr, nullptr, false)
->PredictInteractionContributions(p_fmat, out_contribs, model_, tree_end, nullptr,
approximate);
}

[[nodiscard]] std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
Expand All @@ -390,8 +391,8 @@ class GBTree : public GradientBooster {
std::vector<HostDeviceVector<bst_node_t>>* out_position,
std::vector<std::unique_ptr<RegTree>>* ret);

std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred = nullptr,
DMatrix* f_dmat = nullptr) const;
std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred,
DMatrix* f_dmat, bool is_training) const;

// commit new trees all at once
virtual void CommitModel(TreesOneIter&& new_trees);
Expand Down