forked from elastic/ml-cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce classification analysis runner. (elastic#701)
- Loading branch information
1 parent
6b8d3f6
commit 06007cd
Showing
22 changed files
with
610 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
|
||
#ifndef INCLUDED_ml_api_CDataFrameClassificationRunner_h | ||
#define INCLUDED_ml_api_CDataFrameClassificationRunner_h | ||
|
||
#include <core/CDataSearcher.h> | ||
|
||
#include <api/CDataFrameAnalysisConfigReader.h> | ||
#include <api/CDataFrameAnalysisRunner.h> | ||
#include <api/CDataFrameAnalysisSpecification.h> | ||
#include <api/CDataFrameBoostedTreeRunner.h> | ||
#include <api/ImportExport.h> | ||
|
||
#include <rapidjson/fwd.h> | ||
|
||
#include <atomic> | ||
|
||
namespace ml { | ||
namespace api { | ||
|
||
//! \brief Runs boosted tree classification on a core::CDataFrame. | ||
class API_EXPORT CDataFrameClassificationRunner final : public CDataFrameBoostedTreeRunner { | ||
public: | ||
static const CDataFrameAnalysisConfigReader getParameterReader(); | ||
|
||
//! This is not intended to be called directly: use CDataFrameClassificationRunnerFactory. | ||
CDataFrameClassificationRunner(const CDataFrameAnalysisSpecification& spec, | ||
const CDataFrameAnalysisConfigReader::CParameters& parameters); | ||
|
||
//! This is not intended to be called directly: use CDataFrameClassificationRunnerFactory. | ||
CDataFrameClassificationRunner(const CDataFrameAnalysisSpecification& spec); | ||
|
||
//! \return Indicator of columns for which empty value should be treated as missing. | ||
TBoolVec columnsForWhichEmptyIsMissing(const TStrVec& fieldNames) const override; | ||
|
||
//! Write the prediction for \p row to \p writer. | ||
void writeOneRow(const TStrVec& featureNames, | ||
const TStrVecVec& categoricalFieldValues, | ||
TRowRef row, | ||
core::CRapidJsonConcurrentLineWriter& writer) const override; | ||
|
||
private: | ||
std::size_t m_NumTopClasses; | ||
}; | ||
|
||
//! \brief Makes a core::CDataFrame boosted tree classification runner. | ||
class API_EXPORT CDataFrameClassificationRunnerFactory final | ||
: public CDataFrameAnalysisRunnerFactory { | ||
public: | ||
const std::string& name() const override; | ||
|
||
private: | ||
static const std::string NAME; | ||
|
||
private: | ||
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec) const override; | ||
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec, | ||
const rapidjson::Value& jsonParameters) const override; | ||
}; | ||
} | ||
} | ||
|
||
#endif // INCLUDED_ml_api_CDataFrameClassificationRunner_h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
|
||
#ifndef INCLUDED_ml_api_CDataFrameRegressionRunner_h | ||
#define INCLUDED_ml_api_CDataFrameRegressionRunner_h | ||
|
||
#include <core/CDataSearcher.h> | ||
|
||
#include <api/CDataFrameAnalysisConfigReader.h> | ||
#include <api/CDataFrameAnalysisSpecification.h> | ||
#include <api/CDataFrameBoostedTreeRunner.h> | ||
#include <api/ImportExport.h> | ||
|
||
#include <rapidjson/fwd.h> | ||
|
||
#include <atomic> | ||
|
||
namespace ml { | ||
namespace api { | ||
|
||
//! \brief Runs boosted tree regression on a core::CDataFrame. | ||
class API_EXPORT CDataFrameRegressionRunner final : public CDataFrameBoostedTreeRunner { | ||
public: | ||
static const CDataFrameAnalysisConfigReader getParameterReader(); | ||
|
||
//! This is not intended to be called directly: use CDataFrameRegressionRunnerFactory. | ||
CDataFrameRegressionRunner(const CDataFrameAnalysisSpecification& spec, | ||
const CDataFrameAnalysisConfigReader::CParameters& parameters); | ||
|
||
//! This is not intended to be called directly: use CDataFrameRegressionRunnerFactory. | ||
CDataFrameRegressionRunner(const CDataFrameAnalysisSpecification& spec); | ||
|
||
//! Write the prediction for \p row to \p writer. | ||
void writeOneRow(const TStrVec& featureNames, | ||
const TStrVecVec& categoricalFieldValues, | ||
TRowRef row, | ||
core::CRapidJsonConcurrentLineWriter& writer) const override; | ||
}; | ||
|
||
//! \brief Makes a core::CDataFrame boosted tree regression runner. | ||
class API_EXPORT CDataFrameRegressionRunnerFactory final : public CDataFrameAnalysisRunnerFactory { | ||
public: | ||
const std::string& name() const override; | ||
|
||
private: | ||
static const std::string NAME; | ||
|
||
private: | ||
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec) const override; | ||
TRunnerUPtr makeImpl(const CDataFrameAnalysisSpecification& spec, | ||
const rapidjson::Value& jsonParameters) const override; | ||
}; | ||
} | ||
} | ||
|
||
#endif // INCLUDED_ml_api_CDataFrameRegressionRunner_h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.