Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always respect forced splits, even when feature_fraction < 1.0 (fixes #4601) #4725

Merged
merged 17 commits into from
Nov 10, 2021
42 changes: 34 additions & 8 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <algorithm>
#include <queue>
#include <set>
#include <unordered_map>
#include <utility>

Expand Down Expand Up @@ -322,17 +323,22 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
}

void SerialTreeLearner::FindBestSplits(const Tree* tree) {
FindBestSplits(tree, nullptr);
}

void SerialTreeLearner::FindBestSplits(const Tree* tree, const std::set<int>* force_features) {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static, 256) if (num_features_ >= 512)
#pragma omp parallel for schedule(static, 256) if (num_features_ >= 512)
tongwu-sh marked this conversation as resolved.
Show resolved Hide resolved
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
if (!col_sampler_.is_feature_used_bytree()[feature_index] && (force_features == nullptr || force_features->find(feature_index) == force_features->end())) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}

bool use_subtract = parent_leaf_histogram_array_ != nullptr;

#ifdef USE_CUDA
Expand All @@ -344,6 +350,7 @@ void SerialTreeLearner::FindBestSplits(const Tree* tree) {
#else
ConstructHistograms(is_feature_used, use_subtract);
#endif

FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree);
}

Expand Down Expand Up @@ -463,11 +470,8 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf,
std::unordered_map<int, SplitInfo> forceSplitMap;
q.push(std::make_pair(left, *left_leaf));
while (!q.empty()) {
// before processing next node from queue, store info for current left/right leaf
// store "best split" for left and right, even if they might be overwritten by forced split
if (BeforeFindBestSplit(tree, *left_leaf, *right_leaf)) {
FindBestSplits(tree);
}
FindBestSplitsForForceSplitLeaf(tree, left_leaf, right_leaf, left, right);

// then, compute own splits
SplitInfo left_split;
SplitInfo right_split;
Expand Down Expand Up @@ -561,6 +565,28 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf,
return result_count;
}

void SerialTreeLearner::FindBestSplitsForForceSplitLeaf(Tree* tree, int* left_leaf, int* right_leaf, Json left_force_split_leaf_setting, Json right_force_split_leaf_setting) {
// before processing next node from queue, store info for current left/right leaf
// store "best split" for left and right, even if they might be overwritten by forced split
if (BeforeFindBestSplit(tree, *left_leaf, *right_leaf)) {
std::set<int> force_features;

if (!left_force_split_leaf_setting.is_null()) {
const int left_feature = left_force_split_leaf_setting["feature"].int_value();
const int left_inner_feature_index = train_data_->InnerFeatureIndex(left_feature);
force_features.insert(left_inner_feature_index);
}

if (!right_force_split_leaf_setting.is_null()) {
const int right_feature = right_force_split_leaf_setting["feature"].int_value();
const int right_inner_feature_index = train_data_->InnerFeatureIndex(right_feature);
force_features.insert(right_inner_feature_index);
}

FindBestSplits(tree, &force_features);
}
}

void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
int* right_leaf, bool update_cnt) {
Common::FunctionTimer fun_timer("SerialTreeLearner::SplitInner", global_timer);
Expand Down
5 changes: 5 additions & 0 deletions src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <memory>
#include <random>
#include <vector>
#include <set>

#include "col_sampler.hpp"
#include "data_partition.hpp"
Expand Down Expand Up @@ -142,6 +143,8 @@ class SerialTreeLearner: public TreeLearner {

virtual void FindBestSplits(const Tree* tree);

virtual void FindBestSplits(const Tree* tree, const std::set<int>* force_features);

virtual void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);

virtual void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree*);
Expand All @@ -165,6 +168,8 @@ class SerialTreeLearner: public TreeLearner {
int32_t ForceSplits(Tree* tree, int* left_leaf, int* right_leaf,
int* cur_depth);

void FindBestSplitsForForceSplitLeaf(LightGBM::Tree* tree, int* left_leaf, int* right_leaf, Json left, Json right);

/*!
* \brief Get the number of data in a leaf
* \param leaf_idx The index of leaf
Expand Down
28 changes: 28 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding: utf-8
import copy
import itertools
import json
import math
import pickle
import platform
Expand Down Expand Up @@ -2887,3 +2888,30 @@ def hook(obj):
dumped_model_str = str(bst.dump_model(5, 0, object_hook=hook))
assert "leaf_value" not in dumped_model_str
assert "LV" in dumped_model_str


def test_force_split_with_feature_fraction():
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)

forced_split = {
"feature": 0,
"threshold": 0.5
}

with open("forced_split.json", "w") as f:
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
f.write(json.dumps(forced_split))

params = {
"objective": "regression",
"feature_fraction": 0.6,
"force_col_wise": True,
"feature_fraction_seed": 1,
"forcedsplits_filename": "forced_split.json"
}

gbm = lgb.train(params, lgb_train)
tongwu-sh marked this conversation as resolved.
Show resolved Hide resolved
ret = mean_absolute_error(y_test, gbm.predict(X_test))

assert ret < 2.0
jameslamb marked this conversation as resolved.
Show resolved Hide resolved