Skip to content

Commit

Permalink
Blacklist a number of prediction field names. (#861) (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek authored Dec 1, 2019
1 parent 7275f29 commit f670c04
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[

=== Bug Fixes
* Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].)
* Prevent prediction_field_name clashing with other fields in ml results.
(See {ml-pull}861[#861].)


== {es} version 7.5.0
Expand Down
7 changes: 7 additions & 0 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier
this->dependentVariableFieldName()) == categoricalFieldNames.end()) {
HANDLE_FATAL(<< "Input error: trying to perform classification with numeric target.");
}
const std::set<std::string> predictionFieldNameBlacklist{
IS_TRAINING_FIELD_NAME, PREDICTION_PROBABILITY_FIELD_NAME, TOP_CLASSES_FIELD_NAME};
if (predictionFieldNameBlacklist.count(this->predictionFieldName()) > 0) {
HANDLE_FATAL(<< "Input error: prediction_field_name must not be equal to any of "
<< core::CContainerPrinter::print(predictionFieldNameBlacklist)
<< ".");
}
}

CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifierRunner(
Expand Down
6 changes: 6 additions & 0 deletions lib/api/CDataFrameTrainBoostedTreeRegressionRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ CDataFrameTrainBoostedTreeRegressionRunner::CDataFrameTrainBoostedTreeRegression
this->dependentVariableFieldName()) != categoricalFieldNames.end()) {
HANDLE_FATAL(<< "Input error: trying to perform regression with categorical target.");
}
const std::set<std::string> predictionFieldNameBlacklist{IS_TRAINING_FIELD_NAME};
if (predictionFieldNameBlacklist.count(this->predictionFieldName()) > 0) {
HANDLE_FATAL(<< "Input error: prediction_field_name must not be equal to any of "
<< core::CContainerPrinter::print(predictionFieldNameBlacklist)
<< ".");
}
}

CDataFrameTrainBoostedTreeRegressionRunner::CDataFrameTrainBoostedTreeRegressionRunner(
Expand Down
20 changes: 20 additions & 0 deletions lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ using TStrVec = std::vector<std::string>;
using TStrVecVec = std::vector<TStrVec>;
}

BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
TStrVec errors;
auto errorHandler = [&errors](std::string error) { errors.push_back(error); };
core::CLogger::CScopeSetFatalErrorHandler scope{errorHandler};

const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
"classification", "dep_var", 5, 6, 13000000, 0, 0, {"dep_var"})};
rapidjson::Document jsonParameters;
jsonParameters.Parse("{"
" \"dependent_variable\": \"dep_var\","
" \"prediction_field_name\": \"is_training\""
"}");
const auto parameters{
api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)};
api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters);

BOOST_TEST_REQUIRE(errors.size() == 1);
BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes].");
}

BOOST_AUTO_TEST_CASE(testWriteOneRow) {
// Prepare input data frame
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"};
Expand Down
46 changes: 46 additions & 0 deletions lib/api/unittest/CDataFrameTrainBoostedTreeRegressionRunnerTest.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

#include <core/CDataFrame.h>

#include <api/CDataFrameAnalysisConfigReader.h>
#include <api/CDataFrameTrainBoostedTreeRegressionRunner.h>

#include <test/CDataFrameAnalysisSpecificationFactory.h>

#include <boost/test/unit_test.hpp>

#include <string>
#include <vector>

BOOST_AUTO_TEST_SUITE(CDataFrameTrainBoostedTreeRegressionRunnerTest)

using namespace ml;
namespace {
using TStrVec = std::vector<std::string>;
}

BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
TStrVec errors;
auto errorHandler = [&errors](std::string error) { errors.push_back(error); };
core::CLogger::CScopeSetFatalErrorHandler scope{errorHandler};

const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
"regression", "dep_var", 5, 6, 13000000, 0, 0)};
rapidjson::Document jsonParameters;
jsonParameters.Parse("{"
" \"dependent_variable\": \"dep_var\","
" \"prediction_field_name\": \"is_training\""
"}");
const auto parameters{
api::CDataFrameTrainBoostedTreeRegressionRunner::parameterReader().read(jsonParameters)};
api::CDataFrameTrainBoostedTreeRegressionRunner runner(*spec, parameters);

BOOST_TEST_REQUIRE(errors.size() == 1);
BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training].");
}

BOOST_AUTO_TEST_SUITE_END()
1 change: 1 addition & 0 deletions lib/api/unittest/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ SRCS=\
CDataFrameAnalyzerOutlierTest.cc \
CDataFrameAnalyzerTrainingTest.cc \
CDataFrameTrainBoostedTreeClassifierRunnerTest.cc \
CDataFrameTrainBoostedTreeRegressionRunnerTest.cc \
CDataFrameMockAnalysisRunner.cc \
CDetectionRulesJsonParserTest.cc \
CFieldConfigTest.cc \
Expand Down

0 comments on commit f670c04

Please sign in to comment.