diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h index 4f5ede83102b..4f072a6406e3 100644 --- a/include/LightGBM/tree.h +++ b/include/LightGBM/tree.h @@ -243,6 +243,9 @@ class Tree { /*! \brief Serialize this object to json*/ std::string ToJSON() const; + /*! \brief Serialize linear model of tree node to json*/ + std::string LinearModelToJSON(int index) const; + /*! \brief Serialize this object to if-else statement*/ std::string ToIfElse(int index, bool predict_leaf_index) const; diff --git a/src/io/tree.cpp b/src/io/tree.cpp index 67e02af20cd8..e3c770491ff6 100644 --- a/src/io/tree.cpp +++ b/src/io/tree.cpp @@ -417,11 +417,39 @@ std::string Tree::ToJSON() const { str_buf << "\"num_cat\":" << num_cat_ << "," << '\n'; str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n'; if (num_leaves_ == 1) { - str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n'; + if (is_linear_) { + str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << ", " << "\n"; + str_buf << LinearModelToJSON(0) << "}" << "\n"; + } else { + str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n'; + } } else { str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n'; } + return str_buf.str(); +} +std::string Tree::LinearModelToJSON(int index) const { + std::stringstream str_buf; + Common::C_stringstream(str_buf); + str_buf << std::setprecision(std::numeric_limits::digits10 + 2); + str_buf << "\"leaf_const\":" << leaf_const_[index] << "," << "\n"; + int num_features = static_cast(leaf_features_[index].size()); + if (num_features > 0) { + str_buf << "\"leaf_features\":["; + for (int i = 0; i < num_features - 1; ++i) { + str_buf << leaf_features_[index][i] << ", "; + } + str_buf << leaf_features_[index][num_features - 1] << "]" << ", " << "\n"; + str_buf << "\"leaf_coeff\":["; + for (int i = 0; i < num_features - 1; ++i) { + str_buf << leaf_coeff_[index][i] << ", "; + } + str_buf << leaf_coeff_[index][num_features - 1] << "]" << "\n"; + } else { + str_buf << "\"leaf_features\":[],\n"; + str_buf << "\"leaf_coeff\":[]\n"; + } return str_buf.str(); } @@ -479,10 +507,14 @@ std::string Tree::NodeToJSON(int index) const { str_buf << "\"leaf_index\":" << index << "," << '\n'; str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n'; str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n'; - str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n'; + if (is_linear_) { + str_buf << "\"leaf_count\":" << leaf_count_[index] << "," << '\n'; + str_buf << LinearModelToJSON(index); + } else { + str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n'; + } str_buf << "}"; } - return str_buf.str(); } diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 49e89534da51..6ffec8cee7d9 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -2793,3 +2793,28 @@ def test_reset_params_works_with_metric_num_class_and_boosting(): expected_params = dict(dataset_params, **booster_params) assert bst.params == expected_params assert new_bst.params == expected_params + + +def test_dump_model(): + X, y = load_breast_cancer(return_X_y=True) + train_data = lgb.Dataset(X, label=y) + params = { + "objective": "binary", + "verbose": -1 + } + bst = lgb.train(params, train_data, num_boost_round=5) + dumped_model_str = str(bst.dump_model(5, 0)) + assert "leaf_features" not in dumped_model_str + assert "leaf_coeff" not in dumped_model_str + assert "leaf_const" not in dumped_model_str + assert "leaf_value" in dumped_model_str + assert "leaf_count" in dumped_model_str + params['linear_tree'] = True + train_data = lgb.Dataset(X, label=y) + bst = lgb.train(params, train_data, num_boost_round=5) + dumped_model_str = str(bst.dump_model(5, 0)) + assert "leaf_features" in dumped_model_str + assert "leaf_coeff" in dumped_model_str + assert "leaf_const" in dumped_model_str + assert "leaf_value" in dumped_model_str + assert "leaf_count" in dumped_model_str