diff --git a/src/model_loader/lightgbm.cc b/src/model_loader/lightgbm.cc index 9c1bd4af..00c7ac56 100644 --- a/src/model_loader/lightgbm.cc +++ b/src/model_loader/lightgbm.cc @@ -526,22 +526,24 @@ inline std::unique_ptr ParseStream(std::istream& fi) { for (auto const& lgb_tree : lgb_trees_) { builder->StartTree(); - // Assign node ID's so that a breadth-wise traversal would yield + // Assign node ID's so that a depth-wise traversal would yield // the monotonic sequence 0, 1, 2, ... // We re-arrange nodes here, since LightGBM uses negative indices to distinguish leaf nodes // from internal nodes. - std::queue> Q; // (old ID, new ID) pair + std::deque> Q; // (old ID, new ID) pair + int dfs_index = 1; if (lgb_tree.num_leaves == 0) { continue; } else if (lgb_tree.num_leaves == 1) { // A constant-value tree with a single root node that's also a leaf - Q.emplace(-1, 0); + Q.emplace_front(-1, dfs_index); } else { - Q.emplace(0, 0); + Q.emplace_front(0, dfs_index); } + dfs_index++; while (!Q.empty()) { auto [old_node_id, new_node_id] = Q.front(); - Q.pop(); + Q.pop_front(); builder->StartNode(new_node_id); if (old_node_id < 0) { // leaf builder->LeafScalar(lgb_tree.leaf_value[~old_node_id]); @@ -554,9 +556,9 @@ inline std::unique_ptr ParseStream(std::istream& fi) { auto const split_index = static_cast(lgb_tree.split_feature[old_node_id]); auto const missing_type = GetMissingType(lgb_tree.decision_type[old_node_id]); int const left_child_old_id = lgb_tree.left_child[old_node_id]; - int const left_child_new_id = new_node_id * 2 + 1; + int const left_child_new_id = dfs_index++; int const right_child_old_id = lgb_tree.right_child[old_node_id]; - int const right_child_new_id = new_node_id * 2 + 2; + int const right_child_new_id = dfs_index++; if (GetDecisionType(lgb_tree.decision_type[old_node_id], kCategoricalMask)) { // Categorical split @@ -591,8 +593,8 @@ inline std::unique_ptr ParseStream(std::istream& fi) { if (!lgb_tree.split_gain.empty()) { builder->Gain(lgb_tree.split_gain[old_node_id]); } - Q.emplace(left_child_old_id, left_child_new_id); - Q.emplace(right_child_old_id, right_child_new_id); + Q.emplace_front(left_child_old_id, left_child_new_id); + Q.emplace_front(right_child_old_id, right_child_new_id); } builder->EndNode(); } diff --git a/tests/examples/deep_lightgbm/model.txt b/tests/examples/deep_lightgbm/model.txt new file mode 100644 index 00000000..943917f0 --- /dev/null +++ b/tests/examples/deep_lightgbm/model.txt @@ -0,0 +1,36 @@ +tree +version=v4 +num_class=1 +num_tree_per_iteration=1 +label_index=0 +max_feature_idx=0 +objective=regression +feature_names=this +feature_infos=[0:100] +tree_sizes=1119 + +Tree=0 +num_leaves=32 +num_cat=0 +split_feature=0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 +split_gain=0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 +threshold=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 +decision_type=2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 +left_child=1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 -31 +right_child=-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -32 +leaf_value=31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 0 1 +leaf_weight=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 +leaf_count=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 +internal_value=0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 +internal_weight=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 +internal_count=1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 +is_linear=0 +shrinkage=1 + + +end of trees + +feature_importances: +this=31 + +pandas_categorical:null diff --git a/tests/python/test_lightgbm_integration.py b/tests/python/test_lightgbm_integration.py index a78aced9..26ea8ab8 100644 --- a/tests/python/test_lightgbm_integration.py +++ b/tests/python/test_lightgbm_integration.py @@ -10,8 +10,9 @@ try: from hypothesis import given, settings + from hypothesis.extra.numpy import arrays from hypothesis.strategies import data as hypothesis_callback - from hypothesis.strategies import integers, just, sampled_from + from hypothesis.strategies import integers, just, sampled_from, tuples except ImportError: pytest.skip("hypothesis not installed; skipping", allow_module_level=True) @@ -259,3 +260,25 @@ def test_lightgbm_sparse_categorical_model(): expected_pred = load_txt(dataset_db[dataset].expected_margin).reshape((-1, 1, 1)) out_pred = treelite.gtil.predict(tl_model, X, pred_margin=True) np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5) + + +@given( + X=arrays( + dtype=sampled_from([np.float32, np.float64]), + shape=tuples(integers(min_value=10, max_value=100), just(1)), + ) +) +def test_lightgbm_deep_tree(X): + """Test LightGBM model with depth 32+""" + path = ( + pathlib.Path(__file__).parent.parent + / "examples" + / "deep_lightgbm" + / "model.txt" + ) + bst = lgb.Booster(model_file=path) + expected_pred = bst.predict(X).reshape((X.shape[0], 1, -1)) + + tl_model = treelite.frontend.load_lightgbm_model(path) + out_pred = treelite.gtil.predict(tl_model, X) + np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)