Skip to content

Commit

Permalink
Make sure the number of features is correct.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 18, 2021
1 parent d097ed5 commit 3d82f2b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/c_api/c_api_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,14 @@ inline void GenerateFeatureMap(Learner const *learner,
// Use the feature names and types from booster.
std::vector<std::string> feature_names;
learner->GetFeatureNames(&feature_names);
if (!feature_names.empty()) {
CHECK_EQ(feature_names.size(), n_features) << "Incorrect number of feature names.";
}
std::vector<std::string> feature_types;
learner->GetFeatureTypes(&feature_types);
if (!feature_types.empty()) {
CHECK_EQ(feature_types.size(), n_features) << "Incorrect number of feature types.";
}
for (size_t i = 0; i < n_features; ++i) {
feature_map.PushBack(
i,
Expand Down
17 changes: 17 additions & 0 deletions tests/python/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ def test_dump(self):
dump4j = json.loads(dump4[0])
assert 'gain' in dump4j, "Expected 'gain' to be dumped in JSON."

def test_feature_score(self):
rng = np.random.RandomState(0)
data = rng.randn(100, 2)
target = np.array([0, 1] * 50)
features = ["F0"]
with pytest.raises(ValueError):
xgb.DMatrix(data, label=target, feature_names=features)

params = {"objective": "binary:logistic"}
dm = xgb.DMatrix(data, label=target, feature_names=["F0", "F1"])
booster = xgb.train(params, dm, num_boost_round=1)
# no error since feature names might be assigned before the booster seeing data
# and booster doesn't known about the actual number of features.
booster.feature_names = ["F0"]
with pytest.raises(ValueError):
booster.get_fscore()

def test_load_file_invalid(self):
with pytest.raises(xgb.core.XGBoostError):
xgb.Booster(model_file='incorrect_path')
Expand Down

0 comments on commit 3d82f2b

Please sign in to comment.