From efcc4d179ca222d45ce7abd91b0880ee6fb1b67c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Fri, 4 Oct 2019 11:46:13 +0200 Subject: [PATCH] Implement new analysis type: classification (#46537) --- .../client/ml/dataframe/Classification.java | 245 +++++++++++++ ...ataFrameAnalysisNamedXContentProvider.java | 6 +- .../client/ml/dataframe/Regression.java | 23 +- .../client/MachineLearningIT.java | 35 ++ .../client/RestHighLevelClientTests.java | 5 +- .../ml/dataframe/ClassificationTests.java | 54 +++ .../xpack/core/XPackClientPlugin.java | 2 + .../dataframe/analyses/BoostedTreeParams.java | 156 ++++++++ .../ml/dataframe/analyses/Classification.java | 186 ++++++++++ ...ataFrameAnalysisNamedXContentProvider.java | 41 ++- .../ml/dataframe/analyses/Regression.java | 135 ++----- .../core/ml/dataframe/analyses/Types.java | 24 +- .../persistence/ElasticsearchMappings.java | 43 ++- .../ml/job/results/ReservedFieldNames.java | 17 +- .../analyses/BoostedTreeParamsTests.java | 105 ++++++ .../analyses/ClassificationTests.java | 68 ++++ .../dataframe/analyses/RegressionTests.java | 84 +---- .../ml/qa/ml-with-security/build.gradle | 13 + .../ml/integration/ClassificationIT.java | 314 ++++++++++++++++ ...NativeDataFrameAnalyticsIntegTestCase.java | 21 +- .../OutlierDetectionWithMissingFieldsIT.java | 3 +- .../xpack/ml/integration/RegressionIT.java | 283 +++++++-------- .../integration/RunDataFrameAnalyticsIT.java | 22 +- .../CustomProcessorFactory.java | 10 +- ...a => DatasetSplittingCustomProcessor.java} | 10 +- ...DatasetSplittingCustomProcessorTests.java} | 15 +- .../test/ml/data_frame_analytics_crud.yml | 340 ++++++++++++++++++ 27 files changed, 1833 insertions(+), 427 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/{RegressionCustomProcessor.java => DatasetSplittingCustomProcessor.java} (86%) rename x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/{RegressionCustomProcessorTests.java => DatasetSplittingCustomProcessorTests.java} (87%) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java new file mode 100644 index 0000000000000..fb9234d25b84e --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -0,0 +1,245 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class Classification implements DataFrameAnalysis { + + public static Classification fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static Builder builder(String dependentVariable) { + return new Builder(dependentVariable); + } + + public static final ParseField NAME = new ParseField("classification"); + + static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable"); + static final ParseField LAMBDA = new ParseField("lambda"); + static final ParseField GAMMA = new ParseField("gamma"); + static final ParseField ETA = new ParseField("eta"); + static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); + static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); + static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + NAME.getPreferredName(), + true, + a -> new Classification( + (String) a[0], + (Double) a[1], + (Double) a[2], + (Double) a[3], + (Integer) a[4], + (Double) a[5], + (String) a[6], + (Double) a[7])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); + } + + private final String dependentVariable; + private final Double lambda; + private final Double gamma; + private final Double eta; + private final Integer maximumNumberTrees; + private final Double featureBagFraction; + private final String predictionFieldName; + private final Double trainingPercent; + + private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, + @Nullable Double trainingPercent) { + this.dependentVariable = Objects.requireNonNull(dependentVariable); + this.lambda = lambda; + this.gamma = gamma; + this.eta = eta; + this.maximumNumberTrees = maximumNumberTrees; + this.featureBagFraction = featureBagFraction; + this.predictionFieldName = predictionFieldName; + this.trainingPercent = trainingPercent; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + public String getDependentVariable() { + return dependentVariable; + } + + public Double getLambda() { + return lambda; + } + + public Double getGamma() { + return gamma; + } + + public Double getEta() { + return eta; + } + + public Integer getMaximumNumberTrees() { + return maximumNumberTrees; + } + + public Double getFeatureBagFraction() { + return featureBagFraction; + } + + public String getPredictionFieldName() { + return predictionFieldName; + } + + public Double getTrainingPercent() { + return trainingPercent; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); + if (lambda != null) { + builder.field(LAMBDA.getPreferredName(), lambda); + } + if (gamma != null) { + builder.field(GAMMA.getPreferredName(), gamma); + } + if (eta != null) { + builder.field(ETA.getPreferredName(), eta); + } + if (maximumNumberTrees != null) { + builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees); + } + if (featureBagFraction != null) { + builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + } + if (predictionFieldName != null) { + builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); + } + if (trainingPercent != null) { + builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); + } + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, + trainingPercent); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Classification that = (Classification) o; + return Objects.equals(dependentVariable, that.dependentVariable) + && Objects.equals(lambda, that.lambda) + && Objects.equals(gamma, that.gamma) + && Objects.equals(eta, that.eta) + && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) + && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(predictionFieldName, that.predictionFieldName) + && Objects.equals(trainingPercent, that.trainingPercent); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public static class Builder { + private String dependentVariable; + private Double lambda; + private Double gamma; + private Double eta; + private Integer maximumNumberTrees; + private Double featureBagFraction; + private String predictionFieldName; + private Double trainingPercent; + + private Builder(String dependentVariable) { + this.dependentVariable = Objects.requireNonNull(dependentVariable); + } + + public Builder setLambda(Double lambda) { + this.lambda = lambda; + return this; + } + + public Builder setGamma(Double gamma) { + this.gamma = gamma; + return this; + } + + public Builder setEta(Double eta) { + this.eta = eta; + return this; + } + + public Builder setMaximumNumberTrees(Integer maximumNumberTrees) { + this.maximumNumberTrees = maximumNumberTrees; + return this; + } + + public Builder setFeatureBagFraction(Double featureBagFraction) { + this.featureBagFraction = featureBagFraction; + return this; + } + + public Builder setPredictionFieldName(String predictionFieldName) { + this.predictionFieldName = predictionFieldName; + return this; + } + + public Builder setTrainingPercent(Double trainingPercent) { + this.trainingPercent = trainingPercent; + return this; + } + + public Classification build() { + return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, + trainingPercent); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java index 809317d735b54..e2692385bd508 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java @@ -36,6 +36,10 @@ public List getNamedXContentParsers() { new NamedXContentRegistry.Entry( DataFrameAnalysis.class, Regression.NAME, - (p, c) -> Regression.fromXContent(p))); + (p, c) -> Regression.fromXContent(p)), + new NamedXContentRegistry.Entry( + DataFrameAnalysis.class, + Classification.NAME, + (p, c) -> Classification.fromXContent(p))); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java index 450da1a3e0c94..3c1edece6fc16 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java @@ -49,16 +49,19 @@ public static Builder builder(String dependentVariable) { static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), true, - a -> new Regression( - (String) a[0], - (Double) a[1], - (Double) a[2], - (Double) a[3], - (Integer) a[4], - (Double) a[5], - (String) a[6], - (Double) a[7])); + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + NAME.getPreferredName(), + true, + a -> new Regression( + (String) a[0], + (Double) a[1], + (Double) a[2], + (Double) a[3], + (Integer) a[4], + (Double) a[5], + (String) a[6], + (Double) a[7])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index df85f0483d67a..9c8663d8eb383 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1315,6 +1315,41 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception { assertThat(createdConfig.getDescription(), equalTo("this is a regression")); } + public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "test-put-df-analytics-classification"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder() + .setId(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("put-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("put-test-dest-index") + .build()) + .setAnalysis(org.elasticsearch.client.ml.dataframe.Classification + .builder("my_dependent_variable") + .setTrainingPercent(80.0) + .build()) + .setDescription("this is a classification") + .build(); + + createIndex("put-test-source-index", defaultMappingForTest()); + + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig(); + assertThat(createdConfig.getId(), equalTo(config.getId())); + assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex())); + assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value + assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex())); + assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value + assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis())); + assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields())); + assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value + assertThat(createdConfig.getDescription(), equalTo("this is a classification")); + } + public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); String configId = "get-test-config"; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index d71043ce8cdcf..bd9a9bb6212c8 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -684,7 +684,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(44, namedXContents.size()); + assertEquals(48, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -718,9 +718,10 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(ShrinkAction.NAME)); assertTrue(names.contains(FreezeAction.NAME)); assertTrue(names.contains(SetPriorityAction.NAME)); - assertEquals(Integer.valueOf(2), categories.get(DataFrameAnalysis.class)); + assertEquals(Integer.valueOf(3), categories.get(DataFrameAnalysis.class)); assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName())); + assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Classification.NAME.getPreferredName())); assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); assertTrue(names.contains(TimeSyncConfig.NAME)); assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java new file mode 100644 index 0000000000000..9f0a418178dd5 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -0,0 +1,54 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class ClassificationTests extends AbstractXContentTestCase { + + public static Classification randomClassification() { + return Classification.builder(randomAlphaOfLength(10)) + .setLambda(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true)) + .setGamma(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true)) + .setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true)) + .setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000)) + .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) + .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10)) + .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) + .build(); + } + + @Override + protected Classification createTestInstance() { + return randomClassification(); + } + + @Override + protected Classification doParseInstance(XContentParser parser) throws IOException { + return Classification.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 6c9342cc710eb..590914fd93dec 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -133,6 +133,7 @@ import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; @@ -466,6 +467,7 @@ public List getNamedWriteables() { // ML - Data frame analytics new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new), new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new), + new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new), // ML - Data frame evaluation new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), BinarySoftClassification::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java new file mode 100644 index 0000000000000..ed3cff7d73c0c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.AbstractObjectParser; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * Parameters used by both {@link Classification} and {@link Regression} analyses. + */ +public class BoostedTreeParams implements ToXContentFragment, Writeable { + + static final String NAME = "boosted_tree_params"; + + public static final ParseField LAMBDA = new ParseField("lambda"); + public static final ParseField GAMMA = new ParseField("gamma"); + public static final ParseField ETA = new ParseField("eta"); + public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); + public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + + static void declareFields(AbstractObjectParser parser) { + parser.declareDouble(optionalConstructorArg(), LAMBDA); + parser.declareDouble(optionalConstructorArg(), GAMMA); + parser.declareDouble(optionalConstructorArg(), ETA); + parser.declareInt(optionalConstructorArg(), MAXIMUM_NUMBER_TREES); + parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION); + } + + private final Double lambda; + private final Double gamma; + private final Double eta; + private final Integer maximumNumberTrees; + private final Double featureBagFraction; + + BoostedTreeParams(@Nullable Double lambda, + @Nullable Double gamma, + @Nullable Double eta, + @Nullable Integer maximumNumberTrees, + @Nullable Double featureBagFraction) { + if (lambda != null && lambda < 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName()); + } + if (gamma != null && gamma < 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", GAMMA.getPreferredName()); + } + if (eta != null && (eta < 0.001 || eta > 1)) { + throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.001, 1]", ETA.getPreferredName()); + } + if (maximumNumberTrees != null && (maximumNumberTrees <= 0 || maximumNumberTrees > 2000)) { + throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, 2000]", MAXIMUM_NUMBER_TREES.getPreferredName()); + } + if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) { + throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName()); + } + this.lambda = lambda; + this.gamma = gamma; + this.eta = eta; + this.maximumNumberTrees = maximumNumberTrees; + this.featureBagFraction = featureBagFraction; + } + + BoostedTreeParams() { + this(null, null, null, null, null); + } + + BoostedTreeParams(StreamInput in) throws IOException { + lambda = in.readOptionalDouble(); + gamma = in.readOptionalDouble(); + eta = in.readOptionalDouble(); + maximumNumberTrees = in.readOptionalVInt(); + featureBagFraction = in.readOptionalDouble(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalDouble(lambda); + out.writeOptionalDouble(gamma); + out.writeOptionalDouble(eta); + out.writeOptionalVInt(maximumNumberTrees); + out.writeOptionalDouble(featureBagFraction); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (lambda != null) { + builder.field(LAMBDA.getPreferredName(), lambda); + } + if (gamma != null) { + builder.field(GAMMA.getPreferredName(), gamma); + } + if (eta != null) { + builder.field(ETA.getPreferredName(), eta); + } + if (maximumNumberTrees != null) { + builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees); + } + if (featureBagFraction != null) { + builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + } + return builder; + } + + Map getParams() { + Map params = new HashMap<>(); + if (lambda != null) { + params.put(LAMBDA.getPreferredName(), lambda); + } + if (gamma != null) { + params.put(GAMMA.getPreferredName(), gamma); + } + if (eta != null) { + params.put(ETA.getPreferredName(), eta); + } + if (maximumNumberTrees != null) { + params.put(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees); + } + if (featureBagFraction != null) { + params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); + } + return params; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BoostedTreeParams that = (BoostedTreeParams) o; + return Objects.equals(lambda, that.lambda) + && Objects.equals(gamma, that.gamma) + && Objects.equals(eta, that.eta) + && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) + && Objects.equals(featureBagFraction, that.featureBagFraction); + } + + @Override + public int hashCode() { + return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java new file mode 100644 index 0000000000000..96c03b7692f20 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -0,0 +1,186 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class Classification implements DataFrameAnalysis { + + public static final ParseField NAME = new ParseField("classification"); + + public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable"); + public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); + public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + public static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME.getPreferredName(), + lenient, + a -> new Classification( + (String) a[0], + new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), + (String) a[6], + (Integer) a[7], + (Double) a[8])); + parser.declareString(constructorArg(), DEPENDENT_VARIABLE); + BoostedTreeParams.declareFields(parser); + parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); + parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); + parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT); + return parser; + } + + public static Classification fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + private final String dependentVariable; + private final BoostedTreeParams boostedTreeParams; + private final String predictionFieldName; + private final int numTopClasses; + private final double trainingPercent; + + public Classification(String dependentVariable, + BoostedTreeParams boostedTreeParams, + @Nullable String predictionFieldName, + @Nullable Integer numTopClasses, + @Nullable Double trainingPercent) { + if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) { + throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName()); + } + if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) { + throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName()); + } + this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE); + this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME); + this.predictionFieldName = predictionFieldName; + this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; + this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; + } + + public Classification(String dependentVariable) { + this(dependentVariable, new BoostedTreeParams(), null, null, null); + } + + public Classification(StreamInput in) throws IOException { + dependentVariable = in.readString(); + boostedTreeParams = new BoostedTreeParams(in); + predictionFieldName = in.readOptionalString(); + numTopClasses = in.readOptionalVInt(); + trainingPercent = in.readDouble(); + } + + public String getDependentVariable() { + return dependentVariable; + } + + public double getTrainingPercent() { + return trainingPercent; + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(dependentVariable); + boostedTreeParams.writeTo(out); + out.writeOptionalString(predictionFieldName); + out.writeOptionalVInt(numTopClasses); + out.writeDouble(trainingPercent); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); + boostedTreeParams.toXContent(builder, params); + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + if (predictionFieldName != null) { + builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); + } + builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); + builder.endObject(); + return builder; + } + + @Override + public Map getParams() { + Map params = new HashMap<>(); + params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); + params.putAll(boostedTreeParams.getParams()); + params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + if (predictionFieldName != null) { + params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); + } + return params; + } + + @Override + public boolean supportsCategoricalFields() { + return true; + } + + @Override + public List getRequiredFields() { + return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical())); + } + + @Override + public boolean supportsMissingValues() { + return true; + } + + @Override + public boolean persistsState() { + return false; + } + + @Override + public String getStateDocId(String jobId) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Classification that = (Classification) o; + return Objects.equals(dependentVariable, that.dependentVariable) + && Objects.equals(boostedTreeParams, that.boostedTreeParams) + && Objects.equals(predictionFieldName, that.predictionFieldName) + && Objects.equals(numTopClasses, that.numTopClasses) + && trainingPercent == that.trainingPercent; + } + + @Override + public int hashCode() { + return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java index e33a774859224..120cd24f6278c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java @@ -9,35 +9,34 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentProvider { @Override public List getNamedXContentParsers() { - List namedXContent = new ArrayList<>(); - - namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME, (p, c) -> { - boolean ignoreUnknownFields = (boolean) c; - return OutlierDetection.fromXContent(p, ignoreUnknownFields); - })); - namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Regression.NAME, (p, c) -> { - boolean ignoreUnknownFields = (boolean) c; - return Regression.fromXContent(p, ignoreUnknownFields); - })); - - return namedXContent; + return Arrays.asList( + new NamedXContentRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME, (p, c) -> { + boolean ignoreUnknownFields = (boolean) c; + return OutlierDetection.fromXContent(p, ignoreUnknownFields); + }), + new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Regression.NAME, (p, c) -> { + boolean ignoreUnknownFields = (boolean) c; + return Regression.fromXContent(p, ignoreUnknownFields); + }), + new NamedXContentRegistry.Entry(DataFrameAnalysis.class, Classification.NAME, (p, c) -> { + boolean ignoreUnknownFields = (boolean) c; + return Classification.fromXContent(p, ignoreUnknownFields); + }) + ); } public List getNamedWriteables() { - List namedWriteables = new ArrayList<>(); - - namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), - OutlierDetection::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), - Regression::new)); - - return namedWriteables; + return Arrays.asList( + new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new), + new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new), + new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new) + ); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 6e60f1de57a4d..e804c7d176189 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -21,16 +21,14 @@ import java.util.Map; import java.util.Objects; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + public class Regression implements DataFrameAnalysis { public static final ParseField NAME = new ParseField("regression"); public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable"); - public static final ParseField LAMBDA = new ParseField("lambda"); - public static final ParseField GAMMA = new ParseField("gamma"); - public static final ParseField ETA = new ParseField("eta"); - public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); - public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); public static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); @@ -38,17 +36,18 @@ public class Regression implements DataFrameAnalysis { private static final ConstructingObjectParser STRICT_PARSER = createParser(false); private static ConstructingObjectParser createParser(boolean lenient) { - ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient, - a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6], + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME.getPreferredName(), + lenient, + a -> new Regression( + (String) a[0], + new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), + (String) a[6], (Double) a[7])); - parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); - parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA); - parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA); - parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); - parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); - parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); - parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); - parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); + parser.declareString(constructorArg(), DEPENDENT_VARIABLE); + BoostedTreeParams.declareFields(parser); + parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); + parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT); return parser; } @@ -57,63 +56,30 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno } private final String dependentVariable; - private final Double lambda; - private final Double gamma; - private final Double eta; - private final Integer maximumNumberTrees; - private final Double featureBagFraction; + private final BoostedTreeParams boostedTreeParams; private final String predictionFieldName; private final double trainingPercent; - public Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, - @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, + public Regression(String dependentVariable, + BoostedTreeParams boostedTreeParams, + @Nullable String predictionFieldName, @Nullable Double trainingPercent) { - this.dependentVariable = Objects.requireNonNull(dependentVariable); - - if (lambda != null && lambda < 0) { - throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName()); - } - this.lambda = lambda; - - if (gamma != null && gamma < 0) { - throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", GAMMA.getPreferredName()); - } - this.gamma = gamma; - - if (eta != null && (eta < 0.001 || eta > 1)) { - throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.001, 1]", ETA.getPreferredName()); - } - this.eta = eta; - - if (maximumNumberTrees != null && (maximumNumberTrees <= 0 || maximumNumberTrees > 2000)) { - throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, 2000]", MAXIMUM_NUMBER_TREES.getPreferredName()); - } - this.maximumNumberTrees = maximumNumberTrees; - - if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) { - throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName()); - } - this.featureBagFraction = featureBagFraction; - - this.predictionFieldName = predictionFieldName; - if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) { throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName()); } + this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE); + this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME); + this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; } public Regression(String dependentVariable) { - this(dependentVariable, null, null, null, null, null, null, null); + this(dependentVariable, new BoostedTreeParams(), null, null); } public Regression(StreamInput in) throws IOException { dependentVariable = in.readString(); - lambda = in.readOptionalDouble(); - gamma = in.readOptionalDouble(); - eta = in.readOptionalDouble(); - maximumNumberTrees = in.readOptionalVInt(); - featureBagFraction = in.readOptionalDouble(); + boostedTreeParams = new BoostedTreeParams(in); predictionFieldName = in.readOptionalString(); trainingPercent = in.readDouble(); } @@ -134,11 +100,7 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(dependentVariable); - out.writeOptionalDouble(lambda); - out.writeOptionalDouble(gamma); - out.writeOptionalDouble(eta); - out.writeOptionalVInt(maximumNumberTrees); - out.writeOptionalDouble(featureBagFraction); + boostedTreeParams.writeTo(out); out.writeOptionalString(predictionFieldName); out.writeDouble(trainingPercent); } @@ -147,21 +109,7 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); - if (lambda != null) { - builder.field(LAMBDA.getPreferredName(), lambda); - } - if (gamma != null) { - builder.field(GAMMA.getPreferredName(), gamma); - } - if (eta != null) { - builder.field(ETA.getPreferredName(), eta); - } - if (maximumNumberTrees != null) { - builder.field(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees); - } - if (featureBagFraction != null) { - builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); - } + boostedTreeParams.toXContent(builder, params); if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } @@ -174,21 +122,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public Map getParams() { Map params = new HashMap<>(); params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); - if (lambda != null) { - params.put(LAMBDA.getPreferredName(), lambda); - } - if (gamma != null) { - params.put(GAMMA.getPreferredName(), gamma); - } - if (eta != null) { - params.put(ETA.getPreferredName(), eta); - } - if (maximumNumberTrees != null) { - params.put(MAXIMUM_NUMBER_TREES.getPreferredName(), maximumNumberTrees); - } - if (featureBagFraction != null) { - params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); - } + params.putAll(boostedTreeParams.getParams()); if (predictionFieldName != null) { params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } @@ -220,24 +154,19 @@ public String getStateDocId(String jobId) { return jobId + "_regression_state#1"; } - @Override - public int hashCode() { - return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent); - } - @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Regression that = (Regression) o; return Objects.equals(dependentVariable, that.dependentVariable) - && Objects.equals(lambda, that.lambda) - && Objects.equals(gamma, that.gamma) - && Objects.equals(eta, that.eta) - && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) - && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(boostedTreeParams, that.boostedTreeParams) && Objects.equals(predictionFieldName, that.predictionFieldName) && trainingPercent == that.trainingPercent; } + + @Override + public int hashCode() { + return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java index ba7cac81d7fda..fc991c86f5663 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Types.java @@ -7,9 +7,7 @@ import org.elasticsearch.index.mapper.NumberFieldMapper; -import java.util.Arrays; import java.util.Collections; -import java.util.HashSet; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -21,17 +19,17 @@ public final class Types { private Types() {} - private static final Set CATEGORICAL_TYPES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("text", "keyword", "ip"))); - - private static final Set NUMERICAL_TYPES; - - static { - Set numericalTypes = Stream.of(NumberFieldMapper.NumberType.values()) - .map(NumberFieldMapper.NumberType::typeName) - .collect(Collectors.toSet()); - numericalTypes.add("scaled_float"); - NUMERICAL_TYPES = Collections.unmodifiableSet(numericalTypes); - } + private static final Set CATEGORICAL_TYPES = + Collections.unmodifiableSet( + Stream.of("text", "keyword", "ip") + .collect(Collectors.toSet())); + + private static final Set NUMERICAL_TYPES = + Collections.unmodifiableSet( + Stream.concat( + Stream.of(NumberFieldMapper.NumberType.values()).map(NumberFieldMapper.NumberType::typeName), + Stream.of("scaled_float")) + .collect(Collectors.toSet())); public static Set categorical() { return CATEGORICAL_TYPES; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index 804e9c8dcda69..1eec2a04be4f9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -30,6 +30,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; @@ -449,19 +451,19 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .startObject(Regression.DEPENDENT_VARIABLE.getPreferredName()) .field(TYPE, KEYWORD) .endObject() - .startObject(Regression.LAMBDA.getPreferredName()) + .startObject(BoostedTreeParams.LAMBDA.getPreferredName()) .field(TYPE, DOUBLE) .endObject() - .startObject(Regression.GAMMA.getPreferredName()) + .startObject(BoostedTreeParams.GAMMA.getPreferredName()) .field(TYPE, DOUBLE) .endObject() - .startObject(Regression.ETA.getPreferredName()) + .startObject(BoostedTreeParams.ETA.getPreferredName()) .field(TYPE, DOUBLE) .endObject() - .startObject(Regression.MAXIMUM_NUMBER_TREES.getPreferredName()) + .startObject(BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName()) .field(TYPE, INTEGER) .endObject() - .startObject(Regression.FEATURE_BAG_FRACTION.getPreferredName()) + .startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName()) .field(TYPE, DOUBLE) .endObject() .startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName()) @@ -472,6 +474,37 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .endObject() .endObject() .endObject() + .startObject(Classification.NAME.getPreferredName()) + .startObject(PROPERTIES) + .startObject(Classification.DEPENDENT_VARIABLE.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(BoostedTreeParams.LAMBDA.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(BoostedTreeParams.GAMMA.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(BoostedTreeParams.ETA.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() + .startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .startObject(Classification.PREDICTION_FIELD_NAME.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(Classification.NUM_TOP_CLASSES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() + .startObject(Classification.TRAINING_PERCENT.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .endObject() + .endObject() .endObject() .endObject() // re-used: CREATE_TIME diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index cba7a7f634d2e..c40fa2f026bfb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -13,6 +13,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; @@ -303,13 +305,18 @@ public final class ReservedFieldNames { OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName(), Regression.NAME.getPreferredName(), Regression.DEPENDENT_VARIABLE.getPreferredName(), - Regression.LAMBDA.getPreferredName(), - Regression.GAMMA.getPreferredName(), - Regression.ETA.getPreferredName(), - Regression.MAXIMUM_NUMBER_TREES.getPreferredName(), - Regression.FEATURE_BAG_FRACTION.getPreferredName(), Regression.PREDICTION_FIELD_NAME.getPreferredName(), Regression.TRAINING_PERCENT.getPreferredName(), + Classification.NAME.getPreferredName(), + Classification.DEPENDENT_VARIABLE.getPreferredName(), + Classification.PREDICTION_FIELD_NAME.getPreferredName(), + Classification.NUM_TOP_CLASSES.getPreferredName(), + Classification.TRAINING_PERCENT.getPreferredName(), + BoostedTreeParams.LAMBDA.getPreferredName(), + BoostedTreeParams.GAMMA.getPreferredName(), + BoostedTreeParams.ETA.getPreferredName(), + BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName(), + BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName(), ElasticsearchMappings.CONFIG_TYPE, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java new file mode 100644 index 0000000000000..145533df407cd --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; + +public class BoostedTreeParamsTests extends AbstractSerializingTestCase { + + @Override + protected BoostedTreeParams doParseInstance(XContentParser parser) throws IOException { + ConstructingObjectParser objParser = + new ConstructingObjectParser<>( + BoostedTreeParams.NAME, + true, + a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4])); + BoostedTreeParams.declareFields(objParser); + return objParser.apply(parser, null); + } + + @Override + protected BoostedTreeParams createTestInstance() { + return createRandom(); + } + + public static BoostedTreeParams createRandom() { + Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); + Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); + Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true); + Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000); + Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false); + return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction); + } + + @Override + protected Writeable.Reader instanceReader() { + return BoostedTreeParams::new; + } + + public void testConstructor_GivenNegativeLambda() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new BoostedTreeParams(-0.00001, 0.0, 0.5, 500, 0.3)); + + assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double")); + } + + public void testConstructor_GivenNegativeGamma() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new BoostedTreeParams(0.0, -0.00001, 0.5, 500, 0.3)); + + assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double")); + } + + public void testConstructor_GivenEtaIsZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new BoostedTreeParams(0.0, 0.0, 0.0, 500, 0.3)); + + assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); + } + + public void testConstructor_GivenEtaIsGreaterThanOne() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new BoostedTreeParams(0.0, 0.0, 1.00001, 500, 0.3)); + + assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); + } + + public void testConstructor_GivenMaximumNumberTreesIsZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new BoostedTreeParams(0.0, 0.0, 0.5, 0, 0.3)); + + assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); + } + + public void testConstructor_GivenMaximumNumberTreesIsGreaterThan2k() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new BoostedTreeParams(0.0, 0.0, 0.5, 2001, 0.3)); + + assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); + } + + public void testConstructor_GivenFeatureBagFractionIsLessThanZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, -0.00001)); + + assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); + } + + public void testConstructor_GivenFeatureBagFractionIsGreaterThanOne() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.00001)); + + assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java new file mode 100644 index 0000000000000..e67f297094620 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; + +public class ClassificationTests extends AbstractSerializingTestCase { + + @Override + protected Classification doParseInstance(XContentParser parser) throws IOException { + return Classification.fromXContent(parser, false); + } + + @Override + protected Classification createTestInstance() { + return createRandom(); + } + + public static Classification createRandom() { + String dependentVariableName = randomAlphaOfLength(10); + BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom(); + String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10); + Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000); + Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true); + return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent); + } + + @Override + protected Writeable.Reader instanceReader() { + return Classification::new; + } + + public void testConstructor_GivenTrainingPercentIsNull() { + Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, null); + assertThat(classification.getTrainingPercent(), equalTo(100.0)); + } + + public void testConstructor_GivenTrainingPercentIsBoundary() { + Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 1.0); + assertThat(classification.getTrainingPercent(), equalTo(1.0)); + classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0); + assertThat(classification.getTrainingPercent(), equalTo(100.0)); + } + + public void testConstructor_GivenTrainingPercentIsLessThanOne() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 0.999)); + + assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); + } + + public void testConstructor_GivenTrainingPercentIsGreaterThan100() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0001)); + + assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index fbbd7bce0dc07..9e7a898afe66c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -28,15 +28,11 @@ protected Regression createTestInstance() { } public static Regression createRandom() { - Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); - Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); - Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true); - Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000); - Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false); + String dependentVariableName = randomAlphaOfLength(10); + BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom(); String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10); Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true); - return new Regression(randomAlphaOfLength(10), lambda, gamma, eta, maximumNumberTrees, featureBagFraction, - predictionFieldName, trainingPercent); + return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent); } @Override @@ -44,84 +40,28 @@ protected Writeable.Reader instanceReader() { return Regression::new; } - public void testRegression_GivenNegativeLambda() { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result", 100.0)); - - assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double")); - } - - public void testRegression_GivenNegativeGamma() { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result", 100.0)); - - assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double")); - } - - public void testRegression_GivenEtaIsZero() { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result", 100.0)); - - assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); - } - - public void testRegression_GivenEtaIsGreaterThanOne() { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result", 100.0)); - - assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); - } - - public void testRegression_GivenMaximumNumberTreesIsZero() { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result", 100.0)); - - assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); - } - - public void testRegression_GivenMaximumNumberTreesIsGreaterThan2k() { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result", 100.0)); - - assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); - } - - public void testRegression_GivenFeatureBagFractionIsLessThanZero() { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result", 100.0)); - - assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); - } - - public void testRegression_GivenFeatureBagFractionIsGreaterThanOne() { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result", 100.0)); - - assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); - } - - public void testRegression_GivenTrainingPercentIsNull() { - Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", null); + public void testConstructor_GivenTrainingPercentIsNull() { + Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", null); assertThat(regression.getTrainingPercent(), equalTo(100.0)); } - public void testRegression_GivenTrainingPercentIsBoundary() { - Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 1.0); + public void testConstructor_GivenTrainingPercentIsBoundary() { + Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 1.0); assertThat(regression.getTrainingPercent(), equalTo(1.0)); - regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0); + regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0); assertThat(regression.getTrainingPercent(), equalTo(100.0)); } - public void testRegression_GivenTrainingPercentIsLessThanOne() { + public void testConstructor_GivenTrainingPercentIsLessThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 0.999)); + () -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 0.999)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } - public void testRegression_GivenTrainingPercentIsGreaterThan100() { + public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0001)); + () -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0001)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 37a29098cea1d..6dfa5798c4b7e 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -73,6 +73,19 @@ integTest.runner { 'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one', 'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one', 'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred', + 'ml/data_frame_analytics_crud/Test put classification given dependent_variable is not defined', + 'ml/data_frame_analytics_crud/Test put classification given negative lambda', + 'ml/data_frame_analytics_crud/Test put classification given negative gamma', + 'ml/data_frame_analytics_crud/Test put classification given eta less than 1e-3', + 'ml/data_frame_analytics_crud/Test put classification given eta greater than one', + 'ml/data_frame_analytics_crud/Test put classification given maximum_number_trees is zero', + 'ml/data_frame_analytics_crud/Test put classification given maximum_number_trees is greater than 2k', + 'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is negative', + 'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one', + 'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero', + 'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k', + 'ml/data_frame_analytics_crud/Test put classification given training_percent is less than one', + 'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred', 'ml/evaluate_data_frame/Test given missing index', 'ml/evaluate_data_frame/Test given index does not exist', 'ml/evaluate_data_frame/Test given missing evaluation', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java new file mode 100644 index 0000000000000..8a8040f586f80 --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -0,0 +1,314 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import com.google.common.collect.Ordering; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.junit.After; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.in; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { + + private static final String NUMERICAL_FEATURE_FIELD = "feature"; + private static final String DEPENDENT_VARIABLE_FIELD = "variable"; + private static final List NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0)); + private static final List DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat", "cow")); + + private String jobId; + private String sourceIndex; + private String destIndex; + + @After + public void cleanup() throws Exception { + cleanUp(); + } + + public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { + initialize("classification_single_numeric_feature_and_mixed_data_set"); + + { // Index 350 rows, 300 of them being training rows. + client().admin().indices().prepareCreate(sourceIndex) + .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword") + .get(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < 300; i++) { + Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); + String value = DEPENDENT_VARIABLE_VALUES.get(i % 3); + + IndexRequest indexRequest = new IndexRequest(sourceIndex) + .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value); + bulkRequestBuilder.add(indexRequest); + } + for (int i = 300; i < 350; i++) { + Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); + + IndexRequest indexRequest = new IndexRequest(sourceIndex) + .source(NUMERICAL_FEATURE_FIELD, field); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + } + + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD)); + registerAnalytics(config); + putAnalytics(config); + + assertState(jobId, DataFrameAnalyticsState.STOPPED); + assertProgress(jobId, 0, 0, 0, 0); + + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); + for (SearchHit hit : sourceData.getHits()) { + Map destDoc = getDestDoc(config, hit); + Map resultsObject = getMlResultsObjectFromDestDoc(destDoc); + + assertThat(resultsObject.containsKey("variable_prediction"), is(true)); + assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES))); + assertThat(resultsObject.containsKey("is_training"), is(true)); + assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD))); + assertThat(resultsObject.containsKey("top_classes"), is(false)); + } + + assertProgress(jobId, 100, 100, 100, 100); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertThatAuditMessagesMatch(jobId, + "Created analytics with analysis type [classification]", + "Estimated memory usage for this analytics to be", + "Started analytics", + "Creating destination index [" + destIndex + "]", + "Finished reindexing to destination index [" + destIndex + "]", + "Finished analysis"); + } + + public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { + initialize("classification_only_training_data_and_training_percent_is_100"); + indexTrainingData(sourceIndex, 300); + + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD)); + registerAnalytics(config); + putAnalytics(config); + + assertState(jobId, DataFrameAnalyticsState.STOPPED); + assertProgress(jobId, 0, 0, 0, 0); + + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); + for (SearchHit hit : sourceData.getHits()) { + Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); + + assertThat(resultsObject.containsKey("variable_prediction"), is(true)); + assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES))); + assertThat(resultsObject.containsKey("is_training"), is(true)); + assertThat(resultsObject.get("is_training"), is(true)); + assertThat(resultsObject.containsKey("top_classes"), is(false)); + } + + assertProgress(jobId, 100, 100, 100, 100); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertThatAuditMessagesMatch(jobId, + "Created analytics with analysis type [classification]", + "Estimated memory usage for this analytics to be", + "Started analytics", + "Creating destination index [" + destIndex + "]", + "Finished reindexing to destination index [" + destIndex + "]", + "Finished analysis"); + } + + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception { + initialize("classification_only_training_data_and_training_percent_is_50"); + indexTrainingData(sourceIndex, 300); + + DataFrameAnalyticsConfig config = + buildAnalytics( + jobId, + sourceIndex, + destIndex, + null, + new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0)); + registerAnalytics(config); + putAnalytics(config); + + assertState(jobId, DataFrameAnalyticsState.STOPPED); + assertProgress(jobId, 0, 0, 0, 0); + + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + int trainingRowsCount = 0; + int nonTrainingRowsCount = 0; + SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); + for (SearchHit hit : sourceData.getHits()) { + Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); + assertThat(resultsObject.containsKey("variable_prediction"), is(true)); + assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES))); + + assertThat(resultsObject.containsKey("is_training"), is(true)); + // Let's just assert there's both training and non-training results + if ((boolean) resultsObject.get("is_training")) { + trainingRowsCount++; + } else { + nonTrainingRowsCount++; + } + assertThat(resultsObject.containsKey("top_classes"), is(false)); + } + assertThat(trainingRowsCount, greaterThan(0)); + assertThat(nonTrainingRowsCount, greaterThan(0)); + + assertProgress(jobId, 100, 100, 100, 100); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertThatAuditMessagesMatch(jobId, + "Created analytics with analysis type [classification]", + "Estimated memory usage for this analytics to be", + "Started analytics", + "Creating destination index [" + destIndex + "]", + "Finished reindexing to destination index [" + destIndex + "]", + "Finished analysis"); + } + + @AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/issues/712") + public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception { + initialize("classification_top_classes_requested"); + indexTrainingData(sourceIndex, 300); + + int numTopClasses = 2; + DataFrameAnalyticsConfig config = + buildAnalytics( + jobId, + sourceIndex, + destIndex, + null, + new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null)); + registerAnalytics(config); + putAnalytics(config); + + assertState(jobId, DataFrameAnalyticsState.STOPPED); + assertProgress(jobId, 0, 0, 0, 0); + + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); + for (SearchHit hit : sourceData.getHits()) { + Map destDoc = getDestDoc(config, hit); + Map resultsObject = getMlResultsObjectFromDestDoc(destDoc); + + assertThat(resultsObject.containsKey("variable_prediction"), is(true)); + assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES))); + assertTopClasses(resultsObject, numTopClasses); + } + + assertProgress(jobId, 100, 100, 100, 100); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertThatAuditMessagesMatch(jobId, + "Created analytics with analysis type [classification]", + "Estimated memory usage for this analytics to be", + "Started analytics", + "Creating destination index [" + destIndex + "]", + "Finished reindexing to destination index [" + destIndex + "]", + "Finished analysis"); + } + + private void initialize(String jobId) { + this.jobId = jobId; + this.sourceIndex = jobId + "_source_index"; + this.destIndex = sourceIndex + "_results"; + } + + private static void indexTrainingData(String sourceIndex, int numRows) { + client().admin().indices().prepareCreate(sourceIndex) + .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword") + .get(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < numRows; i++) { + Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); + String value = DEPENDENT_VARIABLE_VALUES.get(i % 3); + + IndexRequest indexRequest = new IndexRequest(sourceIndex) + .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + } + + private static Map getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) { + GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); + assertThat(destDocGetResponse.isExists(), is(true)); + Map sourceDoc = hit.getSourceAsMap(); + Map destDoc = destDocGetResponse.getSource(); + for (String field : sourceDoc.keySet()) { + assertThat(destDoc.containsKey(field), is(true)); + assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); + } + return destDoc; + } + + private static Map getMlResultsObjectFromDestDoc(Map destDoc) { + assertThat(destDoc.containsKey("ml"), is(true)); + @SuppressWarnings("unchecked") + Map resultsObject = (Map) destDoc.get("ml"); + return resultsObject; + } + + private static void assertTopClasses(Map resultsObject, int numTopClasses) { + assertThat(resultsObject.containsKey("top_classes"), is(true)); + List> topClasses = (List>) resultsObject.get("top_classes"); + assertThat(topClasses, hasSize(numTopClasses)); + List classNames = new ArrayList<>(topClasses.size()); + List classProbabilities = new ArrayList<>(topClasses.size()); + for (Map topClass : topClasses) { + assertThat(topClass, allOf(hasKey("class_name"), hasKey("class_probability"))); + classNames.add((String) topClass.get("class_name")); + classProbabilities.add((Double) topClass.get("class_probability")); + } + // Assert that all the class names come from the set of dependent variable values. + classNames.forEach(className -> assertThat(className, is(in(DEPENDENT_VARIABLE_VALUES)))); + // Assert that the first class listed in top classes is the same as the predicted class. + assertThat(classNames.get(0), equalTo(resultsObject.get("variable_prediction"))); + // Assert that all the class probabilities lie within [0, 1] interval. + classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)))); + // Assert that the top classes are listed in the order of decreasing probabilities. + assertThat(Ordering.natural().reverse().isOrdered(classProbabilities), is(true)); + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index fe5c347de3f77..0cb121ac67215 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -27,8 +27,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.notifications.AuditorField; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; @@ -136,13 +135,13 @@ protected List getAnalyticsStat return response.getResponse().results(); } - protected static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String[] sourceIndex, String destIndex, - @Nullable String resultsField) { + protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex, + @Nullable String resultsField, DataFrameAnalysis analysis) { DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(); configBuilder.setId(id); - configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null)); + configBuilder.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null)); configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField)); - configBuilder.setAnalysis(new OutlierDetection()); + configBuilder.setAnalysis(analysis); return configBuilder.build(); } @@ -175,16 +174,6 @@ protected SearchResponse searchStoredProgress(String id) { .get(); } - protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex, - @Nullable String resultsField, Regression regression) { - DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(); - configBuilder.setId(id); - configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null)); - configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField)); - configBuilder.setAnalysis(regression); - return configBuilder.build(); - } - /** * Asserts whether the audit messages fetched from index match provided prefixes. * More specifically, in order to pass: diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java index 89782f18c0c72..c1c2fec780b62 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java @@ -14,6 +14,7 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.junit.After; import java.util.Map; @@ -68,7 +69,7 @@ public void testMissingFields() throws Exception { } String id = "test_outlier_detection_with_missing_fields"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, sourceIndex + "-results", null); + DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", null, new OutlierDetection()); registerAnalytics(config); putAnalytics(config); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 42854a29e41ea..fed47445d0666 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -15,11 +15,13 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.junit.After; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -30,41 +32,52 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { + private static final String NUMERICAL_FEATURE_FIELD = "feature"; + private static final String DEPENDENT_VARIABLE_FIELD = "variable"; + private static final List NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0)); + private static final List DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0)); + + private String jobId; + private String sourceIndex; + private String destIndex; + @After - public void cleanup() { + public void cleanup() throws Exception { cleanUp(); } public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { - String jobId = "regression_single_numeric_feature_and_mixed_data_set"; - String sourceIndex = jobId + "_source_index"; - - BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); - bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - List featureValues = Arrays.asList(1.0, 2.0, 3.0); - List dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0); - - for (int i = 0; i < 350; i++) { - Double field = featureValues.get(i % 3); - Double value = dependentVariableValues.get(i % 3); + initialize("regression_single_numeric_feature_and_mixed_data_set"); + + { // Index 350 rows, 300 of them being training rows. + client().admin().indices().prepareCreate(sourceIndex) + .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double") + .get(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < 300; i++) { + Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); + Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3); + + IndexRequest indexRequest = new IndexRequest(sourceIndex) + .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value); + bulkRequestBuilder.add(indexRequest); + } + for (int i = 300; i < 350; i++) { + Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); - IndexRequest indexRequest = new IndexRequest(sourceIndex); - if (i < 300) { - indexRequest.source("feature", field, "variable", value); - } else { - indexRequest.source("feature", field); + IndexRequest indexRequest = new IndexRequest(sourceIndex) + .source(NUMERICAL_FEATURE_FIELD, field); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); } - bulkRequestBuilder.add(indexRequest); - } - BulkResponse bulkResponse = bulkRequestBuilder.get(); - if (bulkResponse.hasFailures()) { - fail("Failed to index data: " + bulkResponse.buildFailureMessage()); } - String destIndex = sourceIndex + "_results"; - DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null, - new Regression("variable")); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); registerAnalytics(config); putAnalytics(config); @@ -76,71 +89,54 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); for (SearchHit hit : sourceData.getHits()) { - GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); - assertThat(destDocGetResponse.isExists(), is(true)); - Map sourceDoc = hit.getSourceAsMap(); - Map destDoc = destDocGetResponse.getSource(); - for (String field : sourceDoc.keySet()) { - assertThat(destDoc.containsKey(field), is(true)); - assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); - } - assertThat(destDoc.containsKey("ml"), is(true)); - - @SuppressWarnings("unchecked") - Map resultsObject = (Map) destDoc.get("ml"); - - assertThat(resultsObject.containsKey("variable_prediction"), is(true)); + Map destDoc = getDestDoc(config, hit); + Map resultsObject = getMlResultsObjectFromDestDoc(destDoc); // TODO reenable this assertion when the backend is stable // it seems for this case values can be as far off as 2.0 - // double featureValue = (double) destDoc.get("feature"); + // double featureValue = (double) destDoc.get(NUMERICAL_FEATURE_FIELD); // double predictionValue = (double) resultsObject.get("variable_prediction"); // assertThat(predictionValue, closeTo(10 * featureValue, 2.0)); - boolean expectedIsTraining = destDoc.containsKey("variable"); + assertThat(resultsObject.containsKey("variable_prediction"), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); - assertThat(resultsObject.get("is_training"), is(expectedIsTraining)); + assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD))); } assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(jobId); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [regression]", "Estimated memory usage for this analytics to be", "Started analytics", - "Creating destination index [regression_single_numeric_feature_and_mixed_data_set_source_index_results]", - "Finished reindexing to destination index [regression_single_numeric_feature_and_mixed_data_set_source_index_results]", + "Creating destination index [" + destIndex + "]", + "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); - assertModelStatePersisted(jobId); } public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { - String jobId = "regression_only_training_data_and_training_percent_is_hundred"; - String sourceIndex = jobId + "_source_index"; - - BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); - bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - List featureValues = Arrays.asList(1.0, 2.0, 3.0); - List dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0); - - for (int i = 0; i < 350; i++) { - Double field = featureValues.get(i % 3); - Double value = dependentVariableValues.get(i % 3); - - IndexRequest indexRequest = new IndexRequest(sourceIndex); - indexRequest.source("feature", field, "variable", value); - bulkRequestBuilder.add(indexRequest); - } - BulkResponse bulkResponse = bulkRequestBuilder.get(); - if (bulkResponse.hasFailures()) { - fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + initialize("regression_only_training_data_and_training_percent_is_100"); + + { // Index 350 rows, all of them being training rows. + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < 350; i++) { + Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); + Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3); + + IndexRequest indexRequest = new IndexRequest(sourceIndex) + .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } } - String destIndex = sourceIndex + "_results"; - DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null, - new Regression("variable")); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); registerAnalytics(config); putAnalytics(config); @@ -152,18 +148,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); for (SearchHit hit : sourceData.getHits()) { - GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); - assertThat(destDocGetResponse.isExists(), is(true)); - Map sourceDoc = hit.getSourceAsMap(); - Map destDoc = destDocGetResponse.getSource(); - for (String field : sourceDoc.keySet()) { - assertThat(destDoc.containsKey(field), is(true)); - assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); - } - assertThat(destDoc.containsKey("ml"), is(true)); - - @SuppressWarnings("unchecked") - Map resultsObject = (Map) destDoc.get("ml"); + Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); assertThat(resultsObject.containsKey("variable_prediction"), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); @@ -172,42 +157,43 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(jobId); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [regression]", "Estimated memory usage for this analytics to be", "Started analytics", - "Creating destination index [regression_only_training_data_and_training_percent_is_hundred_source_index_results]", - "Finished reindexing to destination index [regression_only_training_data_and_training_percent_is_hundred_source_index_results]", + "Creating destination index [" + destIndex + "]", + "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); - assertModelStatePersisted(jobId); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception { - String jobId = "regression_only_training_data_and_training_percent_is_fifty"; - String sourceIndex = jobId + "_source_index"; - - BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); - bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - List featureValues = Arrays.asList(1.0, 2.0, 3.0); - List dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0); - - for (int i = 0; i < 350; i++) { - Double field = featureValues.get(i % 3); - Double value = dependentVariableValues.get(i % 3); - - IndexRequest indexRequest = new IndexRequest(sourceIndex); - indexRequest.source("feature", field, "variable", value); - bulkRequestBuilder.add(indexRequest); - } - BulkResponse bulkResponse = bulkRequestBuilder.get(); - if (bulkResponse.hasFailures()) { - fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + initialize("regression_only_training_data_and_training_percent_is_50"); + + { // Index 350 rows, all of them being training rows. + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < 350; i++) { + Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); + Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3); + + IndexRequest indexRequest = new IndexRequest(sourceIndex) + .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } } - String destIndex = sourceIndex + "_results"; - DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null, - new Regression("variable", null, null, null, null, null, null, 50.0)); + DataFrameAnalyticsConfig config = + buildAnalytics( + jobId, + sourceIndex, + destIndex, + null, + new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0)); registerAnalytics(config); putAnalytics(config); @@ -221,21 +207,9 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception int nonTrainingRowsCount = 0; SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); for (SearchHit hit : sourceData.getHits()) { - GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); - assertThat(destDocGetResponse.isExists(), is(true)); - Map sourceDoc = hit.getSourceAsMap(); - Map destDoc = destDocGetResponse.getSource(); - for (String field : sourceDoc.keySet()) { - assertThat(destDoc.containsKey(field), is(true)); - assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); - } - assertThat(destDoc.containsKey("ml"), is(true)); - - @SuppressWarnings("unchecked") - Map resultsObject = (Map) destDoc.get("ml"); + Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); assertThat(resultsObject.containsKey("variable_prediction"), is(true)); - assertThat(resultsObject.containsKey("is_training"), is(true)); // Let's just assert there's both training and non-training results if ((boolean) resultsObject.get("is_training")) { @@ -249,32 +223,27 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(jobId); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [regression]", "Estimated memory usage for this analytics to be", "Started analytics", - "Creating destination index [regression_only_training_data_and_training_percent_is_fifty_source_index_results]", - "Finished reindexing to destination index [regression_only_training_data_and_training_percent_is_fifty_source_index_results]", + "Creating destination index [" + destIndex + "]", + "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); - assertModelStatePersisted(jobId); } public void testStopAndRestart() throws Exception { - String jobId = "regression_stop_and_restart"; - String sourceIndex = jobId + "_source_index"; - - BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); - bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - List featureValues = Arrays.asList(1.0, 2.0, 3.0); - List dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0); + initialize("regression_stop_and_restart"); + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 0; i < 350; i++) { - Double field = featureValues.get(i % 3); - Double value = dependentVariableValues.get(i % 3); + Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); + Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3); - IndexRequest indexRequest = new IndexRequest(sourceIndex); - indexRequest.source("feature", field, "variable", value); + IndexRequest indexRequest = new IndexRequest(sourceIndex) + .source("feature", field, "variable", value); bulkRequestBuilder.add(indexRequest); } BulkResponse bulkResponse = bulkRequestBuilder.get(); @@ -282,9 +251,7 @@ public void testStopAndRestart() throws Exception { fail("Failed to index data: " + bulkResponse.buildFailureMessage()); } - String destIndex = sourceIndex + "_results"; - DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null, - new Regression("variable")); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); registerAnalytics(config); putAnalytics(config); @@ -317,18 +284,7 @@ public void testStopAndRestart() throws Exception { SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); for (SearchHit hit : sourceData.getHits()) { - GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); - assertThat(destDocGetResponse.isExists(), is(true)); - Map sourceDoc = hit.getSourceAsMap(); - Map destDoc = destDocGetResponse.getSource(); - for (String field : sourceDoc.keySet()) { - assertThat(destDoc.containsKey(field), is(true)); - assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); - } - assertThat(destDoc.containsKey("ml"), is(true)); - - @SuppressWarnings("unchecked") - Map resultsObject = (Map) destDoc.get("ml"); + Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); assertThat(resultsObject.containsKey("variable_prediction"), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); @@ -340,7 +296,32 @@ public void testStopAndRestart() throws Exception { assertModelStatePersisted(jobId); } - private void assertModelStatePersisted(String jobId) { + private void initialize(String jobId) { + this.jobId = jobId; + this.sourceIndex = jobId + "_source_index"; + this.destIndex = sourceIndex + "_results"; + } + + private static Map getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) { + GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); + assertThat(destDocGetResponse.isExists(), is(true)); + Map sourceDoc = hit.getSourceAsMap(); + Map destDoc = destDocGetResponse.getSource(); + for (String field : sourceDoc.keySet()) { + assertThat(destDoc.containsKey(field), is(true)); + assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); + } + return destDoc; + } + + private static Map getMlResultsObjectFromDestDoc(Map destDoc) { + assertThat(destDoc.containsKey("ml"), is(true)); + @SuppressWarnings("unchecked") + Map resultsObject = (Map) destDoc.get("ml"); + return resultsObject; + } + + private static void assertModelStatePersisted(String jobId) { String docId = jobId + "_regression_state#1"; SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern()) .setQuery(QueryBuilders.idsQuery().addIds(docId)) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index e55194b3592fa..145cdc97d2455 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -72,7 +72,7 @@ public void testOutlierDetectionWithFewDocuments() throws Exception { } String id = "test_outlier_detection_with_few_docs"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, sourceIndex + "-results", null); + DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", null, new OutlierDetection()); registerAnalytics(config); putAnalytics(config); @@ -147,8 +147,7 @@ public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { } String id = "test_outlier_detection_with_enough_docs_to_scroll"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics( - id, new String[] {sourceIndex}, sourceIndex + "-results", "custom_ml"); + DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", "custom_ml", new OutlierDetection()); registerAnalytics(config); putAnalytics(config); @@ -217,7 +216,7 @@ public void testOutlierDetectionWithMoreFieldsThanDocValueFieldLimit() throws Ex } String id = "test_outlier_detection_with_more_fields_than_docvalue_limit"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, sourceIndex + "-results", null); + DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", null, new OutlierDetection()); registerAnalytics(config); putAnalytics(config); @@ -280,8 +279,7 @@ public void testStopOutlierDetectionWithEnoughDocumentsToScroll() throws Excepti } String id = "test_stop_outlier_detection_with_enough_docs_to_scroll"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics( - id, new String[] {sourceIndex}, sourceIndex + "-results", "custom_ml"); + DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", "custom_ml", new OutlierDetection()); registerAnalytics(config); putAnalytics(config); @@ -345,7 +343,12 @@ public void testOutlierDetectionWithMultipleSourceIndices() throws Exception { } String id = "test_outlier_detection_with_multiple_source_indices"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex, destIndex, null); + DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder() + .setId(id) + .setSource(new DataFrameAnalyticsSource(sourceIndex, null)) + .setDest(new DataFrameAnalyticsDest(destIndex, null)) + .setAnalysis(new OutlierDetection()) + .build(); registerAnalytics(config); putAnalytics(config); @@ -402,7 +405,7 @@ public void testOutlierDetectionWithPreExistingDestIndex() throws Exception { } String id = "test_outlier_detection_with_pre_existing_dest_index"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, new String[] {sourceIndex}, destIndex, null); + DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, destIndex, null, new OutlierDetection()); registerAnalytics(config); putAnalytics(config); @@ -500,8 +503,7 @@ public void testOutlierDetectionStopAndRestart() throws Exception { } String id = "test_outlier_detection_stop_and_restart"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics( - id, new String[] {sourceIndex}, sourceIndex + "-results", "custom_ml"); + DataFrameAnalyticsConfig config = buildAnalytics(id, sourceIndex, sourceIndex + "-results", "custom_ml", new OutlierDetection()); registerAnalytics(config); putAnalytics(config); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java index baf7e06346944..fd52a3fd8da58 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.ml.dataframe.process.customprocessing; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; @@ -21,7 +22,14 @@ public CustomProcessorFactory(List fieldNames) { public CustomProcessor create(DataFrameAnalysis analysis) { if (analysis instanceof Regression) { - return new RegressionCustomProcessor(fieldNames, (Regression) analysis); + Regression regression = (Regression) analysis; + return new DatasetSplittingCustomProcessor( + fieldNames, regression.getDependentVariable(), regression.getTrainingPercent()); + } + if (analysis instanceof Classification) { + Classification classification = (Classification) analysis; + return new DatasetSplittingCustomProcessor( + fieldNames, classification.getDependentVariable(), classification.getTrainingPercent()); } return row -> {}; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/RegressionCustomProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java similarity index 86% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/RegressionCustomProcessor.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java index 4b814d3504a83..ed42cf5198854 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/RegressionCustomProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.ml.dataframe.process.customprocessing; import org.elasticsearch.common.Randomness; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.util.List; @@ -18,7 +17,7 @@ * This relies on the fact that when the dependent variable field * is empty, then the row is not used for training but only to make predictions. */ -class RegressionCustomProcessor implements CustomProcessor { +class DatasetSplittingCustomProcessor implements CustomProcessor { private static final String EMPTY = ""; @@ -27,10 +26,9 @@ class RegressionCustomProcessor implements CustomProcessor { private final Random random = Randomness.get(); private boolean isFirstRow = true; - RegressionCustomProcessor(List fieldNames, Regression regression) { - this.dependentVariableIndex = findDependentVariableIndex(fieldNames, regression.getDependentVariable()); - this.trainingPercent = regression.getTrainingPercent(); - + DatasetSplittingCustomProcessor(List fieldNames, String dependentVariable, double trainingPercent) { + this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable); + this.trainingPercent = trainingPercent; } private static int findDependentVariableIndex(List fieldNames, String dependentVariable) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/RegressionCustomProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java similarity index 87% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/RegressionCustomProcessorTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java index adcd845059dac..d5973f8782461 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/RegressionCustomProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.ml.dataframe.process.customprocessing; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.junit.Before; import java.util.ArrayList; @@ -20,7 +19,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; -public class RegressionCustomProcessorTests extends ESTestCase { +public class DatasetSplittingCustomProcessorTests extends ESTestCase { private List fields; private int dependentVariableIndex; @@ -38,7 +37,7 @@ public void setUpTests() { } public void testProcess_GivenRowsWithoutDependentVariableValue() { - CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 50.0)); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -56,7 +55,7 @@ public void testProcess_GivenRowsWithoutDependentVariableValue() { } public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() { - CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 100.0)); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -76,7 +75,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() { double trainingPercent = randomDoubleBetween(1.0, 100.0, true); double trainingFraction = trainingPercent / 100; - CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, trainingPercent)); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent); int runCount = 20; int rowsCount = 1000; @@ -122,7 +121,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs } public void testProcess_ShouldHaveAtLeastOneTrainingRow() { - CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 1.0)); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0); // We have some non-training rows and then a training row to check // we maintain the first training row and not just the first row @@ -142,8 +141,4 @@ public void testProcess_ShouldHaveAtLeastOneTrainingRow() { assertThat(Arrays.equals(processedRow, row), is(true)); } } - - private static Regression regression(String dependentVariable, double trainingPercent) { - return new Regression(dependentVariable, null, null, null, null, null, null, trainingPercent); - } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 95c838509f0b8..939a8812d0413 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -1231,6 +1231,346 @@ setup: - is_true: create_time - is_true: version +--- +"Test put classification given dependent_variable is not defined": + + - do: + catch: /parse_exception/ + ml.put_data_frame_analytics: + id: "classification-without-dependent-variable" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": {} + } + } + +--- +"Test put classification given negative lambda": + + - do: + catch: /\[lambda\] must be a non-negative double/ + ml.put_data_frame_analytics: + id: "classification-negative-lambda" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "lambda": -1.0 + } + } + } + +--- +"Test put classification given negative gamma": + + - do: + catch: /\[gamma\] must be a non-negative double/ + ml.put_data_frame_analytics: + id: "classification-negative-gamma" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "gamma": -1.0 + } + } + } + +--- +"Test put classification given eta less than 1e-3": + + - do: + catch: /\[eta\] must be a double in \[0.001, 1\]/ + ml.put_data_frame_analytics: + id: "classification-eta-greater-less-than-valid" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "eta": 0.0009 + } + } + } + +--- +"Test put classification given eta greater than one": + + - do: + catch: /\[eta\] must be a double in \[0.001, 1\]/ + ml.put_data_frame_analytics: + id: "classification-eta-greater-than-one" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "eta": 1.00001 + } + } + } + +--- +"Test put classification given maximum_number_trees is zero": + + - do: + catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/ + ml.put_data_frame_analytics: + id: "classification-maximum-number-trees-is-zero" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "maximum_number_trees": 0 + } + } + } + +--- +"Test put classification given maximum_number_trees is greater than 2k": + + - do: + catch: /\[maximum_number_trees\] must be an integer in \[1, 2000\]/ + ml.put_data_frame_analytics: + id: "classification-maximum-number-trees-greater-than-2k" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "maximum_number_trees": 2001 + } + } + } + +--- +"Test put classification given feature_bag_fraction is negative": + + - do: + catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/ + ml.put_data_frame_analytics: + id: "classification-feature-bag-fraction-is-negative" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "feature_bag_fraction": -0.0001 + } + } + } + +--- +"Test put classification given feature_bag_fraction is greater than one": + + - do: + catch: /\[feature_bag_fraction\] must be a double in \(0, 1\]/ + ml.put_data_frame_analytics: + id: "classification-feature-bag-fraction-is-greater-than-one" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "feature_bag_fraction": 1.0001 + } + } + } + +--- +"Test put classification given num_top_classes is less than zero": + + - do: + catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/ + ml.put_data_frame_analytics: + id: "classification-training-percent-is-less-than-one" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "num_top_classes": -1 + } + } + } + +--- +"Test put classification given num_top_classes is greater than 1k": + + - do: + catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/ + ml.put_data_frame_analytics: + id: "classification-training-percent-is-greater-than-hundred" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "num_top_classes": 1001 + } + } + } + +--- +"Test put classification given training_percent is less than one": + + - do: + catch: /\[training_percent\] must be a double in \[1, 100\]/ + ml.put_data_frame_analytics: + id: "classification-training-percent-is-less-than-one" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "training_percent": 0.999 + } + } + } + +--- +"Test put classification given training_percent is greater than hundred": + + - do: + catch: /\[training_percent\] must be a double in \[1, 100\]/ + ml.put_data_frame_analytics: + id: "classification-training-percent-is-greater-than-hundred" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "training_percent": 100.1 + } + } + } + +--- +"Test put classification given valid": + + - do: + ml.put_data_frame_analytics: + id: "valid-classification" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "classification": { + "dependent_variable": "foo", + "lambda": 3.14, + "gamma": 0.42, + "eta": 0.5, + "maximum_number_trees": 400, + "feature_bag_fraction": 0.3, + "training_percent": 60.3 + } + } + } + - match: { id: "valid-classification" } + - match: { source.index: ["index-source"] } + - match: { dest.index: "index-dest" } + - match: { analysis: { + "classification":{ + "dependent_variable": "foo", + "lambda": 3.14, + "gamma": 0.42, + "eta": 0.5, + "maximum_number_trees": 400, + "feature_bag_fraction": 0.3, + "training_percent": 60.3, + "num_top_classes": 0 + } + }} + - is_true: create_time + - is_true: version + --- "Test put with description":