diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java index 9d384e6d86786..02861adc73845 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -46,6 +46,7 @@ public static Builder builder(String dependentVariable) { static final ParseField ETA = new ParseField("eta"); static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); @@ -62,10 +63,11 @@ public static Builder builder(String dependentVariable) { (Double) a[3], (Integer) a[4], (Double) a[5], - (String) a[6], - (Double) a[7], - (Integer) a[8], - (Long) a[9])); + (Integer) a[6], + (String) a[7], + (Double) a[8], + (Integer) a[9], + (Long) a[10])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -74,6 +76,7 @@ public static Builder builder(String dependentVariable) { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); @@ -86,13 +89,15 @@ public static Builder builder(String dependentVariable) { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; private final String predictionFieldName; private final Double trainingPercent; private final Integer numTopClasses; private final Long randomizeSeed; private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, - @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName, @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; @@ -100,6 +105,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; this.numTopClasses = numTopClasses; @@ -135,6 +141,10 @@ public Double getFeatureBagFraction() { return featureBagFraction; } + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -170,6 +180,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } @@ -188,8 +201,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public int hashCode() { - return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed, numTopClasses); + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues, + predictionFieldName, trainingPercent, randomizeSeed, numTopClasses); } @Override @@ -203,6 +216,7 @@ public boolean equals(Object o) { && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) && Objects.equals(randomizeSeed, that.randomizeSeed) @@ -221,6 +235,7 @@ public static class Builder { private Double eta; private Integer maximumNumberTrees; private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; private String predictionFieldName; private Double trainingPercent; private Integer numTopClasses; @@ -255,6 +270,11 @@ public Builder setFeatureBagFraction(Double featureBagFraction) { return this; } + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + public Builder setPredictionFieldName(String predictionFieldName) { this.predictionFieldName = predictionFieldName; return this; @@ -276,8 +296,8 @@ public Builder setNumTopClasses(Integer numTopClasses) { } public Classification build() { - return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, numTopClasses, randomizeSeed); + return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed); } } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java index fa55ee40b27fb..d7e374a2563a1 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java @@ -46,6 +46,7 @@ public static Builder builder(String dependentVariable) { static final ParseField ETA = new ParseField("eta"); static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); @@ -61,9 +62,10 @@ public static Builder builder(String dependentVariable) { (Double) a[3], (Integer) a[4], (Double) a[5], - (String) a[6], - (Double) a[7], - (Long) a[8])); + (Integer) a[6], + (String) a[7], + (Double) a[8], + (Long) a[9])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -72,6 +74,7 @@ public static Builder builder(String dependentVariable) { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); @@ -83,12 +86,14 @@ public static Builder builder(String dependentVariable) { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; private final String predictionFieldName; private final Double trainingPercent; private final Long randomizeSeed; - private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, - @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, + private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName, @Nullable Double trainingPercent, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; @@ -96,6 +101,7 @@ private Regression(String dependentVariable, @Nullable Double lambda, @Nullable this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; this.randomizeSeed = randomizeSeed; @@ -130,6 +136,10 @@ public Double getFeatureBagFraction() { return featureBagFraction; } + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -161,6 +171,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } @@ -176,8 +189,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public int hashCode() { - return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed); + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues, + predictionFieldName, trainingPercent, randomizeSeed); } @Override @@ -191,6 +204,7 @@ public boolean equals(Object o) { && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) && Objects.equals(randomizeSeed, that.randomizeSeed); @@ -208,6 +222,7 @@ public static class Builder { private Double eta; private Integer maximumNumberTrees; private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; private String predictionFieldName; private Double trainingPercent; private Long randomizeSeed; @@ -241,6 +256,11 @@ public Builder setFeatureBagFraction(Double featureBagFraction) { return this; } + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + public Builder setPredictionFieldName(String predictionFieldName) { this.predictionFieldName = predictionFieldName; return this; @@ -257,8 +277,8 @@ public Builder setRandomizeSeed(Long randomizeSeed) { } public Regression build() { - return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed); + return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 0664d49f76841..6fe08f8a507de 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1294,6 +1294,12 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception { .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) .setRandomizeSeed(42L) + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setMaximumNumberTrees(10) + .setFeatureBagFraction(0.5) + .setNumTopFeatureImportanceValues(3) .build()) .setDescription("this is a regression") .build(); @@ -1331,6 +1337,12 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti .setTrainingPercent(80.0) .setRandomizeSeed(42L) .setNumTopClasses(1) + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setMaximumNumberTrees(10) + .setFeatureBagFraction(0.5) + .setNumTopFeatureImportanceValues(3) .build()) .setDescription("this is a classification") .build(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 0db9dbf222f49..b850b2e8b9f1a 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -2975,10 +2975,11 @@ public void testPutDataFrameAnalytics() throws Exception { .setEta(5.5) // <4> .setMaximumNumberTrees(50) // <5> .setFeatureBagFraction(0.4) // <6> - .setPredictionFieldName("my_prediction_field_name") // <7> - .setTrainingPercent(50.0) // <8> - .setRandomizeSeed(1234L) // <9> - .setNumTopClasses(1) // <10> + .setNumTopFeatureImportanceValues(3) // <7> + .setPredictionFieldName("my_prediction_field_name") // <8> + .setTrainingPercent(50.0) // <9> + .setRandomizeSeed(1234L) // <10> + .setNumTopClasses(1) // <11> .build(); // end::put-data-frame-analytics-classification @@ -2989,9 +2990,10 @@ public void testPutDataFrameAnalytics() throws Exception { .setEta(5.5) // <4> .setMaximumNumberTrees(50) // <5> .setFeatureBagFraction(0.4) // <6> - .setPredictionFieldName("my_prediction_field_name") // <7> - .setTrainingPercent(50.0) // <8> - .setRandomizeSeed(1234L) // <9> + .setNumTopFeatureImportanceValues(3) // <7> + .setPredictionFieldName("my_prediction_field_name") // <8> + .setTrainingPercent(50.0) // <9> + .setRandomizeSeed(1234L) // <10> .build(); // end::put-data-frame-analytics-regression @@ -3670,7 +3672,7 @@ public void testPutTrainedModel() throws Exception { } { PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); - + // tag::put-trained-model-execute-listener ActionListener listener = new ActionListener<>() { @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java index 5ef8fdaef5a27..79d78c888880f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -32,6 +32,7 @@ public static Classification randomClassification() { .setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true)) .setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000)) .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) + .setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE)) .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10)) .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) .setRandomizeSeed(randomBoolean() ? null : randomLong()) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java index 02e41ecdff333..eedffb4740d78 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java @@ -32,6 +32,7 @@ public static Regression randomRegression() { .setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true)) .setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000)) .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) + .setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE)) .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10)) .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) .build(); diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc index 2152eff5c0850..4be2011340210 100644 --- a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -117,10 +117,11 @@ include-tagged::{doc-tests-file}[{api}-classification] <4> The applied shrinkage. A double in [0.001, 1]. <5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000]. <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. -<7> The name of the prediction field in the results object. -<8> The percentage of training-eligible rows to be used in training. Defaults to 100%. -<9> The seed to be used by the random generator that picks which rows are used in training. -<10> The number of top classes to be reported in the results. Defaults to 2. +<7> If set, feature importance for the top most important features will be computed. +<8> The name of the prediction field in the results object. +<9> The percentage of training-eligible rows to be used in training. Defaults to 100%. +<10> The seed to be used by the random generator that picks which rows are used in training. +<11> The number of top classes to be reported in the results. Defaults to 2. ===== Regression @@ -137,9 +138,10 @@ include-tagged::{doc-tests-file}[{api}-regression] <4> The applied shrinkage. A double in [0.001, 1]. <5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000]. <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. -<7> The name of the prediction field in the results object. -<8> The percentage of training-eligible rows to be used in training. Defaults to 100%. -<9> The seed to be used by the random generator that picks which rows are used in training. +<7> If set, feature importance for the top most important features will be computed. +<8> The name of the prediction field in the results object. +<9> The percentage of training-eligible rows to be used in training. Defaults to 100%. +<10> The seed to be used by the random generator that picks which rows are used in training. ==== Analyzed fields diff --git a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc index 24149372e0e99..18b3446bca56c 100644 --- a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc @@ -148,6 +148,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name] (Optional, long) include::{docdir}/ml/ml-shared.asciidoc[tag=randomize-seed] +`analysis`.`classification`.`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values] + `analysis`.`classification`.`training_percent`:::: (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent] @@ -227,6 +231,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=lambda] (Optional, string) include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name] +`analysis`.`regression`.`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values] + `analysis`.`regression`.`training_percent`:::: (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 5454939af1e24..f25cfb94e8bed 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -639,6 +639,14 @@ end::include-model-definition[] tag::indices[] An array of index names. Wildcards are supported. For example: `["it_ops_metrics", "server*"]`. + +tag::num-top-feature-importance-values[] +Advanced configuration option. If set, feature importance for the top +most important features will be computed. Importance is calculated +using the SHAP (SHapley Additive exPlanations) method as described in +https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf[Lundberg, S. M., & Lee, S.-I. A Unified Approach to Interpreting Model Predictions. In NeurIPS 2017.]. +end::num-top-feature-importance-values[] + + -- NOTE: If any indices are in remote clusters then `cluster.remote.connect` must diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java index 0f06b08444f53..ec563af73f520 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.analyses; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -34,6 +35,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { public static final ParseField ETA = new ParseField("eta"); public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static void declareFields(AbstractObjectParser parser) { parser.declareDouble(optionalConstructorArg(), LAMBDA); @@ -41,6 +43,7 @@ static void declareFields(AbstractObjectParser parser) { parser.declareDouble(optionalConstructorArg(), ETA); parser.declareInt(optionalConstructorArg(), MAXIMUM_NUMBER_TREES); parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION); + parser.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); } private final Double lambda; @@ -48,12 +51,14 @@ static void declareFields(AbstractObjectParser parser) { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; public BoostedTreeParams(@Nullable Double lambda, - @Nullable Double gamma, - @Nullable Double eta, - @Nullable Integer maximumNumberTrees, - @Nullable Double featureBagFraction) { + @Nullable Double gamma, + @Nullable Double eta, + @Nullable Integer maximumNumberTrees, + @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues) { if (lambda != null && lambda < 0) { throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName()); } @@ -69,15 +74,16 @@ public BoostedTreeParams(@Nullable Double lambda, if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) { throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName()); } + if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a non-negative integer", + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + } this.lambda = lambda; this.gamma = gamma; this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; - } - - public BoostedTreeParams() { - this(null, null, null, null, null); + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; } BoostedTreeParams(StreamInput in) throws IOException { @@ -86,6 +92,11 @@ public BoostedTreeParams() { eta = in.readOptionalDouble(); maximumNumberTrees = in.readOptionalVInt(); featureBagFraction = in.readOptionalDouble(); + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { + numTopFeatureImportanceValues = in.readOptionalInt(); + } else { + numTopFeatureImportanceValues = null; + } } @Override @@ -95,6 +106,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalDouble(eta); out.writeOptionalVInt(maximumNumberTrees); out.writeOptionalDouble(featureBagFraction); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { + out.writeOptionalInt(numTopFeatureImportanceValues); + } } @Override @@ -114,6 +128,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } return builder; } @@ -134,6 +151,9 @@ Map getParams() { if (featureBagFraction != null) { params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + params.put(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } return params; } @@ -146,11 +166,62 @@ public boolean equals(Object o) { && Objects.equals(gamma, that.gamma) && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) - && Objects.equals(featureBagFraction, that.featureBagFraction); + && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); } @Override public int hashCode() { - return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction); + return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Double lambda; + private Double gamma; + private Double eta; + private Integer maximumNumberTrees; + private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; + + private Builder() {} + + public Builder setLambda(Double lambda) { + this.lambda = lambda; + return this; + } + + public Builder setGamma(Double gamma) { + this.gamma = gamma; + return this; + } + + public Builder setEta(Double eta) { + this.eta = eta; + return this; + } + + public Builder setMaximumNumberTrees(Integer maximumNumberTrees) { + this.maximumNumberTrees = maximumNumberTrees; + return this; + } + + public Builder setFeatureBagFraction(Double featureBagFraction) { + this.featureBagFraction = featureBagFraction; + return this; + } + + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public BoostedTreeParams build() { + return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 0e68d13895e25..47a02786f5d07 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -50,11 +50,11 @@ private static ConstructingObjectParser createParser(boole lenient, a -> new Classification( (String) a[0], - new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), - (String) a[6], - (Integer) a[7], - (Double) a[8], - (Long) a[9])); + new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]), + (String) a[7], + (Integer) a[8], + (Double) a[9], + (Long) a[10])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -112,7 +112,7 @@ public Classification(String dependentVariable, } public Classification(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null); } public Classification(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index fe2927591312a..83174a9aebfe3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -47,10 +47,10 @@ private static ConstructingObjectParser createParser(boolean l lenient, a -> new Regression( (String) a[0], - new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), - (String) a[6], - (Double) a[7], - (Long) a[8])); + new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]), + (String) a[7], + (Double) a[8], + (Long) a[9])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -85,7 +85,7 @@ public Regression(String dependentVariable, } public Regression(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null); } public Regression(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index b9de87ef93de0..b64a12e087ea9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -471,6 +471,9 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName()) .field(TYPE, DOUBLE) .endObject() + .startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() .startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName()) .field(TYPE, KEYWORD) .endObject() @@ -499,6 +502,9 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName()) .field(TYPE, DOUBLE) .endObject() + .startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() .startObject(Classification.PREDICTION_FIELD_NAME.getPreferredName()) .field(TYPE, KEYWORD) .endObject() diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 8eacdcb0e78e4..968df76d5ed91 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -322,6 +322,7 @@ public final class ReservedFieldNames { BoostedTreeParams.ETA.getPreferredName(), BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName(), BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName(), + BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), ElasticsearchMappings.CONFIG_TYPE, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java index 145533df407cd..6f3aff88846d9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java @@ -23,7 +23,7 @@ protected BoostedTreeParams doParseInstance(XContentParser parser) throws IOExce new ConstructingObjectParser<>( BoostedTreeParams.NAME, true, - a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4])); + a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4], (Integer) a[5])); BoostedTreeParams.declareFields(objParser); return objParser.apply(parser, null); } @@ -34,12 +34,14 @@ protected BoostedTreeParams createTestInstance() { } public static BoostedTreeParams createRandom() { - Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); - Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); - Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true); - Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000); - Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false); - return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction); + return BoostedTreeParams.builder() + .setLambda(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true)) + .setGamma(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true)) + .setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true)) + .setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000)) + .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) + .setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE)) + .build(); } @Override @@ -49,57 +51,64 @@ protected Writeable.Reader instanceReader() { public void testConstructor_GivenNegativeLambda() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(-0.00001, 0.0, 0.5, 500, 0.3)); + () -> BoostedTreeParams.builder().setLambda(-0.00001).build()); assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double")); } public void testConstructor_GivenNegativeGamma() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, -0.00001, 0.5, 500, 0.3)); + () -> BoostedTreeParams.builder().setGamma(-0.00001).build()); assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double")); } public void testConstructor_GivenEtaIsZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.0, 500, 0.3)); + () -> BoostedTreeParams.builder().setEta(0.0).build()); assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); } public void testConstructor_GivenEtaIsGreaterThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 1.00001, 500, 0.3)); + () -> BoostedTreeParams.builder().setEta(1.00001).build()); assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); } public void testConstructor_GivenMaximumNumberTreesIsZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 0, 0.3)); + () -> BoostedTreeParams.builder().setMaximumNumberTrees(0).build()); assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); } public void testConstructor_GivenMaximumNumberTreesIsGreaterThan2k() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 2001, 0.3)); + () -> BoostedTreeParams.builder().setMaximumNumberTrees(2001).build()); assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); } public void testConstructor_GivenFeatureBagFractionIsLessThanZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, -0.00001)); + () -> BoostedTreeParams.builder().setFeatureBagFraction(-0.00001).build()); assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); } public void testConstructor_GivenFeatureBagFractionIsGreaterThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.00001)); + () -> BoostedTreeParams.builder().setFeatureBagFraction(1.00001).build()); assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); } + + public void testConstructor_GivenTopFeatureImportanceValuesIsNegative() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> BoostedTreeParams.builder().setNumTopFeatureImportanceValues(-1).build()); + + assertThat(e.getMessage(), equalTo("[num_top_feature_importance_values] must be a non-negative integer")); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 1b988379fc218..64b14157cf613 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -34,7 +34,7 @@ public class ClassificationTests extends AbstractSerializingTestCase { - private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0); + private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build(); @Override protected Classification doParseInstance(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index c7f89cc0413b5..83df5b44ced25 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -30,7 +30,7 @@ public class RegressionTests extends AbstractSerializingTestCase { - private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0); + private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build(); @Override protected Regression doParseInstance(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index d1e49169d75e7..ab130089df3f0 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.ml.integration; import com.google.common.collect.Ordering; - import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.admin.indices.get.GetIndexAction; import org.elasticsearch.action.admin.indices.get.GetIndexRequest; @@ -28,7 +27,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; @@ -86,7 +84,14 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws String predictedClassField = KEYWORD_FIELD + "_prediction"; indexData(sourceIndex, 300, 50, KEYWORD_FIELD); - DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, + new Classification( + KEYWORD_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null, + null)); registerAnalytics(config); putAnalytics(config); @@ -104,6 +109,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES))); assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); + assertThat(resultsObject.keySet().stream().filter(k -> k.startsWith("feature_importance.")).findAny().isPresent(), is(true)); } assertProgress(jobId, 100, 100, 100, 100); @@ -178,7 +184,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId, sourceIndex, destIndex, null, - new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0, null)); + new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, numTopClasses, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -414,7 +420,13 @@ public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exceptio String firstJobId = "classification_two_jobs_with_same_randomize_seed_1"; String firstJobDestIndex = firstJobId + "_dest"; - BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder() + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setFeatureBagFraction(1.0) + .setMaximumNumberTrees(1) + .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 2d790260dac12..3315727df57c6 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.junit.After; @@ -53,7 +52,14 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws initialize("regression_single_numeric_feature_and_mixed_data_set"); indexData(sourceIndex, 300, 50); - DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, + new Regression( + DEPENDENT_VARIABLE_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null) + ); registerAnalytics(config); putAnalytics(config); @@ -78,6 +84,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws assertThat(resultsObject.containsKey("variable_prediction"), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD))); + assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true)); } assertProgress(jobId, 100, 100, 100, 100); @@ -141,7 +148,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception sourceIndex, destIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0, null)); + new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -244,7 +251,13 @@ public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exceptio String firstJobId = "regression_two_jobs_with_same_randomize_seed_1"; String firstJobDestIndex = firstJobId + "_dest"; - BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder() + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setFeatureBagFraction(1.0) + .setMaximumNumberTrees(1) + .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null));