@@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
45
45
BOOST_TEST_REQUIRE (errors[0 ] == " Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes]." );
46
46
}
47
47
48
- BOOST_AUTO_TEST_CASE (testWriteOneRow) {
48
+ template <typename T>
49
+ void testWriteOneRow (const std::string& dependentVariableField,
50
+ const std::string& predictionFieldType,
51
+ T (rapidjson::Value::*extract)() const ,
52
+ const std::vector<T>& expectedPredictions) {
49
53
// Prepare input data frame
50
- const TStrVec columnNames{" x1" , " x2" , " x3" , " x4" , " x5" , " x5_prediction" };
51
- const TStrVec categoricalColumns{" x1" , " x2" , " x5" };
54
+ const std::string predictionField = dependentVariableField + " _prediction" ;
55
+ const TStrVec columnNames{" x1" , " x2" , " x3" , " x4" , " x5" , predictionField};
56
+ const TStrVec categoricalColumns{" x1" , " x2" , " x3" , " x4" , " x5" };
52
57
const TStrVecVec rows{{" a" , " b" , " 1.0" , " 1.0" , " cat" , " -1.0" },
53
- {" a" , " b" , " 2 .0" , " 2 .0" , " cat" , " -0.5" },
54
- {" a" , " b" , " 5.0" , " 5 .0" , " dog" , " -0.1" },
55
- {" c" , " d" , " 5.0" , " 5 .0" , " dog" , " 1.0" },
56
- {" e" , " f" , " 5.0" , " 5 .0" , " dog" , " 1.5" }};
58
+ {" a" , " b" , " 1 .0" , " 1 .0" , " cat" , " -0.5" },
59
+ {" a" , " b" , " 5.0" , " 0 .0" , " dog" , " -0.1" },
60
+ {" c" , " d" , " 5.0" , " 0 .0" , " dog" , " 1.0" },
61
+ {" e" , " f" , " 5.0" , " 0 .0" , " dog" , " 1.5" }};
57
62
std::unique_ptr<core::CDataFrame> frame =
58
63
core::makeMainStorageDataFrame (columnNames.size ()).first ;
59
64
frame->columnNames (columnNames);
@@ -67,10 +72,21 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
67
72
68
73
// Create classification analysis runner object
69
74
const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec (
70
- " classification" , " x5 " , rows.size (), columnNames. size (), 13000000 , 0 , 0 ,
71
- categoricalColumns)};
75
+ " classification" , dependentVariableField , rows.size (),
76
+ columnNames. size (), 13000000 , 0 , 0 , categoricalColumns)};
72
77
rapidjson::Document jsonParameters;
73
- jsonParameters.Parse (" {\" dependent_variable\" : \" x5\" }" );
78
+ if (predictionFieldType.empty ()) {
79
+ jsonParameters.Parse (" {\" dependent_variable\" : \" " + dependentVariableField + " \" }" );
80
+ } else {
81
+ jsonParameters.Parse (" {"
82
+ " \" dependent_variable\" : \" " +
83
+ dependentVariableField +
84
+ " \" ,"
85
+ " \" prediction_field_type\" : \" " +
86
+ predictionFieldType +
87
+ " \" "
88
+ " }" );
89
+ }
74
90
const auto parameters{
75
91
api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader ().read (jsonParameters)};
76
92
api::CDataFrameTrainBoostedTreeClassifierRunner runner (*spec, parameters);
@@ -83,10 +99,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
83
99
84
100
frame->readRows (1 , [&](TRowItr beginRows, TRowItr endRows) {
85
101
const auto columnHoldingDependentVariable{
86
- std::find (columnNames.begin (), columnNames.end (), " x5 " ) -
102
+ std::find (columnNames.begin (), columnNames.end (), dependentVariableField ) -
87
103
columnNames.begin ()};
88
104
const auto columnHoldingPrediction{
89
- std::find (columnNames.begin (), columnNames.end (), " x5_prediction " ) -
105
+ std::find (columnNames.begin (), columnNames.end (), predictionField ) -
90
106
columnNames.begin ()};
91
107
for (auto row = beginRows; row != endRows; ++row) {
92
108
runner.writeOneRow (*frame, columnHoldingDependentVariable,
@@ -95,17 +111,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
95
111
});
96
112
}
97
113
// Verify results
98
- const TStrVec expectedPredictions{" cat" , " cat" , " cat" , " dog" , " dog" };
99
114
rapidjson::Document arrayDoc;
100
115
arrayDoc.Parse <rapidjson::kParseDefaultFlags >(output.str ().c_str ());
101
116
BOOST_TEST_REQUIRE (arrayDoc.IsArray ());
102
117
BOOST_TEST_REQUIRE (arrayDoc.Size () == rows.size ());
118
+ BOOST_TEST_REQUIRE (arrayDoc.Size () == expectedPredictions.size ());
103
119
for (std::size_t i = 0 ; i < arrayDoc.Size (); ++i) {
104
120
BOOST_TEST_CONTEXT (" Result for row " << i) {
105
121
const rapidjson::Value& object = arrayDoc[rapidjson::SizeType (i)];
106
122
BOOST_TEST_REQUIRE (object.IsObject ());
107
- BOOST_TEST_REQUIRE (object.HasMember (" x5_prediction " ));
108
- BOOST_TEST_REQUIRE (object[" x5_prediction " ]. GetString () ==
123
+ BOOST_TEST_REQUIRE (object.HasMember (predictionField ));
124
+ BOOST_TEST_REQUIRE (( object[predictionField].*extract) () ==
109
125
expectedPredictions[i]);
110
126
BOOST_TEST_REQUIRE (object.HasMember (" prediction_probability" ));
111
127
BOOST_TEST_REQUIRE (object[" prediction_probability" ].GetDouble () > 0.5 );
@@ -115,4 +131,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
115
131
}
116
132
}
117
133
134
+ BOOST_AUTO_TEST_CASE (testWriteOneRowPredictionFieldTypeIsInt) {
135
+ testWriteOneRow (" x3" , " int" , &rapidjson::Value::GetInt, {1 , 1 , 1 , 5 , 5 });
136
+ }
137
+
138
+ BOOST_AUTO_TEST_CASE (testWriteOneRowPredictionFieldTypeIsBool) {
139
+ testWriteOneRow (" x4" , " bool" , &rapidjson::Value::GetBool,
140
+ {true , true , true , false , false });
141
+ }
142
+
143
+ BOOST_AUTO_TEST_CASE (testWriteOneRowPredictionFieldTypeIsString) {
144
+ testWriteOneRow (" x5" , " string" , &rapidjson::Value::GetString,
145
+ {" cat" , " cat" , " cat" , " dog" , " dog" });
146
+ }
147
+
148
+ BOOST_AUTO_TEST_CASE (testWriteOneRowPredictionFieldTypeIsMissing) {
149
+ testWriteOneRow (" x5" , " " , &rapidjson::Value::GetString,
150
+ {" cat" , " cat" , " cat" , " dog" , " dog" });
151
+ }
152
+
118
153
BOOST_AUTO_TEST_SUITE_END ()
0 commit comments