Skip to content

Commit

Permalink
[c++] Avoid copy on Refit (#6478)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau authored Jul 10, 2024
1 parent cd4459a commit 1886bf5
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 22 deletions.
2 changes: 1 addition & 1 deletion include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class LIGHTGBM_EXPORT Boosting {
/*!
* \brief Update the tree output by new training data
*/
virtual void RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) = 0;
virtual void RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) = 0;

/*!
* \brief Training logic
Expand Down
23 changes: 18 additions & 5 deletions src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,24 @@ void Application::Predict() {
config_.precise_float_parser);
TextReader<int> result_reader(config_.output_result.c_str(), false);
result_reader.ReadAllLines();
std::vector<std::vector<int>> pred_leaf(result_reader.Lines().size());

size_t nrow = result_reader.Lines().size();
size_t ncol = 0;
if (nrow > 0) {
ncol = Common::StringToArray<int>(result_reader.Lines()[0], '\t').size();
}
std::vector<int> pred_leaf;
pred_leaf.resize(nrow * ncol);

#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < static_cast<int>(result_reader.Lines().size()); ++i) {
pred_leaf[i] = Common::StringToArray<int>(result_reader.Lines()[i], '\t');
for (int irow = 0; irow < static_cast<int>(nrow); ++irow) {
auto line_vec = Common::StringToArray<int>(result_reader.Lines()[irow], '\t');
CHECK_EQ(line_vec.size(), ncol);
for (int i_row_item = 0; i_row_item < static_cast<int>(ncol); ++i_row_item) {
pred_leaf[irow * ncol + i_row_item] = line_vec[i_row_item];
}
// Free memory
result_reader.Lines()[i].clear();
result_reader.Lines()[irow].clear();
}
DatasetLoader dataset_loader(config_, nullptr,
config_.num_class, config_.data.c_str());
Expand All @@ -242,7 +254,8 @@ void Application::Predict() {
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->RefitTree(pred_leaf);

boosting_->RefitTree(pred_leaf.data(), nrow, ncol);
boosting_->SaveModelToFile(0, -1, config_.saved_feature_importance_type,
config_.output_model.c_str());
Log::Info("Finished RefitTree");
Expand Down
18 changes: 10 additions & 8 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,32 +249,34 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
}
}

void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) {
CHECK_GT(tree_leaf_prediction.size(), 0);
CHECK_EQ(static_cast<size_t>(num_data_), tree_leaf_prediction.size());
CHECK_EQ(static_cast<size_t>(models_.size()), tree_leaf_prediction[0].size());
void GBDT::RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) {
CHECK_GT(nrow * ncol, 0);
CHECK_EQ(static_cast<size_t>(num_data_), nrow);
CHECK_EQ(models_.size(), ncol);

int num_iterations = static_cast<int>(models_.size() / num_tree_per_iteration_);
std::vector<int> leaf_pred(num_data_);
if (linear_tree_) {
std::vector<int> max_leaves_by_thread = std::vector<int>(OMP_NUM_THREADS(), 0);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < static_cast<int>(tree_leaf_prediction.size()); ++i) {
for (int i = 0; i < static_cast<int>(nrow); ++i) {
int tid = omp_get_thread_num();
for (size_t j = 0; j < tree_leaf_prediction[i].size(); ++j) {
max_leaves_by_thread[tid] = std::max(max_leaves_by_thread[tid], tree_leaf_prediction[i][j]);
for (size_t j = 0; j < ncol; ++j) {
max_leaves_by_thread[tid] = std::max(max_leaves_by_thread[tid], tree_leaf_prediction[i * ncol + j]);
}
}
int max_leaves = *std::max_element(max_leaves_by_thread.begin(), max_leaves_by_thread.end());
max_leaves += 1;
tree_learner_->InitLinear(train_data_, max_leaves);
}

for (int iter = 0; iter < num_iterations; ++iter) {
Boosting();
for (int tree_id = 0; tree_id < num_tree_per_iteration_; ++tree_id) {
int model_index = iter * num_tree_per_iteration_ + tree_id;
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < num_data_; ++i) {
leaf_pred[i] = tree_leaf_prediction[i][model_index];
leaf_pred[i] = tree_leaf_prediction[i * ncol + model_index];
CHECK_LT(leaf_pred[i], models_[model_index]->num_leaves());
}
size_t offset = static_cast<size_t>(tree_id) * num_data_;
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class GBDT : public GBDTBase {
*/
void Train(int snapshot_freq, const std::string& model_output_path) override;

void RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) override;
void RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) override;

/*!
* \brief Training logic
Expand Down
8 changes: 1 addition & 7 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,7 @@ class Booster {

void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
UNIQUE_LOCK(mutex_)
std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
for (int i = 0; i < nrow; ++i) {
for (int j = 0; j < ncol; ++j) {
v_leaf_preds[i][j] = leaf_preds[static_cast<size_t>(i) * static_cast<size_t>(ncol) + static_cast<size_t>(j)];
}
}
boosting_->RefitTree(v_leaf_preds);
boosting_->RefitTree(leaf_preds, nrow, ncol);
}

bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
Expand Down

0 comments on commit 1886bf5

Please sign in to comment.