Skip to content

Commit

Permalink
Add JSON schema to model dump. (#5660)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored May 15, 2020
1 parent 2c1a439 commit 535479e
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 26 deletions.
55 changes: 55 additions & 0 deletions doc/dump.schema
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"definitions": {
"split_node": {
"type": "object",
"properties": {
"nodeid": {
"type": "number",
"minimum": 0
},
"depth": {
"type": "number",
"minimum": 0
},
"yes": {
"type": "number",
"minimum": 0
},
"no": {
"type": "number",
"minimum": 0
},
"split": {
"type": "string"
},
"children": {
"type": "array",
"items": {
"oneOf": [
{"$ref": "#/definitions/split_node"},
{"$ref": "#/definitions/leaf_node"}
]
},
"maxItems": 2
}
},
"required": ["nodeid", "depth", "yes", "no", "split", "children"]
},
"leaf_node": {
"type": "object",
"properties": {
"nodeid": {
"type": "number",
"minimum": 0
},
"leaf": {
"type": "number"
}
},
"required": ["nodeid", "leaf"]
}
},
"type": "object",
"$ref": "#/definitions/split_node"
}
51 changes: 26 additions & 25 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,20 @@ class TreeGenerator {
return result;
}

virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) {
virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const {
return "";
}
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) {
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const {
return "";
}
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) {
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const {
return "";
}
virtual std::string NodeStat(RegTree const& tree, int32_t nid) {
virtual std::string NodeStat(RegTree const& tree, int32_t nid) const {
return "";
}

virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0;

virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
auto const split_index = tree[nid].SplitIndex();
Expand Down Expand Up @@ -110,7 +110,7 @@ class TreeGenerator {
return result;
}

virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const = 0;
virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0;

public:
Expand Down Expand Up @@ -181,7 +181,7 @@ class TextGenerator : public TreeGenerator {
TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {}

std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}";
static std::string kStatTemplate = ",cover={cover}";
std::string result = SuperT::Match(
Expand All @@ -195,7 +195,7 @@ class TextGenerator : public TreeGenerator {
return result;
}

std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}";
int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
Expand All @@ -211,7 +211,7 @@ class TextGenerator : public TreeGenerator {

std::string SplitNodeImpl(
RegTree const& tree, int32_t nid, std::string const& template_str,
std::string cond, uint32_t depth) {
std::string cond, uint32_t depth) const {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
Expand All @@ -226,7 +226,7 @@ class TextGenerator : public TreeGenerator {
return result;
}

std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kIntegerTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
Expand All @@ -238,21 +238,21 @@ class TextGenerator : public TreeGenerator {
std::to_string(integer_threshold), depth);
}

std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kQuantitiveTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
}

std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
}

std::string NodeStat(RegTree const& tree, int32_t nid) override {
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match(
kStatTemplate,
Expand Down Expand Up @@ -297,15 +297,15 @@ class JsonGenerator : public TreeGenerator {
JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) :
TreeGenerator(fmap, with_stats) {}

std::string Indent(uint32_t depth) {
std::string Indent(uint32_t depth) const {
std::string result;
for (uint32_t i = 0; i < depth + 1; ++i) {
result += " ";
}
return result;
}

std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kLeafTemplate =
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
static std::string const kStatTemplate =
Expand All @@ -321,11 +321,11 @@ class JsonGenerator : public TreeGenerator {
return result;
}

std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override {
int32_t nyes = tree[nid].DefaultLeft() ?
tree[nid].RightChild() : tree[nid].LeftChild();
static std::string const kIndicatorTemplate =
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no}})ID";
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID";
auto split_index = tree[nid].SplitIndex();
auto result = SuperT::Match(
kIndicatorTemplate,
Expand All @@ -337,8 +337,9 @@ class JsonGenerator : public TreeGenerator {
return result;
}

std::string SplitNodeImpl(RegTree const& tree, int32_t nid,
std::string const& template_str, std::string cond, uint32_t depth) {
std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
std::string const &template_str, std::string cond,
uint32_t depth) const {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
Expand All @@ -353,7 +354,7 @@ class JsonGenerator : public TreeGenerator {
return result;
}

std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond();
const bst_float floored = std::floor(cond);
const int32_t integer_threshold
Expand All @@ -367,7 +368,7 @@ class JsonGenerator : public TreeGenerator {
std::to_string(integer_threshold), depth);
}

std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kQuantitiveTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
Expand All @@ -376,7 +377,7 @@ class JsonGenerator : public TreeGenerator {
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
}

std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I"
Expand All @@ -385,7 +386,7 @@ class JsonGenerator : public TreeGenerator {
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
}

std::string NodeStat(RegTree const& tree, int32_t nid) override {
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
static std::string kStatTemplate =
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
auto result = SuperT::Match(
Expand Down Expand Up @@ -529,7 +530,7 @@ class GraphvizGenerator : public TreeGenerator {
protected:
// Only indicator is different, so we combine all different node types into this
// function.
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto split = tree[nid].SplitIndex();
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
Expand Down Expand Up @@ -563,7 +564,7 @@ class GraphvizGenerator : public TreeGenerator {
return result;
};

std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
static std::string const kLeafTemplate =
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
auto result = SuperT::Match(kLeafTemplate, {
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/tree/test_tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ TEST(Tree, DumpJson) {

str = tree.DumpModel(fmap, false, "json");
ASSERT_EQ(str.find("cover"), std::string::npos);


auto j_tree = Json::Load({str.c_str(), str.size()});
ASSERT_EQ(get<Array>(j_tree["children"]).size(), 2);
}

TEST(Tree, DumpText) {
Expand Down
34 changes: 33 additions & 1 deletion tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def test_model_json_io(self):
assert locale.getpreferredencoding(False) == loc

@pytest.mark.skipif(**tm.no_json_schema())
def test_json_schema(self):
def test_json_io_schema(self):
import jsonschema
model_path = 'test_json_schema.json'
path = os.path.dirname(
Expand All @@ -342,3 +342,35 @@ def test_json_schema(self):
jsonschema.validate(instance=json_model(model_path, parameters),
schema=schema)
os.remove(model_path)

@pytest.mark.skipif(**tm.no_json_schema())
def test_json_dump_schema(self):
import jsonschema

def validate_model(parameters):
X = np.random.random((100, 30))
y = np.random.randint(0, 4, size=(100,))

parameters['num_class'] = 4
m = xgb.DMatrix(X, y)

booster = xgb.train(parameters, m)
dump = booster.get_dump(dump_format='json')

for i in range(len(dump)):
jsonschema.validate(instance=json.loads(dump[i]),
schema=schema)

path = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
doc = os.path.join(path, 'doc', 'dump.schema')
with open(doc, 'r') as fd:
schema = json.load(fd)

parameters = {'tree_method': 'hist', 'booster': 'gbtree',
'objective': 'multi:softmax'}
validate_model(parameters)

parameters = {'tree_method': 'hist', 'booster': 'dart',
'objective': 'multi:softmax'}
validate_model(parameters)

0 comments on commit 535479e

Please sign in to comment.