diff --git a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h index ccf0ccccf1..27bb972305 100644 --- a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h +++ b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h @@ -19,6 +19,12 @@ namespace api { class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final : public CDataFrameTrainBoostedTreeRunner { public: + enum EPredictionFieldType { + E_PredictionFieldTypeString, + E_PredictionFieldTypeInt, + E_PredictionFieldTypeBool + }; + static const CDataFrameAnalysisConfigReader& parameterReader(); //! This is not intended to be called directly: use CDataFrameTrainBoostedTreeClassifierRunnerFactory. @@ -59,7 +65,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final private: std::size_t m_NumTopClasses; - std::string m_PredictionFieldType; + EPredictionFieldType m_PredictionFieldType; }; //! \brief Makes a core::CDataFrame boosted tree classification runner. diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index e60bcf5814..5b84439878 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -33,9 +33,6 @@ using TSizeVec = std::vector; // Configuration const std::string NUM_TOP_CLASSES{"num_top_classes"}; const std::string PREDICTION_FIELD_TYPE{"prediction_field_type"}; -const std::string PREDICTION_FIELD_TYPE_STRING{"string"}; -const std::string PREDICTION_FIELD_TYPE_INT{"int"}; -const std::string PREDICTION_FIELD_TYPE_BOOL{"bool"}; const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"}; // Output @@ -49,10 +46,16 @@ const std::string CLASS_PROBABILITY_FIELD_NAME{"class_probability"}; const CDataFrameAnalysisConfigReader& CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() { static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] { + const std::string typeString{"string"}; + const std::string typeInt{"int"}; + const std::string typeBool{"bool"}; auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader(); theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(PREDICTION_FIELD_TYPE, - CDataFrameAnalysisConfigReader::E_OptionalParameter); + CDataFrameAnalysisConfigReader::E_OptionalParameter, + {{typeString, int{E_PredictionFieldTypeString}}, + {typeInt, int{E_PredictionFieldTypeInt}}, + {typeBool, int{E_PredictionFieldTypeBool}}}); theReader.addParameter(BALANCED_CLASS_LOSS, CDataFrameAnalysisConfigReader::E_OptionalParameter); return theReader; @@ -67,7 +70,7 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0}); m_PredictionFieldType = - parameters[PREDICTION_FIELD_TYPE].fallback(PREDICTION_FIELD_TYPE_STRING); + parameters[PREDICTION_FIELD_TYPE].fallback(E_PredictionFieldTypeString); this->boostedTreeFactory().balanceClassTrainingLoss( parameters[BALANCED_CLASS_LOSS].fallback(true)); @@ -170,20 +173,26 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue( const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) const { - if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_INT) { - double doubleValue; + double doubleValue; + switch (m_PredictionFieldType) { + case E_PredictionFieldTypeString: + writer.String(categoryValue); + break; + case E_PredictionFieldTypeInt: if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { writer.Int64(static_cast(doubleValue)); - return; + } else { + writer.String(categoryValue); } - } else if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_BOOL) { - double doubleValue; + break; + case E_PredictionFieldTypeBool: if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { writer.Bool(static_cast(doubleValue) == 1.0); - return; + } else { + writer.String(categoryValue); } + break; } - writer.String(categoryValue); } CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr diff --git a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc index 1cb338d097..4dfca4c025 100644 --- a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc +++ b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc @@ -75,10 +75,14 @@ void testWriteOneRow(const std::string& dependentVariableField, "classification", dependentVariableField, rows.size(), columnNames.size(), 13000000, 0, 0, categoricalColumns)}; rapidjson::Document jsonParameters; - jsonParameters.Parse("{" - " \"dependent_variable\": \"" + dependentVariableField + "\"," - " \"prediction_field_type\": \"" + predictionFieldType + "\"" - "}"); + if (predictionFieldType.empty()) { + jsonParameters.Parse("{\"dependent_variable\": \"" + dependentVariableField + "\"}"); + } else { + jsonParameters.Parse("{" + " \"dependent_variable\": \"" + dependentVariableField + "\"," + " \"prediction_field_type\": \"" + predictionFieldType + "\"" + "}"); + } const auto parameters{ api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)}; api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters);