Skip to content

Commit

Permalink
Implement new analysis type: classification (elastic#46537)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek committed Oct 4, 2019
1 parent 31a5e1c commit efcc4d1
Show file tree
Hide file tree
Showing 27 changed files with 1,833 additions and 427 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

public class Classification implements DataFrameAnalysis {

public static Classification fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

public static Builder builder(String dependentVariable) {
return new Builder(dependentVariable);
}

public static final ParseField NAME = new ParseField("classification");

static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable");
static final ParseField LAMBDA = new ParseField("lambda");
static final ParseField GAMMA = new ParseField("gamma");
static final ParseField ETA = new ParseField("eta");
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");

private static final ConstructingObjectParser<Classification, Void> PARSER =
new ConstructingObjectParser<>(
NAME.getPreferredName(),
true,
a -> new Classification(
(String) a[0],
(Double) a[1],
(Double) a[2],
(Double) a[3],
(Integer) a[4],
(Double) a[5],
(String) a[6],
(Double) a[7]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
}

private final String dependentVariable;
private final Double lambda;
private final Double gamma;
private final Double eta;
private final Integer maximumNumberTrees;
private final Double featureBagFraction;
private final String predictionFieldName;
private final Double trainingPercent;

private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
@Nullable Double trainingPercent) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
this.eta = eta;
this.maximumNumberTrees = maximumNumberTrees;
this.featureBagFraction = featureBagFraction;
this.predictionFieldName = predictionFieldName;
this.trainingPercent = trainingPercent;
}

@Override
public String getName() {
return NAME.getPreferredName();
}

public String getDependentVariable() {
return dependentVariable;
}

public Double getLambda() {
return lambda;
}

public Double getGamma() {
return gamma;
}

public Double getEta() {
return eta;
}

public Integer getMaximumNumberTrees() {
return maximumNumberTrees;
}

public Double getFeatureBagFraction() {
return featureBagFraction;
}

public String getPredictionFieldName() {
return predictionFieldName;
}

public Double getTrainingPercent() {
return trainingPercent;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable);
if (lambda != null) {
builder.field(LAMBDA.getPreferredName(), lambda);
}
if (gamma != null) {
builder.field(GAMMA.getPreferredName(), gamma);
}
if (eta != null) {
builder.field(ETA.getPreferredName(), eta);
}
if (maximumNumberTrees != null) {
builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees);
}
if (featureBagFraction != null) {
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
}
if (predictionFieldName != null) {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
if (trainingPercent != null) {
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
}
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Classification that = (Classification) o;
return Objects.equals(dependentVariable, that.dependentVariable)
&& Objects.equals(lambda, that.lambda)
&& Objects.equals(gamma, that.gamma)
&& Objects.equals(eta, that.eta)
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& Objects.equals(trainingPercent, that.trainingPercent);
}

@Override
public String toString() {
return Strings.toString(this);
}

public static class Builder {
private String dependentVariable;
private Double lambda;
private Double gamma;
private Double eta;
private Integer maximumNumberTrees;
private Double featureBagFraction;
private String predictionFieldName;
private Double trainingPercent;

private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
}

public Builder setLambda(Double lambda) {
this.lambda = lambda;
return this;
}

public Builder setGamma(Double gamma) {
this.gamma = gamma;
return this;
}

public Builder setEta(Double eta) {
this.eta = eta;
return this;
}

public Builder setMaximumNumberTrees(Integer maximumNumberTrees) {
this.maximumNumberTrees = maximumNumberTrees;
return this;
}

public Builder setFeatureBagFraction(Double featureBagFraction) {
this.featureBagFraction = featureBagFraction;
return this;
}

public Builder setPredictionFieldName(String predictionFieldName) {
this.predictionFieldName = predictionFieldName;
return this;
}

public Builder setTrainingPercent(Double trainingPercent) {
this.trainingPercent = trainingPercent;
return this;
}

public Classification build() {
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
new NamedXContentRegistry.Entry(
DataFrameAnalysis.class,
Regression.NAME,
(p, c) -> Regression.fromXContent(p)));
(p, c) -> Regression.fromXContent(p)),
new NamedXContentRegistry.Entry(
DataFrameAnalysis.class,
Classification.NAME,
(p, c) -> Classification.fromXContent(p)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,19 @@ public static Builder builder(String dependentVariable) {
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");

private static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), true,
a -> new Regression(
(String) a[0],
(Double) a[1],
(Double) a[2],
(Double) a[3],
(Integer) a[4],
(Double) a[5],
(String) a[6],
(Double) a[7]));
private static final ConstructingObjectParser<Regression, Void> PARSER =
new ConstructingObjectParser<>(
NAME.getPreferredName(),
true,
a -> new Regression(
(String) a[0],
(Double) a[1],
(Double) a[2],
(Double) a[3],
(Integer) a[4],
(Double) a[5],
(String) a[6],
(Double) a[7]));

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,41 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
assertThat(createdConfig.getDescription(), equalTo("this is a regression"));
}

public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configId = "test-put-df-analytics-classification";
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder()
.setId(configId)
.setSource(DataFrameAnalyticsSource.builder()
.setIndex("put-test-source-index")
.build())
.setDest(DataFrameAnalyticsDest.builder()
.setIndex("put-test-dest-index")
.build())
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification
.builder("my_dependent_variable")
.setTrainingPercent(80.0)
.build())
.setDescription("this is a classification")
.build();

createIndex("put-test-source-index", defaultMappingForTest());

PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute(
new PutDataFrameAnalyticsRequest(config),
machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync);
DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig();
assertThat(createdConfig.getId(), equalTo(config.getId()));
assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex()));
assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value
assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex()));
assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value
assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis()));
assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields()));
assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value
assertThat(createdConfig.getDescription(), equalTo("this is a classification"));
}

public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String configId = "get-test-config";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ public void testDefaultNamedXContents() {

public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(44, namedXContents.size());
assertEquals(48, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
Expand Down Expand Up @@ -718,9 +718,10 @@ public void testProvidedNamedXContents() {
assertTrue(names.contains(ShrinkAction.NAME));
assertTrue(names.contains(FreezeAction.NAME));
assertTrue(names.contains(SetPriorityAction.NAME));
assertEquals(Integer.valueOf(2), categories.get(DataFrameAnalysis.class));
assertEquals(Integer.valueOf(3), categories.get(DataFrameAnalysis.class));
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName()));
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Classification.NAME.getPreferredName()));
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
Expand Down
Loading

0 comments on commit efcc4d1

Please sign in to comment.