Skip to content

Commit

Permalink
Refactored FindBestSplitsFromHistograms.
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles Auguste committed Jan 28, 2020
1 parent 818c228 commit 63dbcf4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 29 deletions.
65 changes: 36 additions & 29 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,25 +542,16 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
OMP_LOOP_EX_BEGIN();
if (!is_feature_used[feature_index]) { continue; }
const int tid = omp_get_thread_num();
SplitInfo smaller_split;
train_data_->FixHistogram(feature_index,
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(),
smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_histogram_array_[feature_index].RawData());
int real_fidx = train_data_->RealFeatureIndex(feature_index);
smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
smaller_leaf_splits_->sum_gradients(),
smaller_leaf_splits_->sum_hessians(),
smaller_leaf_splits_->num_data_in_leaf(),
constraints_per_leaf_[smaller_leaf_splits_->LeafIndex()],
&smaller_split);
smaller_split.feature = real_fidx;
if (cegb_ != nullptr) {
smaller_split.gain -= cegb_->DetlaGain(feature_index, real_fidx, smaller_leaf_splits_->LeafIndex(), smaller_leaf_splits_->num_data_in_leaf(), smaller_split);
}
if (smaller_split > smaller_best[tid] && smaller_node_used_features[feature_index]) {
smaller_best[tid] = smaller_split;
}

ComputeBestSplitForFeature(smaller_leaf_histogram_array_,
smaller_leaf_splits_, feature_index, real_fidx,
smaller_node_used_features, tid, smaller_best);

// only has root leaf
if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; }

Expand All @@ -571,21 +562,11 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_histogram_array_[feature_index].RawData());
}
SplitInfo larger_split;
// find best threshold for larger child
larger_leaf_histogram_array_[feature_index].FindBestThreshold(
larger_leaf_splits_->sum_gradients(),
larger_leaf_splits_->sum_hessians(),
larger_leaf_splits_->num_data_in_leaf(),
constraints_per_leaf_[larger_leaf_splits_->LeafIndex()],
&larger_split);
larger_split.feature = real_fidx;
if (cegb_ != nullptr) {
larger_split.gain -= cegb_->DetlaGain(feature_index, real_fidx, larger_leaf_splits_->LeafIndex(), larger_leaf_splits_->num_data_in_leaf(), larger_split);
}
if (larger_split > larger_best[tid] && larger_node_used_features[feature_index]) {
larger_best[tid] = larger_split;
}

ComputeBestSplitForFeature(larger_leaf_histogram_array_,
larger_leaf_splits_, feature_index, real_fidx,
larger_node_used_features, tid, larger_best);

OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down Expand Up @@ -881,4 +862,30 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
}
}

void SerialTreeLearner::ComputeBestSplitForFeature(
FeatureHistogram *histogram_array_,
const std::unique_ptr<LeafSplits>& leaf_splits_, int feature_index, int real_fidx,
const std::vector<int8_t>& node_used_features, const int tid,
std::vector<SplitInfo>& best) {

SplitInfo new_split;

histogram_array_[feature_index].FindBestThreshold(
leaf_splits_->sum_gradients(), leaf_splits_->sum_hessians(),
leaf_splits_->num_data_in_leaf(),
constraints_per_leaf_[leaf_splits_->LeafIndex()], &new_split);

new_split.feature = real_fidx;

if (cegb_ != nullptr) {
new_split.gain -=
cegb_->DetlaGain(feature_index, real_fidx, leaf_splits_->LeafIndex(),
leaf_splits_->num_data_in_leaf(), new_split);
}

if (new_split > best[tid] && node_used_features[feature_index]) {
best[tid] = new_split;
}
}

} // namespace LightGBM
6 changes: 6 additions & 0 deletions src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ class SerialTreeLearner: public TreeLearner {
void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override;

void ComputeBestSplitForFeature(FeatureHistogram *histogram_array_,
const std::unique_ptr<LeafSplits>& leaf_splits_,
int feature_index, int real_fidx,
const std::vector<int8_t> &node_used_features,
const int tid, std::vector<SplitInfo> &best);

protected:
virtual std::vector<int8_t> GetUsedFeatures(bool is_tree_level);
/*!
Expand Down

0 comments on commit 63dbcf4

Please sign in to comment.