Skip to content

Commit 5dc608d

Browse files
committed
Emit predicted category using an appropriate JSON type. (elastic#877)
1 parent dcf7370 commit 5dc608d

File tree

4 files changed

+102
-17
lines changed

4 files changed

+102
-17
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[
4545
(See {ml-pull}818[#818].)
4646
* Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].)
4747
* Reduce runtime of classification and regression. (See {ml-pull}863[#863].)
48+
* Emit `prediction_field_name` in ml results using the type provided as
49+
`prediction_field_type` parameter. (See {ml-pull}877[#877].)
4850

4951
=== Bug Fixes
5052
* Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].)

include/api/CDataFrameTrainBoostedTreeClassifierRunner.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ namespace api {
1919
class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
2020
: public CDataFrameTrainBoostedTreeRunner {
2121
public:
22+
enum EPredictionFieldType {
23+
E_PredictionFieldTypeString,
24+
E_PredictionFieldTypeInt,
25+
E_PredictionFieldTypeBool
26+
};
27+
2228
static const CDataFrameAnalysisConfigReader& parameterReader();
2329

2430
//! This is not intended to be called directly: use CDataFrameTrainBoostedTreeClassifierRunnerFactory.
@@ -44,6 +50,10 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
4450
const TRowRef& row,
4551
core::CRapidJsonConcurrentLineWriter& writer) const;
4652

53+
//! Write the predicted category value as string, int or bool.
54+
void writePredictedCategoryValue(const std::string& categoryValue,
55+
core::CRapidJsonConcurrentLineWriter& writer) const;
56+
4757
//! \return A serialisable definition of the trained classification model.
4858
TInferenceModelDefinitionUPtr
4959
inferenceModelDefinition(const TStrVec& fieldNames,
@@ -55,6 +65,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
5565

5666
private:
5767
std::size_t m_NumTopClasses;
68+
EPredictionFieldType m_PredictionFieldType;
5869
};
5970

6071
//! \brief Makes a core::CDataFrame boosted tree classification runner.

lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ using TSizeVec = std::vector<std::size_t>;
3232

3333
// Configuration
3434
const std::string NUM_TOP_CLASSES{"num_top_classes"};
35+
const std::string PREDICTION_FIELD_TYPE{"prediction_field_type"};
3536
const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"};
3637

3738
// Output
@@ -45,8 +46,16 @@ const std::string CLASS_PROBABILITY_FIELD_NAME{"class_probability"};
4546
const CDataFrameAnalysisConfigReader&
4647
CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() {
4748
static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] {
49+
const std::string typeString{"string"};
50+
const std::string typeInt{"int"};
51+
const std::string typeBool{"bool"};
4852
auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader();
4953
theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter);
54+
theReader.addParameter(PREDICTION_FIELD_TYPE,
55+
CDataFrameAnalysisConfigReader::E_OptionalParameter,
56+
{{typeString, int{E_PredictionFieldTypeString}},
57+
{typeInt, int{E_PredictionFieldTypeInt}},
58+
{typeBool, int{E_PredictionFieldTypeBool}}});
5059
theReader.addParameter(BALANCED_CLASS_LOSS,
5160
CDataFrameAnalysisConfigReader::E_OptionalParameter);
5261
return theReader;
@@ -60,6 +69,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier
6069
: CDataFrameTrainBoostedTreeRunner{spec, parameters} {
6170

6271
m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0});
72+
m_PredictionFieldType =
73+
parameters[PREDICTION_FIELD_TYPE].fallback(E_PredictionFieldTypeString);
6374
this->boostedTreeFactory().balanceClassTrainingLoss(
6475
parameters[BALANCED_CLASS_LOSS].fallback(true));
6576

@@ -119,7 +130,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
119130

120131
writer.StartObject();
121132
writer.Key(this->predictionFieldName());
122-
writer.String(categoryValues[predictedCategoryId]);
133+
writePredictedCategoryValue(categoryValues[predictedCategoryId], writer);
123134
writer.Key(PREDICTION_PROBABILITY_FIELD_NAME);
124135
writer.Double(probabilityOfCategory[predictedCategoryId]);
125136
writer.Key(IS_TRAINING_FIELD_NAME);
@@ -135,7 +146,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
135146
for (std::size_t i = 0; i < std::min(categoryIds.size(), m_NumTopClasses); ++i) {
136147
writer.StartObject();
137148
writer.Key(CLASS_NAME_FIELD_NAME);
138-
writer.String(categoryValues[categoryIds[i]]);
149+
writePredictedCategoryValue(categoryValues[categoryIds[i]], writer);
139150
writer.Key(CLASS_PROBABILITY_FIELD_NAME);
140151
writer.Double(probabilityOfCategory[i]);
141152
writer.EndObject();
@@ -158,6 +169,32 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
158169
columnHoldingPrediction, row, writer);
159170
}
160171

172+
void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue(
173+
const std::string& categoryValue,
174+
core::CRapidJsonConcurrentLineWriter& writer) const {
175+
176+
double doubleValue;
177+
switch (m_PredictionFieldType) {
178+
case E_PredictionFieldTypeString:
179+
writer.String(categoryValue);
180+
break;
181+
case E_PredictionFieldTypeInt:
182+
if (core::CStringUtils::stringToType(categoryValue, doubleValue)) {
183+
writer.Int64(static_cast<std::int64_t>(doubleValue));
184+
} else {
185+
writer.String(categoryValue);
186+
}
187+
break;
188+
case E_PredictionFieldTypeBool:
189+
if (core::CStringUtils::stringToType(categoryValue, doubleValue)) {
190+
writer.Bool(doubleValue != 0.0);
191+
} else {
192+
writer.String(categoryValue);
193+
}
194+
break;
195+
}
196+
}
197+
161198
CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr
162199
CDataFrameTrainBoostedTreeClassifierRunner::chooseLossFunction(const core::CDataFrame& frame,
163200
std::size_t dependentVariableColumn) const {

lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
4545
BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes].");
4646
}
4747

48-
BOOST_AUTO_TEST_CASE(testWriteOneRow) {
48+
template<typename T>
49+
void testWriteOneRow(const std::string& dependentVariableField,
50+
const std::string& predictionFieldType,
51+
T (rapidjson::Value::*extract)() const,
52+
const std::vector<T>& expectedPredictions) {
4953
// Prepare input data frame
50-
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"};
51-
const TStrVec categoricalColumns{"x1", "x2", "x5"};
54+
const std::string predictionField = dependentVariableField + "_prediction";
55+
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", predictionField};
56+
const TStrVec categoricalColumns{"x1", "x2", "x3", "x4", "x5"};
5257
const TStrVecVec rows{{"a", "b", "1.0", "1.0", "cat", "-1.0"},
53-
{"a", "b", "2.0", "2.0", "cat", "-0.5"},
54-
{"a", "b", "5.0", "5.0", "dog", "-0.1"},
55-
{"c", "d", "5.0", "5.0", "dog", "1.0"},
56-
{"e", "f", "5.0", "5.0", "dog", "1.5"}};
58+
{"a", "b", "1.0", "1.0", "cat", "-0.5"},
59+
{"a", "b", "5.0", "0.0", "dog", "-0.1"},
60+
{"c", "d", "5.0", "0.0", "dog", "1.0"},
61+
{"e", "f", "5.0", "0.0", "dog", "1.5"}};
5762
std::unique_ptr<core::CDataFrame> frame =
5863
core::makeMainStorageDataFrame(columnNames.size()).first;
5964
frame->columnNames(columnNames);
@@ -67,10 +72,21 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
6772

6873
// Create classification analysis runner object
6974
const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
70-
"classification", "x5", rows.size(), columnNames.size(), 13000000, 0, 0,
71-
categoricalColumns)};
75+
"classification", dependentVariableField, rows.size(),
76+
columnNames.size(), 13000000, 0, 0, categoricalColumns)};
7277
rapidjson::Document jsonParameters;
73-
jsonParameters.Parse("{\"dependent_variable\": \"x5\"}");
78+
if (predictionFieldType.empty()) {
79+
jsonParameters.Parse("{\"dependent_variable\": \"" + dependentVariableField + "\"}");
80+
} else {
81+
jsonParameters.Parse("{"
82+
" \"dependent_variable\": \"" +
83+
dependentVariableField +
84+
"\","
85+
" \"prediction_field_type\": \"" +
86+
predictionFieldType +
87+
"\""
88+
"}");
89+
}
7490
const auto parameters{
7591
api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)};
7692
api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters);
@@ -83,10 +99,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
8399

