From f6fe15d11fd05ff1cfe77441fa6c5d7280ee2131 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 20 Mar 2021 01:56:55 +0800 Subject: [PATCH] Improve parameter validation (#6769) * Add quotes to unused parameters. * Check for whitespace. --- R-package/tests/testthat/test_basic.R | 2 +- src/learner.cc | 11 +++++++---- tests/cpp/test_learner.cc | 8 ++++++-- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 0a427ded4973..ddf2c4318854 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -66,7 +66,7 @@ test_that("parameter validation works", { xgb.train(params = params, data = dtrain, nrounds = nrounds)) print(output) } - expect_output(incorrect(), "bar, foo") + expect_output(incorrect(), '\\\\"bar\\\\", \\\\"foo\\\\"') }) diff --git a/src/learner.cc b/src/learner.cc index 019603261bbf..8d9e05652e2d 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -537,15 +537,18 @@ class LearnerConfiguration : public Learner { } } + // FIXME(trivialfis): Make eval_metric a training parameter. keys.emplace_back(kEvalMetric); - keys.emplace_back("verbosity"); keys.emplace_back("num_output_group"); std::sort(keys.begin(), keys.end()); std::vector provided; for (auto const &kv : cfg_) { - // FIXME(trivialfis): Make eval_metric a training parameter. + if (std::any_of(kv.first.cbegin(), kv.first.cend(), + [](char ch) { return std::isspace(ch); })) { + LOG(FATAL) << "Invalid parameter \"" << kv.first << "\" contains whitespace."; + } provided.push_back(kv.first); } std::sort(provided.begin(), provided.end()); @@ -557,9 +560,9 @@ class LearnerConfiguration : public Learner { std::stringstream ss; ss << "\nParameters: { "; for (size_t i = 0; i < diff.size() - 1; ++i) { - ss << diff[i] << ", "; + ss << "\"" << diff[i] << "\", "; } - ss << diff.back(); + ss << "\"" << diff.back() << "\""; ss << R"W( } might not be used. This may not be accurate due to some parameters are only used in language bindings but diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 237cb559cc94..703af54f2553 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -40,7 +40,7 @@ TEST(Learner, ParameterValidation) { auto learner = std::unique_ptr(Learner::Create({p_mat})); learner->SetParam("validate_parameters", "1"); - learner->SetParam("Knock Knock", "Who's there?"); + learner->SetParam("Knock-Knock", "Who's-there?"); learner->SetParam("Silence", "...."); learner->SetParam("tree_method", "exact"); @@ -48,7 +48,11 @@ TEST(Learner, ParameterValidation) { learner->Configure(); std::string output = testing::internal::GetCapturedStderr(); - ASSERT_TRUE(output.find("Parameters: { Knock Knock, Silence }") != std::string::npos); + ASSERT_TRUE(output.find(R"(Parameters: { "Knock-Knock", "Silence" })") != std::string::npos); + + // whitespace + learner->SetParam("tree method", "exact"); + EXPECT_THROW(learner->Configure(), dmlc::Error); } TEST(Learner, CheckGroup) {