Skip to content

Commit

Permalink
Fix the test.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 9, 2020
1 parent 4a8ba82 commit 4f31bfb
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
4 changes: 0 additions & 4 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,17 @@ class CPUPredictor : public Predictor {
// update cache entry
auto* out = &predts->predictions;
if (predts->predictions.Size() == 0) {
std::cout << __func__ << " Building the cache from sratch." << std::endl;
InitOutPredictions(m->Info(), out, model);
this->PredLoopInternal(m, &out->HostVector(), model, 0, model.trees.size());
} else if (model.learner_model_param_->num_output_group == 1 &&
updaters->size() > 0 &&
num_new_trees == 1 &&
updaters->back()->UpdatePredictionCache(m, out)) {
std::cout << __func__ << " The cache is updated by updater." << std::endl;
{}
} else {
std::cout << __func__ << " The cache is updated by CPU Predictor." << std::endl;
PredLoopInternal(m, &out->HostVector(), model, old_ntree, model.trees.size());
}
auto delta = num_new_trees / model.learner_model_param_->num_output_group;
std::cout << __func__ << " delta: " << delta;
predts->Update(delta);
}

Expand Down
18 changes: 12 additions & 6 deletions tests/cpp/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,19 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
}
}

gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param) {
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
(*trees.back())[0].SetLeaf(1.5f);
(*trees.back()).Stat(0).sum_hess = 1.0f;
gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, size_t n_classes) {
gbm::GBTreeModel model(param);
model.CommitModel(std::move(trees), 0);

for (size_t i = 0; i < n_classes; ++i) {
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
if (i == 0) {
(*trees.back())[0].SetLeaf(1.5f);
(*trees.back()).Stat(0).sum_hess = 1.0f;
}
model.CommitModel(std::move(trees), i);
}

return model;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
size_t n_rows, size_t n_cols, size_t page_size, bool deterministic,
const dmlc::TemporaryDirectory& tempdir = dmlc::TemporaryDirectory());

gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param);
gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, size_t n_classes = 1);

std::unique_ptr<GradientBooster> CreateTrainedGBM(
std::string name, Args kwargs, size_t kRows, size_t kCols,
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/predictor/test_gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ TEST(gpu_predictor, ExternalMemoryTest) {
param.num_output_group = n_classes;
param.base_score = 0.5;

gbm::GBTreeModel model = CreateTestModel(&param);
gbm::GBTreeModel model = CreateTestModel(&param, n_classes);
std::vector<std::unique_ptr<DMatrix>> dmats;
dmlc::TemporaryDirectory tmpdir;
std::string file0 = tmpdir.path + "/big_0.libsvm";
Expand Down

0 comments on commit 4f31bfb

Please sign in to comment.