Skip to content

Commit

Permalink
Improve parameter validation (#6769)
Browse files Browse the repository at this point in the history
* Add quotes to unused parameters.
* Check for whitespace.
  • Loading branch information
trivialfis committed Mar 19, 2021
1 parent 23b4165 commit f6fe15d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ test_that("parameter validation works", {
xgb.train(params = params, data = dtrain, nrounds = nrounds))
print(output)
}
expect_output(incorrect(), "bar, foo")
expect_output(incorrect(), '\\\\"bar\\\\", \\\\"foo\\\\"')
})


Expand Down
11 changes: 7 additions & 4 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -537,15 +537,18 @@ class LearnerConfiguration : public Learner {
}
}

// FIXME(trivialfis): Make eval_metric a training parameter.
keys.emplace_back(kEvalMetric);
keys.emplace_back("verbosity");
keys.emplace_back("num_output_group");

std::sort(keys.begin(), keys.end());

std::vector<std::string> provided;
for (auto const &kv : cfg_) {
// FIXME(trivialfis): Make eval_metric a training parameter.
if (std::any_of(kv.first.cbegin(), kv.first.cend(),
[](char ch) { return std::isspace(ch); })) {
LOG(FATAL) << "Invalid parameter \"" << kv.first << "\" contains whitespace.";
}
provided.push_back(kv.first);
}
std::sort(provided.begin(), provided.end());
Expand All @@ -557,9 +560,9 @@ class LearnerConfiguration : public Learner {
std::stringstream ss;
ss << "\nParameters: { ";
for (size_t i = 0; i < diff.size() - 1; ++i) {
ss << diff[i] << ", ";
ss << "\"" << diff[i] << "\", ";
}
ss << diff.back();
ss << "\"" << diff.back() << "\"";
ss << R"W( } might not be used.
This may not be accurate due to some parameters are only used in language bindings but
Expand Down
8 changes: 6 additions & 2 deletions tests/cpp/test_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,19 @@ TEST(Learner, ParameterValidation) {

auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
learner->SetParam("validate_parameters", "1");
learner->SetParam("Knock Knock", "Who's there?");
learner->SetParam("Knock-Knock", "Who's-there?");
learner->SetParam("Silence", "....");
learner->SetParam("tree_method", "exact");

testing::internal::CaptureStderr();
learner->Configure();
std::string output = testing::internal::GetCapturedStderr();

ASSERT_TRUE(output.find("Parameters: { Knock Knock, Silence }") != std::string::npos);
ASSERT_TRUE(output.find(R"(Parameters: { "Knock-Knock", "Silence" })") != std::string::npos);

// whitespace
learner->SetParam("tree method", "exact");
EXPECT_THROW(learner->Configure(), dmlc::Error);
}

TEST(Learner, CheckGroup) {
Expand Down

0 comments on commit f6fe15d

Please sign in to comment.