Skip to content

Commit

Permalink
Strict feature map.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 12, 2021
1 parent 3973aba commit 9a53f35
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
41 changes: 21 additions & 20 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,37 +95,38 @@ class TreeGenerator {
std::string result;
auto is_categorical = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
if (split_index < fmap_.Size()) {
auto check_categorical = [&]() {
CHECK(is_categorical)
<< fmap_.Name(split_index)
<< " in feature map is numerical but tree node is categorical.";
};
auto check_numerical = [&]() {
auto is_numerical = !is_categorical;
CHECK(is_numerical)
<< fmap_.Name(split_index)
<< " in feature map is categorical but tree node is numerical.";
};

switch (fmap_.TypeOf(split_index)) {
case FeatureMap::kCategorical: {
auto name = split_index < fmap_.Size()
? fmap_.Name(split_index)
: "feat" + std::to_string(split_index);
CHECK(is_categorical)
<< name << " is numerical but tree node is categorical.";
check_categorical();
result = this->Categorical(tree, nid, depth);
break;
}
case FeatureMap::kIndicator: {
if (is_categorical) {
result = this->Categorical(tree, nid, depth);
} else {
result = this->Indicator(tree, nid, depth);
}
check_numerical();
result = this->Indicator(tree, nid, depth);
break;
}
case FeatureMap::kInteger: {
if (is_categorical) {
result = this->Categorical(tree, nid, depth);
} else {
result = this->Integer(tree, nid, depth);
}
check_numerical();
result = this->Integer(tree, nid, depth);
break;
}
case FeatureMap::kFloat:
case FeatureMap::kQuantitive: {
if (is_categorical) {
result = this->Categorical(tree, nid, depth);
} else {
result = this->Quantitive(tree, nid, depth);
}
check_numerical();
result = this->Quantitive(tree, nid, depth);
break;
}
default:
Expand Down
10 changes: 5 additions & 5 deletions tests/cpp/tree/test_tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,16 @@ RegTree ConstructTreeCat(std::vector<bst_cat_t>* cond) {
cond->push_back(14);
cond->push_back(32);

tree.ExpandCategorical(0, 0, cats_storage, true, 0.0f, 2.0, 3.00, 11.0, 2.0,
3.0, 4.0);
tree.ExpandCategorical(0, /*split_index=*/0, cats_storage, true, 0.0f, 2.0,
3.00, 11.0, 2.0, 3.0, 4.0);
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
tree.ExpandCategorical(right, 0, cats_storage, true, 0.0f, 2.0, 3.00, 11.0,
2.0, 3.0, 4.0);
tree.ExpandCategorical(right, /*split_index=*/0, cats_storage, true, 0.0f,
2.0, 3.00, 11.0, 2.0, 3.0, 4.0);
return tree;
}

Expand All @@ -285,7 +285,7 @@ void TestCategoricalTreeDump(std::string format, std::string sep) {
pos = str.find(cond_str, pos + 1);
ASSERT_NE(pos, std::string::npos);

fmap.PushBack(0, "feat_0", "i");
fmap.PushBack(0, "feat_0", "categorical");
fmap.PushBack(1, "feat_1", "q");
fmap.PushBack(2, "feat_2", "int");

Expand Down

0 comments on commit 9a53f35

Please sign in to comment.