Skip to content

Commit

Permalink
Fix inference with categorical feature. (#8591)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Dec 15, 2022
1 parent 7dc3e95 commit 43a647a
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 28 deletions.
10 changes: 5 additions & 5 deletions doc/tutorials/categorical.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ Miscellaneous

By default, XGBoost assumes input categories are integers starting from 0 till the number
of categories :math:`[0, n\_categories)`. However, user might provide inputs with invalid
values due to mistakes or missing values. It can be negative value, integer values that
can not be accurately represented by 32-bit floating point, or values that are larger than
actual number of unique categories. During training this is validated but for prediction
it's treated as the same as missing value for performance reasons. Lastly, missing values
are treated as the same as numerical features (using the learned split direction).
values due to mistakes or missing values in training dataset. It can be negative value,
integer values that can not be accurately represented by 32-bit floating point, or values
that are larger than actual number of unique categories. During training this is
validated but for prediction it's treated as the same as not-chosen category for
performance reasons.


**********
Expand Down
17 changes: 9 additions & 8 deletions src/common/categorical.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,21 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
return cat < 0 || cat >= kMaxCat;
}

/* \brief Whether should it traverse to left branch of a tree.
/**
* \brief Whether should it traverse to left branch of a tree.
*
* For one hot split, go to left if it's NOT the matching category.
* Go to left if it's NOT the matching category, which matches one-hot encoding.
*/
template <bool validate = true>
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat, bool dft_left) {
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat) {
KCatBitField const s_cats(cats);
// FIXME: Size() is not accurate since it represents the size of bit set instead of
// actual number of categories.
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
return dft_left;
if (XGBOOST_EXPECT(InvalidCat(cat), false)) {
return true;
}

auto pos = KCatBitField::ToBitPos(cat);
// If the input category is larger than the size of the bit field, it implies that the
// category is not chosen. Otherwise the bit field would have the category instead of
// being smaller than the category value.
if (pos.int_pos >= cats.size()) {
return true;
}
Expand Down
4 changes: 2 additions & 2 deletions src/common/partition_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class PartitionBuilder {
auto gidx = gidx_calc(ridx);
bool go_left = default_left;
if (gidx > -1) {
go_left = Decision(node_cats, cut_values[gidx], default_left);
go_left = Decision(node_cats, cut_values[gidx]);
}
return go_left;
} else {
Expand All @@ -157,7 +157,7 @@ class PartitionBuilder {
bool go_left = default_left;
if (gidx > -1) {
if (is_cat) {
go_left = Decision(node_cats, cut_values[gidx], default_left);
go_left = Decision(node_cats, cut_values[gidx]);
} else {
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
}
Expand Down
4 changes: 1 addition & 3 deletions src/predictor/predict_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto node_categories =
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
return common::Decision<true>(node_categories, fvalue, node.DefaultLeft())
? node.LeftChild()
: node.RightChild();
return common::Decision(node_categories, fvalue) ? node.LeftChild() : node.RightChild();
} else {
return node.LeftChild() + !(fvalue < node.SplitCond());
}
Expand Down
5 changes: 2 additions & 3 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,7 @@ struct GPUHistMakerDevice {
go_left = data.split_node.DefaultLeft();
} else {
if (data.split_type == FeatureType::kCategorical) {
go_left = common::Decision<false>(data.node_cats.Bits(), cut_value,
data.split_node.DefaultLeft());
go_left = common::Decision(data.node_cats.Bits(), cut_value);
} else {
go_left = cut_value <= data.split_node.SplitCond();
}
Expand Down Expand Up @@ -480,7 +479,7 @@ struct GPUHistMakerDevice {
if (common::IsCat(d_feature_types, position)) {
auto node_cats = categories.subspan(categories_segments[position].beg,
categories_segments[position].size);
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
go_left = common::Decision(node_cats, element);
} else {
go_left = element <= node.SplitCond();
}
Expand Down
63 changes: 56 additions & 7 deletions tests/cpp/common/test_categorical.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
/*!
* Copyright 2021 by XGBoost Contributors
* Copyright 2021-2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/json.h>
#include <xgboost/learner.h>

#include <limits>

#include "../../../src/common/categorical.h"
#include "../helpers.h"

namespace xgboost {
namespace common {
Expand All @@ -15,29 +18,75 @@ TEST(Categorical, Decision) {

ASSERT_TRUE(common::InvalidCat(a));
std::vector<uint32_t> cats(256, 0);
ASSERT_TRUE(Decision(cats, a, true));
ASSERT_TRUE(Decision(cats, a));

// larger than size
a = 256;
ASSERT_TRUE(Decision(cats, a, true));
ASSERT_TRUE(Decision(cats, a));

// negative
a = -1;
ASSERT_TRUE(Decision(cats, a, true));
ASSERT_TRUE(Decision(cats, a));

CatBitField bits{cats};
bits.Set(0);
a = -0.5;
ASSERT_TRUE(Decision(cats, a, true));
ASSERT_TRUE(Decision(cats, a));

// round toward 0
a = 0.5;
ASSERT_FALSE(Decision(cats, a, true));
ASSERT_FALSE(Decision(cats, a));

// valid
a = 13;
bits.Set(a);
ASSERT_FALSE(Decision(bits.Bits(), a, true));
ASSERT_FALSE(Decision(bits.Bits(), a));
}

/**
* Test for running inference with input category greater than the one stored in tree.
*/
TEST(Categorical, MinimalSet) {
std::size_t constexpr kRows = 256, kCols = 1, kCat = 3;
std::vector<FeatureType> types{FeatureType::kCategorical};
auto Xy =
RandomDataGenerator{kRows, kCols, 0.0}.Type(types).MaxCategory(kCat).GenerateDMatrix(true);

std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->SetParam("max_depth", "1");
learner->SetParam("tree_method", "hist");
learner->Configure();
learner->UpdateOneIter(0, Xy);

Json model{Object{}};
learner->SaveModel(&model);
auto tree = model["learner"]["gradient_booster"]["model"]["trees"][0];
ASSERT_GE(get<I32Array const>(tree["categories"]).size(), 1);
auto v = get<I32Array const>(tree["categories"])[0];

HostDeviceVector<float> predt;
{
std::vector<float> data{kCat, kCat + 1, 32, 33, 34};
auto test = GetDMatrixFromData(data, data.size(), kCols);
learner->Predict(test, false, &predt, 0, 0, false, /*pred_leaf=*/true);
ASSERT_EQ(predt.Size(), data.size());
auto const& h_predt = predt.ConstHostSpan();
for (auto v : h_predt) {
ASSERT_EQ(v, 1); // left child of root node
}
}

{
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->LoadModel(model);
std::vector<float> data = {static_cast<float>(v)};
auto test = GetDMatrixFromData(data, data.size(), kCols);
learner->Predict(test, false, &predt, 0, 0, false, /*pred_leaf=*/true);
auto const& h_predt = predt.ConstHostSpan();
for (auto v : h_predt) {
ASSERT_EQ(v, 2); // right child of root node
}
}
}
} // namespace common
} // namespace xgboost

0 comments on commit 43a647a

Please sign in to comment.