Skip to content

Commit

Permalink
Use DFS to get a valid but smaller new node order (#570)
Browse files Browse the repository at this point in the history
* Use DFS to get a valid but smaller new node order

* Update lightgbm.cc

* Add a unit test

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
  • Loading branch information
tmct and hcho3 authored Jul 13, 2024
1 parent 27b81f7 commit f1b910e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 10 deletions.
20 changes: 11 additions & 9 deletions src/model_loader/lightgbm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,22 +526,24 @@ inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {
for (auto const& lgb_tree : lgb_trees_) {
builder->StartTree();

// Assign node ID's so that a breadth-wise traversal would yield
// Assign node ID's so that a depth-wise traversal would yield
// the monotonic sequence 0, 1, 2, ...
// We re-arrange nodes here, since LightGBM uses negative indices to distinguish leaf nodes
// from internal nodes.
std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
std::deque<std::pair<int, int>> Q; // (old ID, new ID) pair
int dfs_index = 1;
if (lgb_tree.num_leaves == 0) {
continue;
} else if (lgb_tree.num_leaves == 1) {
// A constant-value tree with a single root node that's also a leaf
Q.emplace(-1, 0);
Q.emplace_front(-1, dfs_index);
} else {
Q.emplace(0, 0);
Q.emplace_front(0, dfs_index);
}
dfs_index++;
while (!Q.empty()) {
auto [old_node_id, new_node_id] = Q.front();
Q.pop();
Q.pop_front();
builder->StartNode(new_node_id);
if (old_node_id < 0) { // leaf
builder->LeafScalar(lgb_tree.leaf_value[~old_node_id]);
Expand All @@ -554,9 +556,9 @@ inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {
auto const split_index = static_cast<std::int32_t>(lgb_tree.split_feature[old_node_id]);
auto const missing_type = GetMissingType(lgb_tree.decision_type[old_node_id]);
int const left_child_old_id = lgb_tree.left_child[old_node_id];
int const left_child_new_id = new_node_id * 2 + 1;
int const left_child_new_id = dfs_index++;
int const right_child_old_id = lgb_tree.right_child[old_node_id];
int const right_child_new_id = new_node_id * 2 + 2;
int const right_child_new_id = dfs_index++;

if (GetDecisionType(lgb_tree.decision_type[old_node_id], kCategoricalMask)) {
// Categorical split
Expand Down Expand Up @@ -591,8 +593,8 @@ inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {
if (!lgb_tree.split_gain.empty()) {
builder->Gain(lgb_tree.split_gain[old_node_id]);
}
Q.emplace(left_child_old_id, left_child_new_id);
Q.emplace(right_child_old_id, right_child_new_id);
Q.emplace_front(left_child_old_id, left_child_new_id);
Q.emplace_front(right_child_old_id, right_child_new_id);
}
builder->EndNode();
}
Expand Down
36 changes: 36 additions & 0 deletions tests/examples/deep_lightgbm/model.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
tree
version=v4
num_class=1
num_tree_per_iteration=1
label_index=0
max_feature_idx=0
objective=regression
feature_names=this
feature_infos=[0:100]
tree_sizes=1119

Tree=0
num_leaves=32
num_cat=0
split_feature=0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
split_gain=0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
threshold=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
decision_type=2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
left_child=1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 -31
right_child=-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -32
leaf_value=31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 0 1
leaf_weight=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
leaf_count=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
internal_value=0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
internal_weight=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
internal_count=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
is_linear=0
shrinkage=1


end of trees

feature_importances:
this=31

pandas_categorical:null
25 changes: 24 additions & 1 deletion tests/python/test_lightgbm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

try:
from hypothesis import given, settings
from hypothesis.extra.numpy import arrays
from hypothesis.strategies import data as hypothesis_callback
from hypothesis.strategies import integers, just, sampled_from
from hypothesis.strategies import integers, just, sampled_from, tuples
except ImportError:
pytest.skip("hypothesis not installed; skipping", allow_module_level=True)

Expand Down Expand Up @@ -259,3 +260,25 @@ def test_lightgbm_sparse_categorical_model():
expected_pred = load_txt(dataset_db[dataset].expected_margin).reshape((-1, 1, 1))
out_pred = treelite.gtil.predict(tl_model, X, pred_margin=True)
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)


@given(
X=arrays(
dtype=sampled_from([np.float32, np.float64]),
shape=tuples(integers(min_value=10, max_value=100), just(1)),
)
)
def test_lightgbm_deep_tree(X):
"""Test LightGBM model with depth 32+"""
path = (
pathlib.Path(__file__).parent.parent
/ "examples"
/ "deep_lightgbm"
/ "model.txt"
)
bst = lgb.Booster(model_file=path)
expected_pred = bst.predict(X).reshape((X.shape[0], 1, -1))

tl_model = treelite.frontend.load_lightgbm_model(path)
out_pred = treelite.gtil.predict(tl_model, X)
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)

0 comments on commit f1b910e

Please sign in to comment.