diff --git a/include/xgboost/feature_map.h b/include/xgboost/feature_map.h index a48e28ba1bfa..6083bb95d541 100644 --- a/include/xgboost/feature_map.h +++ b/include/xgboost/feature_map.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014 by Contributors + * Copyright 2014-2021 by Contributors * \file feature_map.h * \brief Feature map data structure to help visualization and model dump. * \author Tianqi Chen @@ -26,7 +26,8 @@ class FeatureMap { kIndicator = 0, kQuantitive = 1, kInteger = 2, - kFloat = 3 + kFloat = 3, + kCategorical = 4 }; /*! * \brief load feature map from input stream @@ -82,6 +83,7 @@ class FeatureMap { if (!strcmp("q", tname)) return kQuantitive; if (!strcmp("int", tname)) return kInteger; if (!strcmp("float", tname)) return kFloat; + if (!strcmp("categorical", tname)) return kCategorical; LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity"; return kIndicator; } diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index 7fe187180a34..09fcc1a04542 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -3,6 +3,7 @@ # coding: utf-8 """Plotting Library.""" from io import BytesIO +import json import numpy as np from .core import Booster from .sklearn import XGBModel @@ -203,7 +204,7 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir=None, if kwargs: parameters += ':' - parameters += str(kwargs) + parameters += json.dumps(kwargs) tree = booster.get_dump( fmap=fmap, dump_format=parameters)[num_trees] diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 58d1633b8c63..e8e2a2a8dfb5 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -52,11 +52,6 @@ bst_float PredValue(const SparsePage::Inst &inst, if (tree_info[i] == bst_group) { auto const &tree = *trees[i]; bool has_categorical = tree.HasCategoricalSplit(); - - auto categories = common::Span{tree.GetSplitCategories()}; - auto split_types = tree.GetSplitTypes(); - auto categories_ptr = - common::Span{tree.GetSplitCategoriesPtr()}; auto cats = tree.GetCategoriesMatrix(); bst_node_t nidx = -1; if (has_categorical) { diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 1847015bd3a7..94e3bc97abb2 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2020 by Contributors + * Copyright 2015-2021 by Contributors * \file tree_model.cc * \brief model structure for tree */ @@ -74,6 +74,7 @@ class TreeGenerator { int32_t /*nid*/, uint32_t /*depth*/) const { return ""; } + virtual std::string Categorical(RegTree const&, int32_t, uint32_t) const = 0; virtual std::string Integer(RegTree const& /*tree*/, int32_t /*nid*/, uint32_t /*depth*/) const { return ""; @@ -92,26 +93,51 @@ class TreeGenerator { virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) { auto const split_index = tree[nid].SplitIndex(); std::string result; + auto is_categorical = tree.GetSplitTypes()[nid] == FeatureType::kCategorical; if (split_index < fmap_.Size()) { + auto check_categorical = [&]() { + CHECK(is_categorical) + << fmap_.Name(split_index) + << " in feature map is numerical but tree node is categorical."; + }; + auto check_numerical = [&]() { + auto is_numerical = !is_categorical; + CHECK(is_numerical) + << fmap_.Name(split_index) + << " in feature map is categorical but tree node is numerical."; + }; + switch (fmap_.TypeOf(split_index)) { - case FeatureMap::kIndicator: { - result = this->Indicator(tree, nid, depth); - break; - } - case FeatureMap::kInteger: { - result = this->Integer(tree, nid, depth); - break; - } - case FeatureMap::kFloat: - case FeatureMap::kQuantitive: { - result = this->Quantitive(tree, nid, depth); - break; - } - default: - LOG(FATAL) << "Unknown feature map type."; + case FeatureMap::kCategorical: { + check_categorical(); + result = this->Categorical(tree, nid, depth); + break; + } + case FeatureMap::kIndicator: { + check_numerical(); + result = this->Indicator(tree, nid, depth); + break; + } + case FeatureMap::kInteger: { + check_numerical(); + result = this->Integer(tree, nid, depth); + break; + } + case FeatureMap::kFloat: + case FeatureMap::kQuantitive: { + check_numerical(); + result = this->Quantitive(tree, nid, depth); + break; + } + default: + LOG(FATAL) << "Unknown feature map type."; } } else { - result = this->PlainNode(tree, nid, depth); + if (is_categorical) { + result = this->Categorical(tree, nid, depth); + } else { + result = this->PlainNode(tree, nid, depth); + } } return result; } @@ -179,6 +205,32 @@ TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const& __make_ ## TreeGenReg ## _ ## UniqueId ## __ = \ ::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name) +std::vector GetSplitCategories(RegTree const &tree, int32_t nidx) { + auto const &csr = tree.GetCategoriesMatrix(); + auto seg = csr.node_ptr[nidx]; + auto split = common::KCatBitField{csr.categories.subspan(seg.beg, seg.size)}; + + std::vector cats; + for (size_t i = 0; i < split.Size(); ++i) { + if (split.Check(i)) { + cats.push_back(static_cast(i)); + } + } + return cats; +} + +std::string PrintCatsAsSet(std::vector const &cats) { + std::stringstream ss; + ss << "{"; + for (size_t i = 0; i < cats.size(); ++i) { + ss << cats[i]; + if (i != cats.size() - 1) { + ss << ","; + } + } + ss << "}"; + return ss.str(); +} class TextGenerator : public TreeGenerator { using SuperT = TreeGenerator; @@ -258,6 +310,17 @@ class TextGenerator : public TreeGenerator { return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); } + std::string Categorical(RegTree const &tree, int32_t nid, + uint32_t depth) const override { + auto cats = GetSplitCategories(tree, nid); + std::string cats_str = PrintCatsAsSet(cats); + static std::string const kNodeTemplate = + "{tabs}{nid}:[{fname}:{cond}] yes={right},no={left},missing={missing}"; + std::string const result = + SplitNodeImpl(tree, nid, kNodeTemplate, cats_str, depth); + return result; + } + std::string NodeStat(RegTree const& tree, int32_t nid) const override { static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}"; std::string const result = SuperT::Match( @@ -343,6 +406,24 @@ class JsonGenerator : public TreeGenerator { return result; } + std::string Categorical(RegTree const& tree, int32_t nid, uint32_t depth) const override { + auto cats = GetSplitCategories(tree, nid); + static std::string const kCategoryTemplate = + R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I" + R"I("split_condition": {cond}, "yes": {right}, "no": {left}, )I" + R"I("missing": {missing})I"; + std::string cats_ptr = "["; + for (size_t i = 0; i < cats.size(); ++i) { + cats_ptr += std::to_string(cats[i]); + if (i != cats.size() - 1) { + cats_ptr += ", "; + } + } + cats_ptr += "]"; + auto results = SplitNodeImpl(tree, nid, kCategoryTemplate, cats_ptr, depth); + return results; + } + std::string SplitNodeImpl(RegTree const &tree, int32_t nid, std::string const &template_str, std::string cond, uint32_t depth) const { @@ -534,6 +615,27 @@ class GraphvizGenerator : public TreeGenerator { } protected: + template + std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const { + static std::string const kEdgeTemplate = + " {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n"; + // Is this the default child for missing value? + bool is_missing = tree[nid].DefaultChild() == child; + std::string branch; + if (is_categorical) { + branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""}; + } else { + branch = std::string{left ? "yes" : "no"} + std::string{is_missing ? ", missing" : ""}; + } + std::string buffer = + SuperT::Match(kEdgeTemplate, + {{"{nid}", std::to_string(nid)}, + {"{child}", std::to_string(child)}, + {"{color}", is_missing ? param_.yes_color : param_.no_color}, + {"{branch}", branch}}); + return buffer; + } + // Only indicator is different, so we combine all different node types into this // function. std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override { @@ -552,27 +654,32 @@ class GraphvizGenerator : public TreeGenerator { {"{cond}", has_less ? SuperT::ToStr(cond) : ""}, {"{params}", param_.condition_node_params}}); - static std::string const kEdgeTemplate = - " {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n"; - auto MatchFn = SuperT::Match; // mingw failed to capture protected fn. - auto BuildEdge = - [&tree, nid, MatchFn, this](int32_t child, bool left) { - // Is this the default child for missing value? - bool is_missing = tree[nid].DefaultChild() == child; - std::string branch = std::string {left ? "yes" : "no"} + - std::string {is_missing ? ", missing" : ""}; - std::string buffer = MatchFn(kEdgeTemplate, { - {"{nid}", std::to_string(nid)}, - {"{child}", std::to_string(child)}, - {"{color}", is_missing ? param_.yes_color : param_.no_color}, - {"{branch}", branch}}); - return buffer; - }; - result += BuildEdge(tree[nid].LeftChild(), true); - result += BuildEdge(tree[nid].RightChild(), false); + result += BuildEdge(tree, nid, tree[nid].LeftChild(), true); + result += BuildEdge(tree, nid, tree[nid].RightChild(), false); + return result; }; + std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override { + static std::string const kLabelTemplate = + " {nid} [ label=\"{fname}:{cond}\" {params}]\n"; + auto cats = GetSplitCategories(tree, nid); + auto cats_str = PrintCatsAsSet(cats); + auto split = tree[nid].SplitIndex(); + std::string result = SuperT::Match( + kLabelTemplate, + {{"{nid}", std::to_string(nid)}, + {"{fname}", split < fmap_.Size() ? fmap_.Name(split) + : 'f' + std::to_string(split)}, + {"{cond}", cats_str}, + {"{params}", param_.condition_node_params}}); + + result += BuildEdge(tree, nid, tree[nid].LeftChild(), true); + result += BuildEdge(tree, nid, tree[nid].RightChild(), false); + + return result; + } + std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override { static std::string const kLeafTemplate = " {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; @@ -588,9 +695,12 @@ class GraphvizGenerator : public TreeGenerator { return this->LeafNode(tree, nid, depth); } static std::string const kNodeTemplate = "{parent}\n{left}\n{right}"; + auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical + ? this->Categorical(tree, nid, depth) + : this->PlainNode(tree, nid, depth); auto result = SuperT::Match( kNodeTemplate, - {{"{parent}", this->PlainNode(tree, nid, depth)}, + {{"{parent}", node}, {"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)}, {"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}}); return result; diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index ac87b25bc618..4f20ba69bf08 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -241,6 +241,65 @@ RegTree ConstructTree() { /*right_sum=*/0.0f); return tree; } + +RegTree ConstructTreeCat(std::vector* cond) { + RegTree tree; + std::vector cats_storage(common::CatBitField::ComputeStorageSize(33), 0); + common::CatBitField split_cats(cats_storage); + split_cats.Set(0); + split_cats.Set(14); + split_cats.Set(32); + + cond->push_back(0); + cond->push_back(14); + cond->push_back(32); + + tree.ExpandCategorical(0, /*split_index=*/0, cats_storage, true, 0.0f, 2.0, + 3.00, 11.0, 2.0, 3.0, 4.0); + auto left = tree[0].LeftChild(); + auto right = tree[0].RightChild(); + tree.ExpandNode( + /*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f, + /*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f, + /*right_sum=*/0.0f); + tree.ExpandCategorical(right, /*split_index=*/0, cats_storage, true, 0.0f, + 2.0, 3.00, 11.0, 2.0, 3.0, 4.0); + return tree; +} + +void TestCategoricalTreeDump(std::string format, std::string sep) { + std::vector cond; + auto tree = ConstructTreeCat(&cond); + + FeatureMap fmap; + auto str = tree.DumpModel(fmap, true, format); + std::string cond_str; + for (size_t c = 0; c < cond.size(); ++c) { + cond_str += std::to_string(cond[c]); + if (c != cond.size() - 1) { + cond_str += sep; + } + } + auto pos = str.find(cond_str); + ASSERT_NE(pos, std::string::npos); + pos = str.find(cond_str, pos + 1); + ASSERT_NE(pos, std::string::npos); + + fmap.PushBack(0, "feat_0", "categorical"); + fmap.PushBack(1, "feat_1", "q"); + fmap.PushBack(2, "feat_2", "int"); + + str = tree.DumpModel(fmap, true, format); + pos = str.find(cond_str); + ASSERT_NE(pos, std::string::npos); + pos = str.find(cond_str, pos + 1); + ASSERT_NE(pos, std::string::npos); + + if (format == "json") { + // Make sure it's valid JSON + Json::Load(StringView{str}); + } +} } // anonymous namespace TEST(Tree, DumpJson) { @@ -278,6 +337,10 @@ TEST(Tree, DumpJson) { ASSERT_EQ(get(j_tree["children"]).size(), 2ul); } +TEST(Tree, DumpJsonCategorical) { + TestCategoricalTreeDump("json", ", "); +} + TEST(Tree, DumpText) { auto tree = ConstructTree(); FeatureMap fmap; @@ -313,6 +376,10 @@ TEST(Tree, DumpText) { ASSERT_EQ(str.find("cover"), std::string::npos); } +TEST(Tree, DumpTextCategorical) { + TestCategoricalTreeDump("text", ","); +} + TEST(Tree, DumpDot) { auto tree = ConstructTree(); FeatureMap fmap; @@ -350,6 +417,10 @@ TEST(Tree, DumpDot) { ASSERT_NE(str.find(R"(1 -> 4 [label="no, missing")"), std::string::npos); } +TEST(Tree, DumpDotCategorical) { + TestCategoricalTreeDump("dot", ","); +} + TEST(Tree, JsonIO) { RegTree tree; tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, diff --git a/tests/python-gpu/test_gpu_plotting.py b/tests/python-gpu/test_gpu_plotting.py new file mode 100644 index 000000000000..4bfda2dbd388 --- /dev/null +++ b/tests/python-gpu/test_gpu_plotting.py @@ -0,0 +1,40 @@ +import sys +import xgboost as xgb +import pytest +import json + +sys.path.append("tests/python") +import testing as tm + +try: + import matplotlib + + matplotlib.use("Agg") + from matplotlib.axes import Axes + from graphviz import Source +except ImportError: + pass + + +pytestmark = pytest.mark.skipif(**tm.no_multiple(tm.no_matplotlib(), tm.no_graphviz())) + + +class TestPlotting: + @pytest.mark.skipif(**tm.no_pandas()) + def test_categorical(self): + X, y = tm.make_categorical(1000, 31, 19, onehot=False) + reg = xgb.XGBRegressor( + enable_categorical=True, n_estimators=10, tree_method="gpu_hist" + ) + reg.fit(X, y) + trees = reg.get_booster().get_dump(dump_format="json") + for tree in trees: + j_tree = json.loads(tree) + assert "leaf" in j_tree.keys() or isinstance( + j_tree["split_condition"], list + ) + + graph = xgb.to_graphviz(reg, num_trees=len(j_tree) - 1) + assert isinstance(graph, Source) + ax = xgb.plot_tree(reg, num_trees=len(j_tree) - 1) + assert isinstance(ax, Axes) diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index dd2dd1973e9f..3c3a7e045058 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -71,7 +71,6 @@ def run_categorical_basic(self, rows, cols, rounds, cats): @settings(deadline=None) @pytest.mark.skipif(**tm.no_pandas()) def test_categorical(self, rows, cols, rounds, cats): - pytest.xfail(reason='TestGPUUpdaters::test_categorical is flaky') self.run_categorical_basic(rows, cols, rounds, cats) def test_categorical_32_cat(self): diff --git a/tests/python-gpu/test_gpu_with_sklearn.py b/tests/python-gpu/test_gpu_with_sklearn.py index a38001a51610..7658299b98c1 100644 --- a/tests/python-gpu/test_gpu_with_sklearn.py +++ b/tests/python-gpu/test_gpu_with_sklearn.py @@ -55,7 +55,6 @@ def test_categorical(): tree_method="gpu_hist", use_label_encoder=False, enable_categorical=True, - predictor="gpu_predictor", n_estimators=10, ) X = pd.DataFrame(X.todense()).astype("category")