84100
frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) {
85101
const auto columnHoldingDependentVariable{
86-
std::find(columnNames.begin(), columnNames.end(), "x5") -
102+
std::find(columnNames.begin(), columnNames.end(), dependentVariableField) -
87103
columnNames.begin()};
88104
const auto columnHoldingPrediction{
89-
std::find(columnNames.begin(), columnNames.end(), "x5_prediction") -
105+
std::find(columnNames.begin(), columnNames.end(), predictionField) -
90106
columnNames.begin()};
91107
for (auto row = beginRows; row != endRows; ++row) {
92108
runner.writeOneRow(*frame, columnHoldingDependentVariable,
@@ -95,17 +111,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
95111
});
96112
}
97113
// Verify results
98-
const TStrVec expectedPredictions{"cat", "cat", "cat", "dog", "dog"};
99114
rapidjson::Document arrayDoc;
100115
arrayDoc.Parse<rapidjson::kParseDefaultFlags>(output.str().c_str());
101116
BOOST_TEST_REQUIRE(arrayDoc.IsArray());
102117
BOOST_TEST_REQUIRE(arrayDoc.Size() == rows.size());
118+
BOOST_TEST_REQUIRE(arrayDoc.Size() == expectedPredictions.size());
103119
for (std::size_t i = 0; i < arrayDoc.Size(); ++i) {
104120
BOOST_TEST_CONTEXT("Result for row " << i) {
105121
const rapidjson::Value& object = arrayDoc[rapidjson::SizeType(i)];
106122
BOOST_TEST_REQUIRE(object.IsObject());
107-
BOOST_TEST_REQUIRE(object.HasMember("x5_prediction"));
108-
BOOST_TEST_REQUIRE(object["x5_prediction"].GetString() ==
123+
BOOST_TEST_REQUIRE(object.HasMember(predictionField));
124+
BOOST_TEST_REQUIRE((object[predictionField].*extract)() ==
109125
expectedPredictions[i]);
110126
BOOST_TEST_REQUIRE(object.HasMember("prediction_probability"));
111127
BOOST_TEST_REQUIRE(object["prediction_probability"].GetDouble() > 0.5);
@@ -115,4 +131,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
115131
}
116132
}
117133

134+
BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsInt) {
135+
testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5});
136+
}
137+
138+
BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsBool) {
139+
testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool,
140+
{true, true, true, false, false});
141+
}
142+
143+
BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsString) {
144+
testWriteOneRow("x5", "string", &rapidjson::Value::GetString,
145+
{"cat", "cat", "cat", "dog", "dog"});
146+
}
147+
148+
BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsMissing) {
149+
testWriteOneRow("x5", "", &rapidjson::Value::GetString,
150+
{"cat", "cat", "cat", "dog", "dog"});
151+
}
152+
118153
BOOST_AUTO_TEST_SUITE_END()

0 commit comments

Comments
 (0)