From bfb390e4e4257c2337c5df0d820f70a50620bbb9 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 2 Sep 2020 15:47:41 -0400 Subject: [PATCH 1/7] [ML] Add new inference//_metadata API for model training metadata --- .../client/MLRequestConverters.java | 26 + .../client/MachineLearningClient.java | 45 ++ .../ml/GetTrainedModelsMetadataRequest.java | 99 +++ .../ml/GetTrainedModelsMetadataResponse.java | 86 ++ .../metadata/TotalFeatureImportance.java | 208 +++++ .../metadata/TrainedModelMetadata.java | 91 +++ .../client/MachineLearningIT.java | 64 ++ .../MlClientDocumentationIT.java | 79 ++ .../GetTrainedModelsMetadataRequestTests.java | 39 + .../metadata/TotalFeatureImportanceTests.java | 63 ++ .../metadata/TrainedModelMetadataTests.java | 53 ++ .../ml/get-trained-models-metadata.asciidoc | 45 ++ .../high-level/supported-apis.asciidoc | 2 + ...-inference-trained-model-metadata.asciidoc | 736 ++++++++++++++++++ .../ml/df-analytics/apis/index.asciidoc | 1 + docs/reference/ml/ml-shared.asciidoc | 17 + .../GetTrainedModelsMetadataAction.java | 112 +++ .../ml/qa/ml-with-security/build.gradle | 2 + .../xpack/ml/MachineLearning.java | 7 +- ...ansportGetTrainedModelsMetadataAction.java | 99 +++ .../RestGetTrainedModelsMetadataAction.java | 55 ++ .../api/ml.get_trained_models_metadata.json | 51 ++ .../test/ml/inference_metadata.yml | 107 +++ 23 files changed, 2086 insertions(+), 1 deletion(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataResponse.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataRequestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java create mode 100644 docs/java-rest/high-level/ml/get-trained-models-metadata.asciidoc create mode 100644 docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsMetadataAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsMetadataAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsMetadataAction.java create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_metadata.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_metadata.yml diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index 46c57b4a40cdd..99e4cd4fec7e8 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -61,6 +61,7 @@ import org.elasticsearch.client.ml.GetModelSnapshotsRequest; import org.elasticsearch.client.ml.GetOverallBucketsRequest; import org.elasticsearch.client.ml.GetRecordsRequest; +import org.elasticsearch.client.ml.GetTrainedModelsMetadataRequest; import org.elasticsearch.client.ml.GetTrainedModelsRequest; import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest; import org.elasticsearch.client.ml.MlInfoRequest; @@ -819,6 +820,31 @@ static Request getTrainedModelsStats(GetTrainedModelsStatsRequest getTrainedMode return request; } + static Request getTrainedModelsMetadata(GetTrainedModelsMetadataRequest getTrainedModelsMetadataRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "inference") + .addPathPart(Strings.collectionToCommaDelimitedString(getTrainedModelsMetadataRequest.getIds())) + .addPathPart("_metadata") + .build(); + RequestConverters.Params params = new RequestConverters.Params(); + if (getTrainedModelsMetadataRequest.getPageParams() != null) { + PageParams pageParams = getTrainedModelsMetadataRequest.getPageParams(); + if (pageParams.getFrom() != null) { + params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString()); + } + if (pageParams.getSize() != null) { + params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); + } + } + if (getTrainedModelsMetadataRequest.getAllowNoMatch() != null) { + params.putParam(GetTrainedModelsMetadataRequest.ALLOW_NO_MATCH, + Boolean.toString(getTrainedModelsMetadataRequest.getAllowNoMatch())); + } + Request request = new Request(HttpGet.METHOD_NAME, endpoint); + request.addParameters(params.asMap()); + return request; + } + static Request deleteTrainedModel(DeleteTrainedModelRequest deleteRequest) { String endpoint = new EndpointBuilder() .addPathPartAsIs("_ml", "inference") diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java index 6c74dd1e800cf..dc75e19e59bd0 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java @@ -77,6 +77,8 @@ import org.elasticsearch.client.ml.GetOverallBucketsResponse; import org.elasticsearch.client.ml.GetRecordsRequest; import org.elasticsearch.client.ml.GetRecordsResponse; +import org.elasticsearch.client.ml.GetTrainedModelsMetadataRequest; +import org.elasticsearch.client.ml.GetTrainedModelsMetadataResponse; import org.elasticsearch.client.ml.GetTrainedModelsRequest; import org.elasticsearch.client.ml.GetTrainedModelsResponse; import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest; @@ -2519,6 +2521,49 @@ public Cancellable getTrainedModelsStatsAsync(GetTrainedModelsStatsRequest reque Collections.emptySet()); } + /** + * Gets trained model metadata + *

+ * For additional info + * see + * GET Trained Model Metadata documentation + * + * @param request The {@link GetTrainedModelsMetadataRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link GetTrainedModelsMetadataResponse} response object + */ + public GetTrainedModelsMetadataResponse getTrainedModelsMetadata(GetTrainedModelsMetadataRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::getTrainedModelsMetadata, + options, + GetTrainedModelsMetadataResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Gets trained model metadata asynchronously and notifies listener upon completion + *

+ * For additional info + * see + * GET Trained Model Metadata documentation + * + * @param request The {@link GetTrainedModelsMetadataRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + * @return cancellable that may be used to cancel the request + */ + public Cancellable getTrainedModelsMetadataAsync(GetTrainedModelsMetadataRequest request, + RequestOptions options, + ActionListener listener) { + return restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::getTrainedModelsMetadata, + options, + GetTrainedModelsMetadataResponse::fromXContent, + listener, + Collections.emptySet()); + } + /** * Deletes the given Trained Model *

diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataRequest.java new file mode 100644 index 0000000000000..4f3a85bf5e3b7 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataRequest.java @@ -0,0 +1,99 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.core.PageParams; +import org.elasticsearch.common.Nullable; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +public class GetTrainedModelsMetadataRequest implements Validatable { + + public static final String ALLOW_NO_MATCH = "allow_no_match"; + + private final List ids; + private Boolean allowNoMatch; + private PageParams pageParams; + + public static GetTrainedModelsMetadataRequest getAllTrainedModelsMetadataRequest() { + return new GetTrainedModelsMetadataRequest("_all"); + } + + public GetTrainedModelsMetadataRequest(String... ids) { + this.ids = Arrays.asList(ids); + } + + public List getIds() { + return ids; + } + + public Boolean getAllowNoMatch() { + return allowNoMatch; + } + + /** + * Whether to ignore if a wildcard expression matches no trained models. + * + * @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all}) + * does not match any trained models + */ + public GetTrainedModelsMetadataRequest setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + return this; + } + + public PageParams getPageParams() { + return pageParams; + } + + public GetTrainedModelsMetadataRequest setPageParams(@Nullable PageParams pageParams) { + this.pageParams = pageParams; + return this; + } + + @Override + public Optional validate() { + if (ids == null || ids.isEmpty()) { + return Optional.of(ValidationException.withError("trained model id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetTrainedModelsMetadataRequest other = (GetTrainedModelsMetadataRequest) o; + return Objects.equals(ids, other.ids) + && Objects.equals(allowNoMatch, other.allowNoMatch) + && Objects.equals(pageParams, other.pageParams); + } + + @Override + public int hashCode() { + return Objects.hash(ids, allowNoMatch, pageParams); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataResponse.java new file mode 100644 index 0000000000000..c0495ceeaae3f --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataResponse.java @@ -0,0 +1,86 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.inference.trainedmodel.metadata.TrainedModelMetadata; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class GetTrainedModelsMetadataResponse { + + public static final ParseField TRAINED_MODELS_METADATA = new ParseField("trained_models_metadata"); + public static final ParseField COUNT = new ParseField("count"); + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "get_trained_models_metadata", + true, + args -> new GetTrainedModelsMetadataResponse((List) args[0], (Long) args[1])); + + static { + PARSER.declareObjectArray(constructorArg(), (p, c) -> TrainedModelMetadata.fromXContent(p), TRAINED_MODELS_METADATA); + PARSER.declareLong(constructorArg(), COUNT); + } + + public static GetTrainedModelsMetadataResponse fromXContent(final XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final List trainedModelsMetadata; + private final Long count; + + + public GetTrainedModelsMetadataResponse(List trainedModelsMetadata, Long count) { + this.trainedModelsMetadata = trainedModelsMetadata; + this.count = count; + } + + public List getTrainedModelsMetadata() { + return trainedModelsMetadata; + } + + /** + * @return The total count of the trained models that matched the ID pattern. + */ + public Long getCount() { + return count; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetTrainedModelsMetadataResponse other = (GetTrainedModelsMetadataResponse) o; + return Objects.equals(this.trainedModelsMetadata, other.trainedModelsMetadata) && Objects.equals(this.count, other.count); + } + + @Override + public int hashCode() { + return Objects.hash(trainedModelsMetadata, count); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java new file mode 100644 index 0000000000000..882dc046d6d64 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java @@ -0,0 +1,208 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.inference.trainedmodel.metadata; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class TotalFeatureImportance implements ToXContentObject { + + private static final String NAME = "total_feature_importance"; + public static final ParseField FEATURE_NAME = new ParseField("feature_name"); + public static final ParseField IMPORTANCE = new ParseField("importance"); + public static final ParseField CLASSES = new ParseField("classes"); + public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude"); + public static final ParseField MIN = new ParseField("min"); + public static final ParseField MAX = new ParseField("max"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + true, + a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List)a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); + PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), Importance.PARSER, IMPORTANCE); + PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), ClassImportance.PARSER, CLASSES); + } + + public static TotalFeatureImportance fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public final String featureName; + public final Importance importance; + public final List classImportances; + + TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List classImportances) { + this.featureName = featureName; + this.importance = importance; + this.classImportances = classImportances == null ? Collections.emptyList() : classImportances; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FEATURE_NAME.getPreferredName(), featureName); + if (importance != null) { + builder.field(IMPORTANCE.getPreferredName(), importance); + } + if (classImportances.isEmpty() == false) { + builder.field(CLASSES.getPreferredName(), classImportances); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TotalFeatureImportance that = (TotalFeatureImportance) o; + return Objects.equals(that.importance, importance) + && Objects.equals(featureName, that.featureName) + && Objects.equals(classImportances, that.classImportances); + } + + @Override + public int hashCode() { + return Objects.hash(featureName, importance, classImportances); + } + + public static class Importance implements ToXContentObject { + private static final String NAME = "importance"; + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + true, + a -> new Importance((double)a[0], (double)a[1], (double)a[2])); + + static { + PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE); + PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN); + PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX); + } + + private final double meanMagnitude; + private final double min; + private final double max; + + public Importance(double meanMagnitude, double min, double max) { + this.meanMagnitude = meanMagnitude; + this.min = min; + this.max = max; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Importance that = (Importance) o; + return Double.compare(that.meanMagnitude, meanMagnitude) == 0 && + Double.compare(that.min, min) == 0 && + Double.compare(that.max, max) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(meanMagnitude, min, max); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); + builder.field(MIN.getPreferredName(), min); + builder.field(MAX.getPreferredName(), max); + builder.endObject(); + return builder; + } + } + + public static class ClassImportance implements ToXContentObject { + private static final String NAME = "total_class_importance"; + + public static final ParseField CLASS_NAME = new ParseField("class_name"); + public static final ParseField IMPORTANCE = new ParseField("importance"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + true, + a -> new ClassImportance(a[0], (Importance)a[1])); + + static { + PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return p.text(); + } else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) { + return p.numberValue(); + } else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) { + return p.booleanValue(); + } + throw new XContentParseException("Unsupported token [" + p.currentToken() + "]"); + }, CLASS_NAME, ObjectParser.ValueType.VALUE); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), Importance.PARSER, IMPORTANCE); + } + + public static ClassImportance fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public final Object className; + public final Importance importance; + + ClassImportance(Object className, Importance importance) { + this.className = className; + this.importance = importance; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME.getPreferredName(), className); + builder.field(IMPORTANCE.getPreferredName(), importance); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassImportance that = (ClassImportance) o; + return Objects.equals(that.importance, importance) && Objects.equals(className, that.className); + } + + @Override + public int hashCode() { + return Objects.hash(className, importance); + } + + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java new file mode 100644 index 0000000000000..5e9ccf1669c2c --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java @@ -0,0 +1,91 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.inference.trainedmodel.metadata; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class TrainedModelMetadata implements ToXContentObject { + + public static final String NAME = "trained_model_metadata"; + public static final ParseField TOTAL_FEATURE_IMPORTANCE = new ParseField("total_feature_importance"); + public static final ParseField MODEL_ID = new ParseField("model_id"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + true, + a -> new TrainedModelMetadata((String)a[0], (List)a[1])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID); + PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), TotalFeatureImportance.PARSER, TOTAL_FEATURE_IMPORTANCE); + } + + public static TrainedModelMetadata fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final List totalFeatureImportances; + private final String modelId; + + public TrainedModelMetadata(String modelId, List totalFeatureImportances) { + this.modelId = Objects.requireNonNull(modelId); + this.totalFeatureImportances = Collections.unmodifiableList(totalFeatureImportances); + } + + public String getModelId() { + return modelId; + } + + public List getTotalFeatureImportances() { + return totalFeatureImportances; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelMetadata that = (TrainedModelMetadata) o; + return Objects.equals(totalFeatureImportances, that.totalFeatureImportances) && + Objects.equals(modelId, that.modelId); + } + + @Override + public int hashCode() { + return Objects.hash(totalFeatureImportances, modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(TOTAL_FEATURE_IMPORTANCE.getPreferredName(), totalFeatureImportances); + builder.endObject(); + return builder; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 4c5427d8cda7e..8095f06448e90 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -78,6 +78,8 @@ import org.elasticsearch.client.ml.GetJobStatsResponse; import org.elasticsearch.client.ml.GetModelSnapshotsRequest; import org.elasticsearch.client.ml.GetModelSnapshotsResponse; +import org.elasticsearch.client.ml.GetTrainedModelsMetadataRequest; +import org.elasticsearch.client.ml.GetTrainedModelsMetadataResponse; import org.elasticsearch.client.ml.GetTrainedModelsRequest; import org.elasticsearch.client.ml.GetTrainedModelsResponse; import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest; @@ -165,6 +167,7 @@ import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; +import org.elasticsearch.client.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.AnalysisLimits; import org.elasticsearch.client.ml.job.config.DataDescription; @@ -187,6 +190,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.junit.After; import java.io.IOException; @@ -2392,6 +2396,66 @@ public void testGetTrainedModelsStats() throws Exception { } } + public void testGetTrainedModelsMetadata() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String modelIdPrefix = "a-get-trained-model-metadata-"; + int numberOfModels = 5; + for (int i = 0; i < numberOfModels; ++i) { + String modelId = modelIdPrefix + i; + putTrainedModel(modelId); + IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME).id("trained_model_metadata-" + modelId); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + indexRequest.source("{\"model_id\":\"" + modelId + "\", \"doc_type\": \"trained_model_metadata\",\n" + + " \"total_feature_importance\": [\n" + + " {\n" + + " \"feature_name\": \"foo\",\n" + + " \"importance\": {\n" + + " \"mean_magnitude\": 6.0,\n" + + " \"min\": -3.0,\n" + + " \"max\": 3.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"feature_name\": \"bar\",\n" + + " \"importance\": {\n" + + " \"mean_magnitude\": 5.0,\n" + + " \"min\": -2.0,\n" + + " \"max\": 3.0\n" + + " }\n" + + " }\n" + + " ]}", XContentType.JSON); + highLevelClient().index(indexRequest, RequestOptions.DEFAULT); + } + + { + GetTrainedModelsMetadataResponse getTrainedModelsMetadataResponse = execute( + GetTrainedModelsMetadataRequest.getAllTrainedModelsMetadataRequest(), + machineLearningClient::getTrainedModelsMetadata, machineLearningClient::getTrainedModelsMetadataAsync); + assertThat(getTrainedModelsMetadataResponse.getTrainedModelsMetadata(), hasSize(numberOfModels)); + assertThat(getTrainedModelsMetadataResponse.getCount(), equalTo(5L)); + } + { + GetTrainedModelsMetadataResponse getTrainedModelsMetadataResponse = execute( + new GetTrainedModelsMetadataRequest(modelIdPrefix + 4, modelIdPrefix + 2, modelIdPrefix + 3), + machineLearningClient::getTrainedModelsMetadata, machineLearningClient::getTrainedModelsMetadataAsync); + assertThat(getTrainedModelsMetadataResponse.getTrainedModelsMetadata(), hasSize(3)); + assertThat(getTrainedModelsMetadataResponse.getCount(), equalTo(3L)); + } + { + GetTrainedModelsMetadataResponse getTrainedModelsMetadataResponse = execute( + new GetTrainedModelsMetadataRequest(modelIdPrefix + "*").setPageParams(new PageParams(1, 2)), + machineLearningClient::getTrainedModelsMetadata, machineLearningClient::getTrainedModelsMetadataAsync); + assertThat(getTrainedModelsMetadataResponse.getTrainedModelsMetadata(), hasSize(2)); + assertThat(getTrainedModelsMetadataResponse.getCount(), equalTo(5L)); + assertThat( + getTrainedModelsMetadataResponse.getTrainedModelsMetadata() + .stream() + .map(TrainedModelMetadata::getModelId) + .collect(Collectors.toList()), + containsInAnyOrder(modelIdPrefix + 1, modelIdPrefix + 2)); + } + } + public void testDeleteTrainedModel() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); String modelId = "delete-trained-model-test"; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index a676e3a5a7f85..4a3ec2c7c8021 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -91,6 +91,8 @@ import org.elasticsearch.client.ml.GetOverallBucketsResponse; import org.elasticsearch.client.ml.GetRecordsRequest; import org.elasticsearch.client.ml.GetRecordsResponse; +import org.elasticsearch.client.ml.GetTrainedModelsMetadataRequest; +import org.elasticsearch.client.ml.GetTrainedModelsMetadataResponse; import org.elasticsearch.client.ml.GetTrainedModelsRequest; import org.elasticsearch.client.ml.GetTrainedModelsResponse; import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest; @@ -182,6 +184,7 @@ import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.client.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.AnalysisLimits; import org.elasticsearch.client.ml.job.config.DataDescription; @@ -214,6 +217,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.junit.After; import java.io.IOException; @@ -3868,6 +3872,81 @@ public void onFailure(Exception e) { } } + public void testGetTrainedModelsMetadata() throws Exception { + String modelId = "my-trained-model"; + putTrainedModel(modelId); + IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME).id("trained_model_metadata-" + modelId); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + indexRequest.source("{\"model_id\":\"" + modelId + "\", \"doc_type\": \"trained_model_metadata\",\n" + + " \"total_feature_importance\": [\n" + + " {\n" + + " \"feature_name\": \"foo\",\n" + + " \"importance\": {\n" + + " \"mean_magnitude\": 6.0,\n" + + " \"min\": -3.0,\n" + + " \"max\": 3.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"feature_name\": \"bar\",\n" + + " \"importance\": {\n" + + " \"mean_magnitude\": 5.0,\n" + + " \"min\": -2.0,\n" + + " \"max\": 3.0\n" + + " }\n" + + " }\n" + + " ]}", XContentType.JSON); + highLevelClient().index(indexRequest, RequestOptions.DEFAULT); + RestHighLevelClient client = highLevelClient(); + { + // tag::get-trained-models-metadata-request + GetTrainedModelsMetadataRequest request = + new GetTrainedModelsMetadataRequest("my-trained-model") // <1> + .setPageParams(new PageParams(0, 1)) // <2> + .setAllowNoMatch(true); // <3> + // end::get-trained-models-metadata-request + + // tag::get-trained-models-metadata-execute + GetTrainedModelsMetadataResponse response = + client.machineLearning().getTrainedModelsMetadata(request, RequestOptions.DEFAULT); + // end::get-trained-models-metadata-execute + + // tag::get-trained-models-metadata-response + List models = response.getTrainedModelsMetadata(); + // end::get-trained-models-metadata-response + + assertThat(models, hasSize(1)); + } + { + GetTrainedModelsMetadataRequest request = new GetTrainedModelsMetadataRequest("my-trained-model"); + + // tag::get-trained-models-metadata-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(GetTrainedModelsMetadataResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::get-trained-models-metadata-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::get-trained-models-metadata-execute-async + client.machineLearning() + .getTrainedModelsMetadataAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::get-trained-models-metadata-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + public void testDeleteTrainedModel() throws Exception { RestHighLevelClient client = highLevelClient(); { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataRequestTests.java new file mode 100644 index 0000000000000..8f74376d8e0d0 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsMetadataRequestTests.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class GetTrainedModelsMetadataRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new GetTrainedModelsMetadataRequest("valid-id").validate()); + assertEquals(Optional.empty(), new GetTrainedModelsMetadataRequest("").validate()); + } + + public void testValidate_Failure() { + assertThat(new GetTrainedModelsMetadataRequest(new String[0]).validate().get().getMessage(), + containsString("trained model id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java new file mode 100644 index 0000000000000..eef5c3bae21d2 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java @@ -0,0 +1,63 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.metadata; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class TotalFeatureImportanceTests extends AbstractXContentTestCase { + + + public static TotalFeatureImportance randomInstance() { + return new TotalFeatureImportance( + randomAlphaOfLength(10), + randomBoolean() ? null : randomImportance(), + randomBoolean() ? + null : + Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance())) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList()) + ); + } + + private static TotalFeatureImportance.Importance randomImportance() { + return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble()); + } + + @Override + protected TotalFeatureImportance createTestInstance() { + return randomInstance(); + } + + @Override + protected TotalFeatureImportance doParseInstance(XContentParser parser) throws IOException { + return TotalFeatureImportance.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java new file mode 100644 index 0000000000000..5c5db4adbf857 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.metadata; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class TrainedModelMetadataTests extends AbstractXContentTestCase { + + + public static TrainedModelMetadata randomInstance() { + return new TrainedModelMetadata( + randomAlphaOfLength(10), + Stream.generate(TotalFeatureImportanceTests::randomInstance).limit(randomIntBetween(1, 10)).collect(Collectors.toList())); + } + + @Override + protected TrainedModelMetadata createTestInstance() { + return randomInstance(); + } + + @Override + protected TrainedModelMetadata doParseInstance(XContentParser parser) throws IOException { + return TrainedModelMetadata.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + +} diff --git a/docs/java-rest/high-level/ml/get-trained-models-metadata.asciidoc b/docs/java-rest/high-level/ml/get-trained-models-metadata.asciidoc new file mode 100644 index 0000000000000..b3bc7f6f83ce8 --- /dev/null +++ b/docs/java-rest/high-level/ml/get-trained-models-metadata.asciidoc @@ -0,0 +1,45 @@ +-- +:api: get-trained-models-metadata +:request: GetTrainedModelsMetadataRequest +:response: GetTrainedModelsMetadataResponse +-- +[role="xpack"] +[id="{upid}-{api}"] +=== Get Trained Models Metadata API + +experimental[] + +Retrieves training metadata for one or more trained models. +The API accepts a +{request}+ object and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Get Trained Models Metadata request + +A +{request}+ requires either a Trained Model ID, a comma-separated list of +IDs, or the special wildcard `_all` to get metadata for all Trained Models. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> Constructing a new GET request referencing an existing Trained Model +<2> Set the paging parameters +<3> Allow empty response if no Trained Models match the provided ID patterns. + If false, an error will be thrown if no Trained Models match the + ID patterns. + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the metadata +for the requested Trained Model. + +NOTE: the Trained Model will only have training metadata if +it was trained in the current cluster by data frame analytics. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- diff --git a/docs/java-rest/high-level/supported-apis.asciidoc b/docs/java-rest/high-level/supported-apis.asciidoc index e127aff2440ed..2bd77ef371b2e 100644 --- a/docs/java-rest/high-level/supported-apis.asciidoc +++ b/docs/java-rest/high-level/supported-apis.asciidoc @@ -331,6 +331,7 @@ The Java High Level REST Client supports the following Machine Learning APIs: * <<{upid}-get-trained-models>> * <<{upid}-put-trained-model>> * <<{upid}-get-trained-models-stats>> +* <<{upid}-get-trained-models-metadata>> * <<{upid}-delete-trained-model>> * <<{upid}-put-filter>> * <<{upid}-get-filters>> @@ -389,6 +390,7 @@ include::ml/explain-data-frame-analytics.asciidoc[] include::ml/get-trained-models.asciidoc[] include::ml/put-trained-model.asciidoc[] include::ml/get-trained-models-stats.asciidoc[] +include::ml/get-trained-models-metadata.asciidoc[] include::ml/delete-trained-model.asciidoc[] include::ml/put-filter.asciidoc[] include::ml/get-filters.asciidoc[] diff --git a/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc b/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc new file mode 100644 index 0000000000000..71d59b0609072 --- /dev/null +++ b/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc @@ -0,0 +1,736 @@ +[role="xpack"] +[testenv="basic"] +[[get-inference-metadata]] += Get {infer} trained model metadata API +[subs="attributes"] +++++ +Get {infer} trained model metadata +++++ + +Retrieves training metadata information for trained {infer} models. + +experimental[] + + +[[ml-get-inference-metadata-request]] +== {api-request-title} + +`GET _ml/inference/_metadata` + + +`GET _ml/inference/_all/_metadata` + + +`GET _ml/inference//_metadata` + + +`GET _ml/inference/,/_metadata` + + +`GET _ml/inference/,/_metadata` + + +[[ml-get-inference-metadata-prereq]] +== {api-prereq-title} + +If the {es} {security-features} are enabled, you must have the following +privileges: + +* cluster: `monitor_ml` + +For more information, see <> and {ml-docs-setup-privileges}. + +[[ml-get-inference-metadata-desc]] +== {api-description-title} + +You can get training metadata information for multiple trained models in a single API +request by using a comma-separated list of model IDs or a wildcard expression. + + +[[ml-get-inference-metadata-path-params]] +== {api-path-parms-title} + +``:: +(Optional, string) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] + + +[[ml-get-inference-metadata-query-params]] +== {api-query-parms-title} + +`allow_no_match`:: +(Optional, boolean) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match] + +`from`:: +(Optional, integer) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from] + +`size`:: +(Optional, integer) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size] + +[role="child_attributes"] +[[ml-get-inference-metadata-results]] +== {api-response-body-title} + +`count`:: +(integer) +The total number of trained model metadata objects that matched the requested ID patterns. +Could be higher than the number of items in the `trained_models_metadata` array as the +size of the array is restricted by the supplied `size` parameter. + +`trained_models_metadata`:: +(array) +An array of trained model metadata objects, which are sorted by the `model_id` value in +ascending order. ++ +.Properties of trained model metadata +[%collapsible%open] +==== +`model_id`::: +(string) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] + +`total_feature_importance`::: +(array) +An array of the total feature importance for each training feature used from +the training data set. ++ +.Properties of total feature importance +[%collapsible%open] +===== + +`feature_name`::: +(string) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-feature-name] + +`importance`::: +(object) +A collection of feature importance statistics related to the training data set for this particular feature. ++ +.Properties of feature importance +[%collapsible%open] +====== +`mean_magnitude`::: +(double) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude] + +`max`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max] + +`min`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min] + +====== + +`classes`::: +(array) +If the trained model is a classification model, feature importance stastics are gathered +per target class value. ++ +.Properties of class feature importance +[%collapsible%open] + +====== + +`class_name`::: +(string) +The target class value. Could be a string, boolean, or number. + +`importance`::: +(object) +A collection of feature importance statistics related to the training data set for this particular feature. ++ +.Properties of feature importance +[%collapsible%open] +======= +`mean_magnitude`::: +(double) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude] + +`max`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max] + +`min`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min] + +======= + +====== + +===== + +==== + +[[ml-get-inference-metadata-response-codes]] +== {api-response-codes-title} + +`404` (Missing resources):: + If `allow_no_match` is `false`, this code indicates that there are no + resources that match the request or only partial matches for the request. + +[[ml-get-inference-metadata-example]] +== {api-examples-title} + +The following example gets training metadata for all the trained models: + +[source,console] +-------------------------------------------------- +GET _ml/inference/_metadata +-------------------------------------------------- +// TEST[skip:TBD] + + +The API returns the following results: + +[source,console-result] +---- +{ + "count" : 2, + "trained_models_metadata" : [ + { + "model_id" : "avg_price_prediction-1599149443166", + "total_feature_importance" : [ + { + "feature_name" : "Origin", + "importance" : { + "mean_magnitude" : 25.862683737654795, + "min" : -188.93284143727874, + "max" : 162.8783518094679 + } + }, + { + "feature_name" : "FlightTimeMin", + "importance" : { + "mean_magnitude" : 62.2776962970226, + "min" : -421.09965377789365, + "max" : 92.53225055842458 + } + }, + { + "feature_name" : "DestAirportID", + "importance" : { + "mean_magnitude" : 14.392208812683114, + "min" : -103.91799718753263, + "max" : 122.63483137469528 + } + }, + { + "feature_name" : "Dest", + "importance" : { + "mean_magnitude" : 9.06018758454092, + "min" : -64.35766760965463, + "max" : 60.60458858708342 + } + }, + { + "feature_name" : "Carrier", + "importance" : { + "mean_magnitude" : 1.5131352419114026, + "min" : -7.980966972560515, + "max" : 14.407986213341761 + } + }, + { + "feature_name" : "Cancelled", + "importance" : { + "mean_magnitude" : 0.17951893871195423, + "min" : -4.543996246002224, + "max" : 2.058692610259091 + } + } + ] + }, + { + "model_id" : "dest_weather_prediction-1599149568413", + "total_feature_importance" : [ + { + "feature_name" : "dayOfWeek", + "classes" : [ + { + "class_name" : "Clear", + "importance" : { + "mean_magnitude" : 0.0033597810869050483, + "min" : -0.034589509802599394, + "max" : 0.013677011897069439 + } + }, + { + "class_name" : "Cloudy", + "importance" : { + "mean_magnitude" : 0.003549516620011909, + "min" : -0.06736294734141816, + "max" : 0.088650519638185 + } + }, + { + "class_name" : "Sunny", + "importance" : { + "mean_magnitude" : 0.004471474339413112, + "min" : -0.08060353377909144, + "max" : 0.1045130657148837 + } + }, + { + "class_name" : "Hail", + "importance" : { + "mean_magnitude" : 0.007072062864425885, + "min" : -0.05044235221609796, + "max" : 0.038623432806435085 + } + }, + { + "class_name" : "Heavy Fog", + "importance" : { + "mean_magnitude" : 0.0065139540458721236, + "min" : -0.018591621408001358, + "max" : 0.03543735929759353 + } + }, + { + "class_name" : "Thunder & Lightning", + "importance" : { + "mean_magnitude" : 0.0066567969304509155, + "min" : -0.059528507259167134, + "max" : 0.03628958395628503 + } + }, + { + "class_name" : "Rain", + "importance" : { + "mean_magnitude" : 0.0038758238618025985, + "min" : -0.07831548102713791, + "max" : 0.05696179640413974 + } + }, + { + "class_name" : "Damaging Wind", + "importance" : { + "mean_magnitude" : 0.0113605093018583, + "min" : -0.053221001966268555, + "max" : 0.07941614599243701 + } + } + ] + }, + { + "feature_name" : "OriginWeather", + "classes" : [ + { + "class_name" : "Clear", + "importance" : { + "mean_magnitude" : 0.06392838735937217, + "min" : -0.44856958621186466, + "max" : 0.42934605429030326 + } + }, + { + "class_name" : "Cloudy", + "importance" : { + "mean_magnitude" : 0.06548971082123245, + "min" : -0.46779188319269366, + "max" : 0.3269667496467847 + } + }, + { + "class_name" : "Sunny", + "importance" : { + "mean_magnitude" : 0.07518686267467585, + "min" : -0.4986094320082847, + "max" : 0.3404830347301714 + } + }, + { + "class_name" : "Hail", + "importance" : { + "mean_magnitude" : 0.06737171171799335, + "min" : -0.3852291982170536, + "max" : 0.49072939677488925 + } + }, + { + "class_name" : "Heavy Fog", + "importance" : { + "mean_magnitude" : 0.07489704791170221, + "min" : -0.37509109477738595, + "max" : 0.4972213932685191 + } + }, + { + "class_name" : "Thunder & Lightning", + "importance" : { + "mean_magnitude" : 0.0648427860499252, + "min" : -0.31609773937218777, + "max" : 0.4794272326778727 + } + }, + { + "class_name" : "Rain", + "importance" : { + "mean_magnitude" : 0.06543521178087627, + "min" : -0.4873899965919118, + "max" : 0.3184995190663039 + } + }, + { + "class_name" : "Damaging Wind", + "importance" : { + "mean_magnitude" : 0.06431059549346163, + "min" : -0.36484599829883496, + "max" : 0.5219289190181048 + } + } + ] + }, + { + "feature_name" : "DistanceMiles", + "classes" : [ + { + "class_name" : "Clear", + "importance" : { + "mean_magnitude" : 0.009925586997272603, + "min" : -0.08532219495991693, + "max" : 0.09596509857596312 + } + }, + { + "class_name" : "Cloudy", + "importance" : { + "mean_magnitude" : 0.011475354752337826, + "min" : -0.1458542416877272, + "max" : 0.04519866025045433 + } + }, + { + "class_name" : "Sunny", + "importance" : { + "mean_magnitude" : 0.017019258694374224, + "min" : -0.1343486867487768, + "max" : 0.0676285675758164 + } + }, + { + "class_name" : "Hail", + "importance" : { + "mean_magnitude" : 0.009749089329601059, + "min" : -0.09006103474994831, + "max" : 0.09877363346016879 + } + }, + { + "class_name" : "Heavy Fog", + "importance" : { + "mean_magnitude" : 0.01993894846496605, + "min" : -0.127674403426739, + "max" : 0.09865171214017159 + } + }, + { + "class_name" : "Thunder & Lightning", + "importance" : { + "mean_magnitude" : 0.009096504130883604, + "min" : -0.07760903289433295, + "max" : 0.13980510402261984 + } + }, + { + "class_name" : "Rain", + "importance" : { + "mean_magnitude" : 0.012863797177808878, + "min" : -0.14990727286801117, + "max" : 0.11846604888692423 + } + }, + { + "class_name" : "Damaging Wind", + "importance" : { + "mean_magnitude" : 0.007294683069837484, + "min" : -0.06705195480257278, + "max" : 0.10814781489010294 + } + } + ] + }, + { + "feature_name" : "Dest", + "classes" : [ + { + "class_name" : "Clear", + "importance" : { + "mean_magnitude" : 0.03089277559418831, + "min" : -0.17919680669966132, + "max" : 0.167862580693223 + } + }, + { + "class_name" : "Cloudy", + "importance" : { + "mean_magnitude" : 0.03658129232621677, + "min" : -0.155731965346974, + "max" : 0.17282348853252674 + } + }, + { + "class_name" : "Sunny", + "importance" : { + "mean_magnitude" : 0.02229143893127944, + "min" : -0.15452751654480057, + "max" : 0.14704877249575932 + } + }, + { + "class_name" : "Hail", + "importance" : { + "mean_magnitude" : 0.013451198623124975, + "min" : -0.1287421471051356, + "max" : 0.13386295301827233 + } + }, + { + "class_name" : "Heavy Fog", + "importance" : { + "mean_magnitude" : 0.036670909357557686, + "min" : -0.33043825398026216, + "max" : 0.1572790533667184 + } + }, + { + "class_name" : "Thunder & Lightning", + "importance" : { + "mean_magnitude" : 0.028455360035722868, + "min" : -0.15675473094058942, + "max" : 0.24669584585029877 + } + }, + { + "class_name" : "Rain", + "importance" : { + "mean_magnitude" : 0.026467930340461993, + "min" : -0.16937095134002386, + "max" : 0.17950423023794768 + } + }, + { + "class_name" : "Damaging Wind", + "importance" : { + "mean_magnitude" : 0.034070565798177105, + "min" : -0.13793993299792928, + "max" : 0.24210805402819272 + } + } + ] + }, + { + "feature_name" : "FlightDelayType", + "classes" : [ + { + "class_name" : "Clear", + "importance" : { + "mean_magnitude" : 7.813653012241237E-5, + "min" : -0.0045297876576326414, + "max" : 2.5684363007195195E-4 + } + }, + { + "class_name" : "Cloudy", + "importance" : { + "mean_magnitude" : 9.785459733302982E-5, + "min" : -3.2165915172928015E-4, + "max" : 0.005672897766861862 + } + }, + { + "class_name" : "Sunny", + "importance" : { + "mean_magnitude" : 5.032178416705874E-4, + "min" : -0.0016541340774815561, + "max" : 0.02917291009376577 + } + }, + { + "class_name" : "Hail", + "importance" : { + "mean_magnitude" : 3.63823768372713E-5, + "min" : -1.1959299604027548E-4, + "max" : 0.002109185566528496 + } + }, + { + "class_name" : "Heavy Fog", + "importance" : { + "mean_magnitude" : 0.0016573267593305596, + "min" : -0.09607975024381649, + "max" : 0.005447820890113283 + } + }, + { + "class_name" : "Thunder & Lightning", + "importance" : { + "mean_magnitude" : 3.4449616677222154E-4, + "min" : -0.01997138191191523, + "max" : 0.0011323979434591043 + } + }, + { + "class_name" : "Rain", + "importance" : { + "mean_magnitude" : 4.584110411208903E-4, + "min" : -0.0015068490459212258, + "max" : 0.026575337718974523 + } + }, + { + "class_name" : "Damaging Wind", + "importance" : { + "mean_magnitude" : 9.840935992634046E-4, + "min" : -0.00323482719247201, + "max" : 0.057050588667233865 + } + } + ] + }, + { + "feature_name" : "Carrier", + "classes" : [ + { + "class_name" : "Clear", + "importance" : { + "mean_magnitude" : 0.0012995391420598974, + "min" : -0.013357673650648202, + "max" : 0.023657332277930914 + } + }, + { + "class_name" : "Cloudy", + "importance" : { + "mean_magnitude" : 0.0014234549202827422, + "min" : -0.033921439965674836, + "max" : 0.023421720514381957 + } + }, + { + "class_name" : "Sunny", + "importance" : { + "mean_magnitude" : 9.447133417656426E-4, + "min" : -0.0075229254876170775, + "max" : 0.023581652672635692 + } + }, + { + "class_name" : "Hail", + "importance" : { + "mean_magnitude" : 8.183424711633941E-4, + "min" : -0.030286317042075447, + "max" : 0.013849212287877942 + } + }, + { + "class_name" : "Heavy Fog", + "importance" : { + "mean_magnitude" : 0.0018377370367460909, + "min" : -0.04887551961182662, + "max" : 0.03160356785017506 + } + }, + { + "class_name" : "Thunder & Lightning", + "importance" : { + "mean_magnitude" : 0.0034037867051541835, + "min" : -0.06663044595498299, + "max" : 0.059412001243895896 + } + }, + { + "class_name" : "Rain", + "importance" : { + "mean_magnitude" : 0.001136804656696952, + "min" : -0.02694085974354727, + "max" : 0.02252112557408638 + } + }, + { + "class_name" : "Damaging Wind", + "importance" : { + "mean_magnitude" : 0.003395020854090613, + "min" : -0.025753019025261337, + "max" : 0.019039925129197034 + } + } + ] + }, + { + "feature_name" : "Cancelled", + "classes" : [ + { + "class_name" : "Clear", + "importance" : { + "mean_magnitude" : 0.1904078803726547, + "min" : -1.029665279625464, + "max" : 0.14991369698887452 + } + }, + { + "class_name" : "Cloudy", + "importance" : { + "mean_magnitude" : 0.15793833367609444, + "min" : -0.9111582524995818, + "max" : 0.13245342733344687 + } + }, + { + "class_name" : "Sunny", + "importance" : { + "mean_magnitude" : 0.129943186439382, + "min" : -0.797881090878788, + "max" : 0.11748185308604056 + } + }, + { + "class_name" : "Hail", + "importance" : { + "mean_magnitude" : 0.13058547362388567, + "min" : -0.11956949426762936, + "max" : 0.8212854176126191 + } + }, + { + "class_name" : "Heavy Fog", + "importance" : { + "mean_magnitude" : 0.17295029075472754, + "min" : -0.14916418886135357, + "max" : 0.9957120466372638 + } + }, + { + "class_name" : "Thunder & Lightning", + "importance" : { + "mean_magnitude" : 0.16999279667691086, + "min" : -0.14367032363243962, + "max" : 0.9779631405571727 + } + }, + { + "class_name" : "Rain", + "importance" : { + "mean_magnitude" : 0.16103302679957027, + "min" : -0.9461364792231481, + "max" : 0.1418523995372967 + } + }, + { + "class_name" : "Damaging Wind", + "importance" : { + "mean_magnitude" : 0.16579386623218015, + "min" : -0.1393484508575312, + "max" : 0.9577442733774155 + } + } + ] + } + ] + } + ] +} +---- +// NOTCONSOLE diff --git a/docs/reference/ml/df-analytics/apis/index.asciidoc b/docs/reference/ml/df-analytics/apis/index.asciidoc index 421393a1b53e1..22096077fc944 100644 --- a/docs/reference/ml/df-analytics/apis/index.asciidoc +++ b/docs/reference/ml/df-analytics/apis/index.asciidoc @@ -16,6 +16,7 @@ include::get-dfanalytics.asciidoc[leveloffset=+2] include::get-dfanalytics-stats.asciidoc[leveloffset=+2] include::get-inference-trained-model.asciidoc[leveloffset=+2] include::get-inference-trained-model-stats.asciidoc[leveloffset=+2] +include::get-inference-trained-model-metadata.asciidoc[leveloffset=+2] //SET/START/STOP include::start-dfanalytics.asciidoc[leveloffset=+2] include::stop-dfanalytics.asciidoc[leveloffset=+2] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 2a3dea8aca4d9..751968049dd3f 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -766,6 +766,23 @@ prediction. Defaults to the `results_field` value of the {dfanalytics-job} that used to train the model, which defaults to `_prediction`. end::inference-config-results-field-processor[] +tag::inference-metadata-feature-importance-feature-name[] +The training feature name for which this importance was calculated. +end::inference-metadata-feature-importance-feature-name[] +tag::inference-metadata-feature-importance-magnitude[] +The average magnitude of this feature across all the training data. +This value is the average of the absolute values of the importance +for this feature. +end::inference-metadata-feature-importance-magnitude[] +tag::inference-metadata-feature-importance-max[] +The maximum importance value across all the training data for this +feature. +end::inference-metadata-feature-importance-max[] +tag::inference-metadata-feature-importance-min[] +The minimum importance value across all the training data for this +feature. +end::inference-metadata-feature-importance-min[] + tag::influencers[] A comma separated list of influencer field names. Typically these can be the by, over, or partition fields that are used in the detector configuration. You might diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsMetadataAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsMetadataAction.java new file mode 100644 index 0000000000000..6a39330d0f0a3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsMetadataAction.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; + +import java.io.IOException; +import java.util.Comparator; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +public class GetTrainedModelsMetadataAction extends ActionType { + + public static final GetTrainedModelsMetadataAction INSTANCE = new GetTrainedModelsMetadataAction(); + public static final String NAME = "cluster:monitor/xpack/ml/inference/metadata/get"; + + private GetTrainedModelsMetadataAction() { + super(NAME, GetTrainedModelsMetadataAction.Response::new); + } + + public static class Request extends AbstractGetResourcesRequest { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + public Request() { + setAllowNoResources(true); + } + + public Request(String id) { + setResourceId(id); + setAllowNoResources(true); + } + + public Request(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getResourceIdField() { + return TrainedModelConfig.MODEL_ID.getPreferredName(); + } + + } + + public static class Response extends AbstractGetResourcesResponse { + + public static final ParseField RESULTS_FIELD = new ParseField("trained_models_metadata"); + + public Response(StreamInput in) throws IOException { + super(in); + } + + public Response(QueryPage trainedModels) { + super(trainedModels); + } + + @Override + protected Reader getReader() { + return TrainedModelMetadata::new; + } + + public static class Builder { + + private long totalModelCount; + private Set expandedIds; + private Map trainedModelMetadataMap; + + public Builder setTotalModelCount(long totalModelCount) { + this.totalModelCount = totalModelCount; + return this; + } + + public Builder setExpandedIds(Set expandedIds) { + this.expandedIds = expandedIds; + return this; + } + + public Set getExpandedIds() { + return this.expandedIds; + } + + public Builder setTrainedModelMetadata(Map modelMetadataByModelId) { + this.trainedModelMetadataMap = modelMetadataByModelId; + return this; + } + + public Response build() { + return new Response(new QueryPage<>( + expandedIds.stream() + .map(trainedModelMetadataMap::get) + .filter(Objects::nonNull) + .sorted(Comparator.comparing(TrainedModelMetadata::getModelId)) + .collect(Collectors.toList()), + totalModelCount, + RESULTS_FIELD)); + } + } + } + +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 93f943d7ddfc8..1ec2bd8ac745c 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -142,6 +142,8 @@ yamlRestTest { 'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index', 'ml/inference_crud/Test put model with empty input.field_names', 'ml/inference_crud/Test PUT model where target type and inference config mismatch', + 'ml/inference_metadata/Test get given missing trained model metadata', + 'ml/inference_metadata/Test get given expression without matches and allow_no_match is false', 'ml/inference_processor/Test create processor with missing mandatory fields', 'ml/inference_stats_crud/Test get stats given missing trained model', 'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false', diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index e9e0fd09ea876..4c55fd4e7af96 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -113,6 +113,7 @@ import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsMetadataAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; @@ -190,6 +191,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction; import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction; import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsAction; +import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsMetadataAction; import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.action.TransportInternalInferModelAction; import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction; @@ -308,6 +310,7 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; +import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsMetadataAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; @@ -863,6 +866,7 @@ public List getRestHandlers(Settings settings, RestController restC new RestDeleteTrainedModelAction(), new RestGetTrainedModelsStatsAction(), new RestPutTrainedModelAction(), + new RestGetTrainedModelsMetadataAction(), // CAT Handlers new RestCatJobsAction(), new RestCatTrainedModelsAction(), @@ -946,7 +950,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class), new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class), new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class), - usageAction, + new ActionHandler<>(GetTrainedModelsMetadataAction.INSTANCE, TransportGetTrainedModelsMetadataAction.class), + usageAction, infoAction); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsMetadataAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsMetadataAction.java new file mode 100644 index 0000000000000..30b2a5138c703 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsMetadataAction.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.action.AbstractTransportGetResourcesAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsMetadataAction; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; + + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +public class TransportGetTrainedModelsMetadataAction extends AbstractTransportGetResourcesAction< + TrainedModelMetadata, + GetTrainedModelsMetadataAction.Request, + GetTrainedModelsMetadataAction.Response> { + + + @Inject + public TransportGetTrainedModelsMetadataAction(TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry) { + super(GetTrainedModelsMetadataAction.NAME, + transportService, + actionFilters, + GetTrainedModelsMetadataAction.Request::new, + client, + xContentRegistry); + } + @Override + protected ParseField getResultsField() { + return GetTrainedModelsMetadataAction.Response.RESULTS_FIELD; + } + + @Override + protected String[] getIndices() { + return new String[] { InferenceIndexConstants.INDEX_PATTERN }; + } + + @Override + protected TrainedModelMetadata parse(XContentParser parser) { + return TrainedModelMetadata.LENIENT_PARSER.apply(parser, null); + } + + @Override + protected ResourceNotFoundException notFoundException(String resourceId) { + return new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, resourceId)); + } + + @Override + protected void doExecute(Task task, GetTrainedModelsMetadataAction.Request request, + ActionListener listener) { + searchResources(request, ActionListener.wrap( + queryPage -> listener.onResponse(new GetTrainedModelsMetadataAction.Response(queryPage)), + listener::onFailure + )); + } + + @Nullable + protected QueryBuilder additionalQuery() { + return QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelMetadata.NAME); + } + + @Override + protected String executionOrigin() { + return ML_ORIGIN; + } + + @Override + protected String extractIdFromResource(TrainedModelMetadata config) { + return config.getModelId(); + } + + @Override + protected SearchSourceBuilder customSearchOptions(SearchSourceBuilder searchSourceBuilder) { + return searchSourceBuilder.sort("_index", SortOrder.DESC); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsMetadataAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsMetadataAction.java new file mode 100644 index 0000000000000..861563e165e77 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsMetadataAction.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.common.Strings; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsMetadataAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.GET; +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH; + +public class RestGetTrainedModelsMetadataAction extends BaseRestHandler { + + @Override + public List routes() { + return List.of( + new Route(GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_metadata"), + new Route(GET, MachineLearning.BASE_PATH + "inference/_metadata")); + } + + @Override + public String getName() { + return "ml_get_trained_models_metadata_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelMetadata.MODEL_ID.getPreferredName()); + if (Strings.isNullOrEmpty(modelId)) { + modelId = Metadata.ALL; + } + GetTrainedModelsMetadataAction.Request request = new GetTrainedModelsMetadataAction.Request(modelId); + if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { + request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); + } + request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); + return channel -> client.execute(GetTrainedModelsMetadataAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } + +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_metadata.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_metadata.json new file mode 100644 index 0000000000000..c37aec84649e3 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_metadata.json @@ -0,0 +1,51 @@ +{ + "ml.get_trained_models_metadata":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/get-inference-metadata.html", + "description":"Retrieves metadata related to the trained model." + }, + "stability":"experimental", + "url":{ + "paths":[ + { + "path":"/_ml/inference/{model_id}/_metadata", + "methods":[ + "GET" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained models to fetch" + } + } + }, + { + "path":"/_ml/inference/_metadata", + "methods":[ + "GET" + ] + } + ] + }, + "params":{ + "allow_no_match":{ + "type":"boolean", + "required":false, + "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", + "default":true + }, + "from":{ + "required":false, + "type":"int", + "description":"skips a number of trained models", + "default":0 + }, + "size":{ + "required":false, + "type":"int", + "description":"specifies a max number of trained models to get", + "default":100 + } + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_metadata.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_metadata.yml new file mode 100644 index 0000000000000..4ebe7811b7b96 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_metadata.yml @@ -0,0 +1,107 @@ +setup: + - skip: + features: + - headers + - allowed_warnings + - do: + allowed_warnings: + - "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template" + headers: + Content-Type: application/json + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + index: + id: trained_model_metadata-a-regression-model0 + index: .ml-inference-000003 + body: + model_id: "a-regression-model0" + doc_type: "trained_model_metadata" + total_feature_importance: + - { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }} + - { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }} + - do: + allowed_warnings: + - "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template" + headers: + Content-Type: application/json + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + index: + id: trained_model_metadata-a-regression-model1 + index: .ml-inference-000003 + body: + model_id: "a-regression-model1" + doc_type: "trained_model_metadata" + total_feature_importance: + - { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 } } + - { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 } } + - do: + allowed_warnings: + - "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template" + headers: + Content-Type: application/json + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + index: + id: trained_model_metadata-a-classification-model + index: .ml-inference-000003 + body: + model_id: "a-classification-model" + doc_type: "trained_model_metadata" + total_feature_importance: + - { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 } } + - { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 } } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + indices.refresh: { } + +--- +"Test get given missing trained model metadata": + + - do: + catch: missing + ml.get_trained_models_metadata: + model_id: "missing-trained-model" +--- +"Test get given expression without matches and allow_no_match is false": + + - do: + catch: missing + ml.get_trained_models_metadata: + model_id: "missing-trained-model*" + allow_no_match: false + +--- +"Test get given expression without matches and allow_no_match is true": + + - do: + ml.get_trained_models_metadata: + model_id: "missing-trained-model*" + allow_no_match: true + - match: { count: 0 } + - match: { trained_models_metadata: [] } +--- +"Test get models metadata": + - do: + ml.get_trained_models_metadata: + model_id: "*" + size: 3 + - match: { count: 3 } + - length: { trained_models_metadata: 3 } + + - do: + ml.get_trained_models_metadata: + model_id: "a-regression*" + - match: { count: 2 } + - length: { trained_models_metadata: 2 } + - match: { trained_models_metadata.0.model_id: "a-regression-model0" } + - match: { trained_models_metadata.1.model_id: "a-regression-model1" } + + - do: + ml.get_trained_models_metadata: + model_id: "*" + from: 0 + size: 2 + - match: { count: 3 } + - length: { trained_models_metadata: 2 } + - match: { trained_models_metadata.0.model_id: "a-classification-model" } + - match: { trained_models_metadata.1.model_id: "a-regression-model0" } From 90547d3233668a212eeb1a19d0c4090777716843 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 8 Sep 2020 10:32:45 -0400 Subject: [PATCH 2/7] Apply suggestions from code review Co-authored-by: Lisa Cawley --- .../ml/get-trained-models-metadata.asciidoc | 18 +++++++++--------- ...t-inference-trained-model-metadata.asciidoc | 10 +++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/java-rest/high-level/ml/get-trained-models-metadata.asciidoc b/docs/java-rest/high-level/ml/get-trained-models-metadata.asciidoc index b3bc7f6f83ce8..97323c4f83092 100644 --- a/docs/java-rest/high-level/ml/get-trained-models-metadata.asciidoc +++ b/docs/java-rest/high-level/ml/get-trained-models-metadata.asciidoc @@ -5,7 +5,7 @@ -- [role="xpack"] [id="{upid}-{api}"] -=== Get Trained Models Metadata API +=== Get trained models metadata API experimental[] @@ -13,19 +13,19 @@ Retrieves training metadata for one or more trained models. The API accepts a +{request}+ object and returns a +{response}+. [id="{upid}-{api}-request"] -==== Get Trained Models Metadata request +==== Get trained models metadata request -A +{request}+ requires either a Trained Model ID, a comma-separated list of -IDs, or the special wildcard `_all` to get metadata for all Trained Models. +A +{request}+ requires either a trained model ID, a comma-separated list of +IDs, or the special wildcard `_all` to get metadata for all trained models. ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- include-tagged::{doc-tests-file}[{api}-request] -------------------------------------------------- -<1> Constructing a new GET request referencing an existing Trained Model +<1> Constructing a new GET request referencing an existing trained model <2> Set the paging parameters -<3> Allow empty response if no Trained Models match the provided ID patterns. - If false, an error will be thrown if no Trained Models match the +<3> Allow empty response if no trained models match the provided ID patterns. + If false, an error will be thrown if no trained models match the ID patterns. include::../execution.asciidoc[] @@ -34,9 +34,9 @@ include::../execution.asciidoc[] ==== Response The returned +{response}+ contains the metadata -for the requested Trained Model. +for the requested trained model. -NOTE: the Trained Model will only have training metadata if +NOTE: The trained model has training metadata only if it was trained in the current cluster by data frame analytics. ["source","java",subs="attributes,callouts,macros"] diff --git a/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc b/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc index 71d59b0609072..3be3b523666df 100644 --- a/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc @@ -1,13 +1,13 @@ [role="xpack"] [testenv="basic"] [[get-inference-metadata]] -= Get {infer} trained model metadata API += Get trained model metadata API [subs="attributes"] ++++ Get {infer} trained model metadata ++++ -Retrieves training metadata information for trained {infer} models. +Retrieves training metadata information for trained models. experimental[] @@ -39,7 +39,7 @@ For more information, see <> and {ml-docs-setup-privileges} [[ml-get-inference-metadata-desc]] == {api-description-title} -You can get training metadata information for multiple trained models in a single API +You can get metadata for multiple trained models in a single API request by using a comma-separated list of model IDs or a wildcard expression. @@ -56,7 +56,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] `allow_no_match`:: (Optional, boolean) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match] +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models] `from`:: (Optional, integer) @@ -124,7 +124,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-impo `classes`::: (array) -If the trained model is a classification model, feature importance stastics are gathered +If the trained model is a classification model, feature importance statistics are gathered per target class value. + .Properties of class feature importance From 4e9147ba11b2304affa5931a1e1b4a3af5046eb5 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 9 Sep 2020 08:02:49 -0400 Subject: [PATCH 3/7] addressing pr comments --- ...-inference-trained-model-metadata.asciidoc | 460 +----------------- 1 file changed, 4 insertions(+), 456 deletions(-) diff --git a/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc b/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc index 3be3b523666df..090094f383612 100644 --- a/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc @@ -60,11 +60,11 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models] `from`:: (Optional, integer) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from] +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models] `size`:: (Optional, integer) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size] +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size-models] [role="child_attributes"] [[ml-get-inference-metadata-results]] @@ -200,46 +200,7 @@ The API returns the following results: "max" : 162.8783518094679 } }, - { - "feature_name" : "FlightTimeMin", - "importance" : { - "mean_magnitude" : 62.2776962970226, - "min" : -421.09965377789365, - "max" : 92.53225055842458 - } - }, - { - "feature_name" : "DestAirportID", - "importance" : { - "mean_magnitude" : 14.392208812683114, - "min" : -103.91799718753263, - "max" : 122.63483137469528 - } - }, - { - "feature_name" : "Dest", - "importance" : { - "mean_magnitude" : 9.06018758454092, - "min" : -64.35766760965463, - "max" : 60.60458858708342 - } - }, - { - "feature_name" : "Carrier", - "importance" : { - "mean_magnitude" : 1.5131352419114026, - "min" : -7.980966972560515, - "max" : 14.407986213341761 - } - }, - { - "feature_name" : "Cancelled", - "importance" : { - "mean_magnitude" : 0.17951893871195423, - "min" : -4.543996246002224, - "max" : 2.058692610259091 - } - } + ... ] }, { @@ -314,420 +275,7 @@ The API returns the following results: } ] }, - { - "feature_name" : "OriginWeather", - "classes" : [ - { - "class_name" : "Clear", - "importance" : { - "mean_magnitude" : 0.06392838735937217, - "min" : -0.44856958621186466, - "max" : 0.42934605429030326 - } - }, - { - "class_name" : "Cloudy", - "importance" : { - "mean_magnitude" : 0.06548971082123245, - "min" : -0.46779188319269366, - "max" : 0.3269667496467847 - } - }, - { - "class_name" : "Sunny", - "importance" : { - "mean_magnitude" : 0.07518686267467585, - "min" : -0.4986094320082847, - "max" : 0.3404830347301714 - } - }, - { - "class_name" : "Hail", - "importance" : { - "mean_magnitude" : 0.06737171171799335, - "min" : -0.3852291982170536, - "max" : 0.49072939677488925 - } - }, - { - "class_name" : "Heavy Fog", - "importance" : { - "mean_magnitude" : 0.07489704791170221, - "min" : -0.37509109477738595, - "max" : 0.4972213932685191 - } - }, - { - "class_name" : "Thunder & Lightning", - "importance" : { - "mean_magnitude" : 0.0648427860499252, - "min" : -0.31609773937218777, - "max" : 0.4794272326778727 - } - }, - { - "class_name" : "Rain", - "importance" : { - "mean_magnitude" : 0.06543521178087627, - "min" : -0.4873899965919118, - "max" : 0.3184995190663039 - } - }, - { - "class_name" : "Damaging Wind", - "importance" : { - "mean_magnitude" : 0.06431059549346163, - "min" : -0.36484599829883496, - "max" : 0.5219289190181048 - } - } - ] - }, - { - "feature_name" : "DistanceMiles", - "classes" : [ - { - "class_name" : "Clear", - "importance" : { - "mean_magnitude" : 0.009925586997272603, - "min" : -0.08532219495991693, - "max" : 0.09596509857596312 - } - }, - { - "class_name" : "Cloudy", - "importance" : { - "mean_magnitude" : 0.011475354752337826, - "min" : -0.1458542416877272, - "max" : 0.04519866025045433 - } - }, - { - "class_name" : "Sunny", - "importance" : { - "mean_magnitude" : 0.017019258694374224, - "min" : -0.1343486867487768, - "max" : 0.0676285675758164 - } - }, - { - "class_name" : "Hail", - "importance" : { - "mean_magnitude" : 0.009749089329601059, - "min" : -0.09006103474994831, - "max" : 0.09877363346016879 - } - }, - { - "class_name" : "Heavy Fog", - "importance" : { - "mean_magnitude" : 0.01993894846496605, - "min" : -0.127674403426739, - "max" : 0.09865171214017159 - } - }, - { - "class_name" : "Thunder & Lightning", - "importance" : { - "mean_magnitude" : 0.009096504130883604, - "min" : -0.07760903289433295, - "max" : 0.13980510402261984 - } - }, - { - "class_name" : "Rain", - "importance" : { - "mean_magnitude" : 0.012863797177808878, - "min" : -0.14990727286801117, - "max" : 0.11846604888692423 - } - }, - { - "class_name" : "Damaging Wind", - "importance" : { - "mean_magnitude" : 0.007294683069837484, - "min" : -0.06705195480257278, - "max" : 0.10814781489010294 - } - } - ] - }, - { - "feature_name" : "Dest", - "classes" : [ - { - "class_name" : "Clear", - "importance" : { - "mean_magnitude" : 0.03089277559418831, - "min" : -0.17919680669966132, - "max" : 0.167862580693223 - } - }, - { - "class_name" : "Cloudy", - "importance" : { - "mean_magnitude" : 0.03658129232621677, - "min" : -0.155731965346974, - "max" : 0.17282348853252674 - } - }, - { - "class_name" : "Sunny", - "importance" : { - "mean_magnitude" : 0.02229143893127944, - "min" : -0.15452751654480057, - "max" : 0.14704877249575932 - } - }, - { - "class_name" : "Hail", - "importance" : { - "mean_magnitude" : 0.013451198623124975, - "min" : -0.1287421471051356, - "max" : 0.13386295301827233 - } - }, - { - "class_name" : "Heavy Fog", - "importance" : { - "mean_magnitude" : 0.036670909357557686, - "min" : -0.33043825398026216, - "max" : 0.1572790533667184 - } - }, - { - "class_name" : "Thunder & Lightning", - "importance" : { - "mean_magnitude" : 0.028455360035722868, - "min" : -0.15675473094058942, - "max" : 0.24669584585029877 - } - }, - { - "class_name" : "Rain", - "importance" : { - "mean_magnitude" : 0.026467930340461993, - "min" : -0.16937095134002386, - "max" : 0.17950423023794768 - } - }, - { - "class_name" : "Damaging Wind", - "importance" : { - "mean_magnitude" : 0.034070565798177105, - "min" : -0.13793993299792928, - "max" : 0.24210805402819272 - } - } - ] - }, - { - "feature_name" : "FlightDelayType", - "classes" : [ - { - "class_name" : "Clear", - "importance" : { - "mean_magnitude" : 7.813653012241237E-5, - "min" : -0.0045297876576326414, - "max" : 2.5684363007195195E-4 - } - }, - { - "class_name" : "Cloudy", - "importance" : { - "mean_magnitude" : 9.785459733302982E-5, - "min" : -3.2165915172928015E-4, - "max" : 0.005672897766861862 - } - }, - { - "class_name" : "Sunny", - "importance" : { - "mean_magnitude" : 5.032178416705874E-4, - "min" : -0.0016541340774815561, - "max" : 0.02917291009376577 - } - }, - { - "class_name" : "Hail", - "importance" : { - "mean_magnitude" : 3.63823768372713E-5, - "min" : -1.1959299604027548E-4, - "max" : 0.002109185566528496 - } - }, - { - "class_name" : "Heavy Fog", - "importance" : { - "mean_magnitude" : 0.0016573267593305596, - "min" : -0.09607975024381649, - "max" : 0.005447820890113283 - } - }, - { - "class_name" : "Thunder & Lightning", - "importance" : { - "mean_magnitude" : 3.4449616677222154E-4, - "min" : -0.01997138191191523, - "max" : 0.0011323979434591043 - } - }, - { - "class_name" : "Rain", - "importance" : { - "mean_magnitude" : 4.584110411208903E-4, - "min" : -0.0015068490459212258, - "max" : 0.026575337718974523 - } - }, - { - "class_name" : "Damaging Wind", - "importance" : { - "mean_magnitude" : 9.840935992634046E-4, - "min" : -0.00323482719247201, - "max" : 0.057050588667233865 - } - } - ] - }, - { - "feature_name" : "Carrier", - "classes" : [ - { - "class_name" : "Clear", - "importance" : { - "mean_magnitude" : 0.0012995391420598974, - "min" : -0.013357673650648202, - "max" : 0.023657332277930914 - } - }, - { - "class_name" : "Cloudy", - "importance" : { - "mean_magnitude" : 0.0014234549202827422, - "min" : -0.033921439965674836, - "max" : 0.023421720514381957 - } - }, - { - "class_name" : "Sunny", - "importance" : { - "mean_magnitude" : 9.447133417656426E-4, - "min" : -0.0075229254876170775, - "max" : 0.023581652672635692 - } - }, - { - "class_name" : "Hail", - "importance" : { - "mean_magnitude" : 8.183424711633941E-4, - "min" : -0.030286317042075447, - "max" : 0.013849212287877942 - } - }, - { - "class_name" : "Heavy Fog", - "importance" : { - "mean_magnitude" : 0.0018377370367460909, - "min" : -0.04887551961182662, - "max" : 0.03160356785017506 - } - }, - { - "class_name" : "Thunder & Lightning", - "importance" : { - "mean_magnitude" : 0.0034037867051541835, - "min" : -0.06663044595498299, - "max" : 0.059412001243895896 - } - }, - { - "class_name" : "Rain", - "importance" : { - "mean_magnitude" : 0.001136804656696952, - "min" : -0.02694085974354727, - "max" : 0.02252112557408638 - } - }, - { - "class_name" : "Damaging Wind", - "importance" : { - "mean_magnitude" : 0.003395020854090613, - "min" : -0.025753019025261337, - "max" : 0.019039925129197034 - } - } - ] - }, - { - "feature_name" : "Cancelled", - "classes" : [ - { - "class_name" : "Clear", - "importance" : { - "mean_magnitude" : 0.1904078803726547, - "min" : -1.029665279625464, - "max" : 0.14991369698887452 - } - }, - { - "class_name" : "Cloudy", - "importance" : { - "mean_magnitude" : 0.15793833367609444, - "min" : -0.9111582524995818, - "max" : 0.13245342733344687 - } - }, - { - "class_name" : "Sunny", - "importance" : { - "mean_magnitude" : 0.129943186439382, - "min" : -0.797881090878788, - "max" : 0.11748185308604056 - } - }, - { - "class_name" : "Hail", - "importance" : { - "mean_magnitude" : 0.13058547362388567, - "min" : -0.11956949426762936, - "max" : 0.8212854176126191 - } - }, - { - "class_name" : "Heavy Fog", - "importance" : { - "mean_magnitude" : 0.17295029075472754, - "min" : -0.14916418886135357, - "max" : 0.9957120466372638 - } - }, - { - "class_name" : "Thunder & Lightning", - "importance" : { - "mean_magnitude" : 0.16999279667691086, - "min" : -0.14367032363243962, - "max" : 0.9779631405571727 - } - }, - { - "class_name" : "Rain", - "importance" : { - "mean_magnitude" : 0.16103302679957027, - "min" : -0.9461364792231481, - "max" : 0.1418523995372967 - } - }, - { - "class_name" : "Damaging Wind", - "importance" : { - "mean_magnitude" : 0.16579386623218015, - "min" : -0.1393484508575312, - "max" : 0.9577442733774155 - } - } - ] - } + ... ] } ] From 02abb2bd8bcaf2d369860fe7dc927d66646cfbd8 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 16 Sep 2020 13:18:41 -0400 Subject: [PATCH 4/7] adjusting api changes to include feature importance via flag --- .../client/MLRequestConverters.java | 32 +- .../client/MachineLearningClient.java | 45 --- .../client/ml/GetTrainedModelsRequest.java | 34 ++- .../client/MLRequestConvertersTests.java | 4 +- .../client/MachineLearningIT.java | 78 +---- .../MlClientDocumentationIT.java | 90 +----- .../high-level/ml/get-trained-models.asciidoc | 10 +- ...-inference-trained-model-metadata.asciidoc | 284 ------------------ .../apis/get-inference-trained-model.asciidoc | 117 ++++++-- .../ml/df-analytics/apis/index.asciidoc | 1 - .../ml/action/GetTrainedModelsAction.java | 66 +++- .../core/ml/inference/TrainedModelConfig.java | 24 +- .../metadata/TotalFeatureImportance.java | 54 ++-- .../metadata/TrainedModelMetadata.java | 4 + .../xpack/core/ml/job/messages/Messages.java | 2 +- .../action/GetTrainedModelsRequestTests.java | 32 +- .../ChunkedTrainedModelPersisterIT.java | 12 +- .../integration/TrainedModelProviderIT.java | 30 +- .../xpack/ml/MachineLearning.java | 5 - .../TransportGetTrainedModelsAction.java | 26 +- .../TransportInternalInferModelAction.java | 2 +- .../loadingservice/ModelLoadingService.java | 6 +- .../persistence/TrainedModelProvider.java | 148 ++++++--- .../inference/RestGetTrainedModelsAction.java | 16 +- .../RestGetTrainedModelsMetadataAction.java | 55 ---- .../ModelLoadingServiceTests.java | 18 +- .../TrainedModelProviderTests.java | 4 +- .../LangIdentNeuralNetworkInferenceTests.java | 2 +- .../api/ml.get_trained_models.json | 7 +- .../rest-api-spec/test/ml/inference_crud.yml | 34 ++- .../test/ml/inference_metadata.yml | 107 ------- 31 files changed, 527 insertions(+), 822 deletions(-) delete mode 100644 docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsMetadataAction.java delete mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_metadata.yml diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index 99e4cd4fec7e8..758a170087865 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -61,7 +61,6 @@ import org.elasticsearch.client.ml.GetModelSnapshotsRequest; import org.elasticsearch.client.ml.GetOverallBucketsRequest; import org.elasticsearch.client.ml.GetRecordsRequest; -import org.elasticsearch.client.ml.GetTrainedModelsMetadataRequest; import org.elasticsearch.client.ml.GetTrainedModelsRequest; import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest; import org.elasticsearch.client.ml.MlInfoRequest; @@ -780,9 +779,9 @@ static Request getTrainedModels(GetTrainedModelsRequest getTrainedModelsRequest) params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION, Boolean.toString(getTrainedModelsRequest.getDecompressDefinition())); } - if (getTrainedModelsRequest.getIncludeDefinition() != null) { - params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION, - Boolean.toString(getTrainedModelsRequest.getIncludeDefinition())); + if (getTrainedModelsRequest.getIncludes().isEmpty() == false) { + params.putParam(GetTrainedModelsRequest.INCLUDE, + Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getIncludes())); } if (getTrainedModelsRequest.getTags() != null) { params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags())); @@ -820,31 +819,6 @@ static Request getTrainedModelsStats(GetTrainedModelsStatsRequest getTrainedMode return request; } - static Request getTrainedModelsMetadata(GetTrainedModelsMetadataRequest getTrainedModelsMetadataRequest) { - String endpoint = new EndpointBuilder() - .addPathPartAsIs("_ml", "inference") - .addPathPart(Strings.collectionToCommaDelimitedString(getTrainedModelsMetadataRequest.getIds())) - .addPathPart("_metadata") - .build(); - RequestConverters.Params params = new RequestConverters.Params(); - if (getTrainedModelsMetadataRequest.getPageParams() != null) { - PageParams pageParams = getTrainedModelsMetadataRequest.getPageParams(); - if (pageParams.getFrom() != null) { - params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString()); - } - if (pageParams.getSize() != null) { - params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); - } - } - if (getTrainedModelsMetadataRequest.getAllowNoMatch() != null) { - params.putParam(GetTrainedModelsMetadataRequest.ALLOW_NO_MATCH, - Boolean.toString(getTrainedModelsMetadataRequest.getAllowNoMatch())); - } - Request request = new Request(HttpGet.METHOD_NAME, endpoint); - request.addParameters(params.asMap()); - return request; - } - static Request deleteTrainedModel(DeleteTrainedModelRequest deleteRequest) { String endpoint = new EndpointBuilder() .addPathPartAsIs("_ml", "inference") diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java index dc75e19e59bd0..6c74dd1e800cf 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java @@ -77,8 +77,6 @@ import org.elasticsearch.client.ml.GetOverallBucketsResponse; import org.elasticsearch.client.ml.GetRecordsRequest; import org.elasticsearch.client.ml.GetRecordsResponse; -import org.elasticsearch.client.ml.GetTrainedModelsMetadataRequest; -import org.elasticsearch.client.ml.GetTrainedModelsMetadataResponse; import org.elasticsearch.client.ml.GetTrainedModelsRequest; import org.elasticsearch.client.ml.GetTrainedModelsResponse; import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest; @@ -2521,49 +2519,6 @@ public Cancellable getTrainedModelsStatsAsync(GetTrainedModelsStatsRequest reque Collections.emptySet()); } - /** - * Gets trained model metadata - *

- * For additional info - * see - * GET Trained Model Metadata documentation - * - * @param request The {@link GetTrainedModelsMetadataRequest} - * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized - * @return {@link GetTrainedModelsMetadataResponse} response object - */ - public GetTrainedModelsMetadataResponse getTrainedModelsMetadata(GetTrainedModelsMetadataRequest request, - RequestOptions options) throws IOException { - return restHighLevelClient.performRequestAndParseEntity(request, - MLRequestConverters::getTrainedModelsMetadata, - options, - GetTrainedModelsMetadataResponse::fromXContent, - Collections.emptySet()); - } - - /** - * Gets trained model metadata asynchronously and notifies listener upon completion - *

- * For additional info - * see - * GET Trained Model Metadata documentation - * - * @param request The {@link GetTrainedModelsMetadataRequest} - * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized - * @param listener Listener to be notified upon request completion - * @return cancellable that may be used to cancel the request - */ - public Cancellable getTrainedModelsMetadataAsync(GetTrainedModelsMetadataRequest request, - RequestOptions options, - ActionListener listener) { - return restHighLevelClient.performRequestAsyncAndParseEntity(request, - MLRequestConverters::getTrainedModelsMetadata, - options, - GetTrainedModelsMetadataResponse::fromXContent, - listener, - Collections.emptySet()); - } - /** * Deletes the given Trained Model *

diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java index ca0284de84d6e..29fb67b3e75ad 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java @@ -26,21 +26,26 @@ import org.elasticsearch.common.Nullable; import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; public class GetTrainedModelsRequest implements Validatable { + private static final String DEFINITION = "definition"; + private static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance"; public static final String ALLOW_NO_MATCH = "allow_no_match"; - public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition"; public static final String FOR_EXPORT = "for_export"; public static final String DECOMPRESS_DEFINITION = "decompress_definition"; public static final String TAGS = "tags"; + public static final String INCLUDE = "include"; private final List ids; private Boolean allowNoMatch; - private Boolean includeDefinition; + private Set includes = new HashSet<>(); private Boolean decompressDefinition; private Boolean forExport; private PageParams pageParams; @@ -86,19 +91,32 @@ public GetTrainedModelsRequest setPageParams(@Nullable PageParams pageParams) { return this; } - public Boolean getIncludeDefinition() { - return includeDefinition; + public Set getIncludes() { + return Collections.unmodifiableSet(includes); + } + + public GetTrainedModelsRequest includeDefinition() { + this.includes.add(DEFINITION); + return this; + } + + public GetTrainedModelsRequest includeTotalFeatureImportance() { + this.includes.add(TOTAL_FEATURE_IMPORTANCE); + return this; } /** * Whether to include the full model definition. * * The full model definition can be very large. - * + * @deprecated Use {@link GetTrainedModelsRequest#includeDefinition()} * @param includeDefinition If {@code true}, the definition is included. */ + @Deprecated public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) { - this.includeDefinition = includeDefinition; + if (includeDefinition != null && includeDefinition) { + return this.includeDefinition(); + } return this; } @@ -173,13 +191,13 @@ public boolean equals(Object o) { return Objects.equals(ids, other.ids) && Objects.equals(allowNoMatch, other.allowNoMatch) && Objects.equals(decompressDefinition, other.decompressDefinition) - && Objects.equals(includeDefinition, other.includeDefinition) + && Objects.equals(includes, other.includes) && Objects.equals(forExport, other.forExport) && Objects.equals(pageParams, other.pageParams); } @Override public int hashCode() { - return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport); + return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includes, forExport); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index 99b81258c75f2..140c8cd641b49 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -894,7 +894,7 @@ public void testGetTrainedModels() { GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3) .setAllowNoMatch(false) .setDecompressDefinition(true) - .setIncludeDefinition(false) + .includeDefinition() .setTags("tag1", "tag2") .setPageParams(new PageParams(100, 300)); @@ -908,7 +908,7 @@ public void testGetTrainedModels() { hasEntry("allow_no_match", "false"), hasEntry("decompress_definition", "true"), hasEntry("tags", "tag1,tag2"), - hasEntry("include_model_definition", "false") + hasEntry("include", "definition") )); assertNull(request.getEntity()); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 8095f06448e90..f5cb128b21f1e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -78,8 +78,6 @@ import org.elasticsearch.client.ml.GetJobStatsResponse; import org.elasticsearch.client.ml.GetModelSnapshotsRequest; import org.elasticsearch.client.ml.GetModelSnapshotsResponse; -import org.elasticsearch.client.ml.GetTrainedModelsMetadataRequest; -import org.elasticsearch.client.ml.GetTrainedModelsMetadataResponse; import org.elasticsearch.client.ml.GetTrainedModelsRequest; import org.elasticsearch.client.ml.GetTrainedModelsResponse; import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest; @@ -167,7 +165,6 @@ import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; -import org.elasticsearch.client.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.AnalysisLimits; import org.elasticsearch.client.ml.job.config.DataDescription; @@ -190,7 +187,6 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.junit.After; import java.io.IOException; @@ -2231,7 +2227,10 @@ public void testGetTrainedModels() throws Exception { { GetTrainedModelsResponse getTrainedModelsResponse = execute( - new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(true).setIncludeDefinition(true), + new GetTrainedModelsRequest(modelIdPrefix + 0) + .setDecompressDefinition(true) + .includeDefinition() + .includeTotalFeatureImportance(), machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); @@ -2242,7 +2241,10 @@ public void testGetTrainedModels() throws Exception { assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0)); getTrainedModelsResponse = execute( - new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(true), + new GetTrainedModelsRequest(modelIdPrefix + 0) + .setDecompressDefinition(false) + .includeTotalFeatureImportance() + .includeDefinition(), machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); @@ -2253,7 +2255,9 @@ public void testGetTrainedModels() throws Exception { assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0)); getTrainedModelsResponse = execute( - new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(false), + new GetTrainedModelsRequest(modelIdPrefix + 0) + .setDecompressDefinition(false) + .includeDefinition(), machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); assertThat(getTrainedModelsResponse.getCount(), equalTo(1L)); @@ -2396,66 +2400,6 @@ public void testGetTrainedModelsStats() throws Exception { } } - public void testGetTrainedModelsMetadata() throws Exception { - MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); - String modelIdPrefix = "a-get-trained-model-metadata-"; - int numberOfModels = 5; - for (int i = 0; i < numberOfModels; ++i) { - String modelId = modelIdPrefix + i; - putTrainedModel(modelId); - IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME).id("trained_model_metadata-" + modelId); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - indexRequest.source("{\"model_id\":\"" + modelId + "\", \"doc_type\": \"trained_model_metadata\",\n" + - " \"total_feature_importance\": [\n" + - " {\n" + - " \"feature_name\": \"foo\",\n" + - " \"importance\": {\n" + - " \"mean_magnitude\": 6.0,\n" + - " \"min\": -3.0,\n" + - " \"max\": 3.0\n" + - " }\n" + - " },\n" + - " {\n" + - " \"feature_name\": \"bar\",\n" + - " \"importance\": {\n" + - " \"mean_magnitude\": 5.0,\n" + - " \"min\": -2.0,\n" + - " \"max\": 3.0\n" + - " }\n" + - " }\n" + - " ]}", XContentType.JSON); - highLevelClient().index(indexRequest, RequestOptions.DEFAULT); - } - - { - GetTrainedModelsMetadataResponse getTrainedModelsMetadataResponse = execute( - GetTrainedModelsMetadataRequest.getAllTrainedModelsMetadataRequest(), - machineLearningClient::getTrainedModelsMetadata, machineLearningClient::getTrainedModelsMetadataAsync); - assertThat(getTrainedModelsMetadataResponse.getTrainedModelsMetadata(), hasSize(numberOfModels)); - assertThat(getTrainedModelsMetadataResponse.getCount(), equalTo(5L)); - } - { - GetTrainedModelsMetadataResponse getTrainedModelsMetadataResponse = execute( - new GetTrainedModelsMetadataRequest(modelIdPrefix + 4, modelIdPrefix + 2, modelIdPrefix + 3), - machineLearningClient::getTrainedModelsMetadata, machineLearningClient::getTrainedModelsMetadataAsync); - assertThat(getTrainedModelsMetadataResponse.getTrainedModelsMetadata(), hasSize(3)); - assertThat(getTrainedModelsMetadataResponse.getCount(), equalTo(3L)); - } - { - GetTrainedModelsMetadataResponse getTrainedModelsMetadataResponse = execute( - new GetTrainedModelsMetadataRequest(modelIdPrefix + "*").setPageParams(new PageParams(1, 2)), - machineLearningClient::getTrainedModelsMetadata, machineLearningClient::getTrainedModelsMetadataAsync); - assertThat(getTrainedModelsMetadataResponse.getTrainedModelsMetadata(), hasSize(2)); - assertThat(getTrainedModelsMetadataResponse.getCount(), equalTo(5L)); - assertThat( - getTrainedModelsMetadataResponse.getTrainedModelsMetadata() - .stream() - .map(TrainedModelMetadata::getModelId) - .collect(Collectors.toList()), - containsInAnyOrder(modelIdPrefix + 1, modelIdPrefix + 2)); - } - } - public void testDeleteTrainedModel() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); String modelId = "delete-trained-model-test"; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 4a3ec2c7c8021..6d47c4e3dd5fa 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -91,8 +91,6 @@ import org.elasticsearch.client.ml.GetOverallBucketsResponse; import org.elasticsearch.client.ml.GetRecordsRequest; import org.elasticsearch.client.ml.GetRecordsResponse; -import org.elasticsearch.client.ml.GetTrainedModelsMetadataRequest; -import org.elasticsearch.client.ml.GetTrainedModelsMetadataResponse; import org.elasticsearch.client.ml.GetTrainedModelsRequest; import org.elasticsearch.client.ml.GetTrainedModelsResponse; import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest; @@ -184,7 +182,6 @@ import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; -import org.elasticsearch.client.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.AnalysisLimits; import org.elasticsearch.client.ml.job.config.DataDescription; @@ -217,7 +214,6 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.tasks.TaskId; -import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.junit.After; import java.io.IOException; @@ -3698,11 +3694,12 @@ public void testGetTrainedModels() throws Exception { // tag::get-trained-models-request GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1> .setPageParams(new PageParams(0, 1)) // <2> - .setIncludeDefinition(false) // <3> - .setDecompressDefinition(false) // <4> - .setAllowNoMatch(true) // <5> - .setTags("regression") // <6> - .setForExport(false); // <7> + .includeDefinition() // <3> + .includeTotalFeatureImportance() // <4> + .setDecompressDefinition(false) // <5> + .setAllowNoMatch(true) // <6> + .setTags("regression") // <7> + .setForExport(false); // <8> // end::get-trained-models-request request.setTags((List)null); @@ -3872,81 +3869,6 @@ public void onFailure(Exception e) { } } - public void testGetTrainedModelsMetadata() throws Exception { - String modelId = "my-trained-model"; - putTrainedModel(modelId); - IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME).id("trained_model_metadata-" + modelId); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - indexRequest.source("{\"model_id\":\"" + modelId + "\", \"doc_type\": \"trained_model_metadata\",\n" + - " \"total_feature_importance\": [\n" + - " {\n" + - " \"feature_name\": \"foo\",\n" + - " \"importance\": {\n" + - " \"mean_magnitude\": 6.0,\n" + - " \"min\": -3.0,\n" + - " \"max\": 3.0\n" + - " }\n" + - " },\n" + - " {\n" + - " \"feature_name\": \"bar\",\n" + - " \"importance\": {\n" + - " \"mean_magnitude\": 5.0,\n" + - " \"min\": -2.0,\n" + - " \"max\": 3.0\n" + - " }\n" + - " }\n" + - " ]}", XContentType.JSON); - highLevelClient().index(indexRequest, RequestOptions.DEFAULT); - RestHighLevelClient client = highLevelClient(); - { - // tag::get-trained-models-metadata-request - GetTrainedModelsMetadataRequest request = - new GetTrainedModelsMetadataRequest("my-trained-model") // <1> - .setPageParams(new PageParams(0, 1)) // <2> - .setAllowNoMatch(true); // <3> - // end::get-trained-models-metadata-request - - // tag::get-trained-models-metadata-execute - GetTrainedModelsMetadataResponse response = - client.machineLearning().getTrainedModelsMetadata(request, RequestOptions.DEFAULT); - // end::get-trained-models-metadata-execute - - // tag::get-trained-models-metadata-response - List models = response.getTrainedModelsMetadata(); - // end::get-trained-models-metadata-response - - assertThat(models, hasSize(1)); - } - { - GetTrainedModelsMetadataRequest request = new GetTrainedModelsMetadataRequest("my-trained-model"); - - // tag::get-trained-models-metadata-execute-listener - ActionListener listener = new ActionListener<>() { - @Override - public void onResponse(GetTrainedModelsMetadataResponse response) { - // <1> - } - - @Override - public void onFailure(Exception e) { - // <2> - } - }; - // end::get-trained-models-metadata-execute-listener - - // Replace the empty listener by a blocking listener in test - CountDownLatch latch = new CountDownLatch(1); - listener = new LatchedActionListener<>(listener, latch); - - // tag::get-trained-models-metadata-execute-async - client.machineLearning() - .getTrainedModelsMetadataAsync(request, RequestOptions.DEFAULT, listener); // <1> - // end::get-trained-models-metadata-execute-async - - assertTrue(latch.await(30L, TimeUnit.SECONDS)); - } - } - public void testDeleteTrainedModel() throws Exception { RestHighLevelClient client = highLevelClient(); { diff --git a/docs/java-rest/high-level/ml/get-trained-models.asciidoc b/docs/java-rest/high-level/ml/get-trained-models.asciidoc index ffaea526f016c..c0276290cfc15 100644 --- a/docs/java-rest/high-level/ml/get-trained-models.asciidoc +++ b/docs/java-rest/high-level/ml/get-trained-models.asciidoc @@ -25,14 +25,16 @@ include-tagged::{doc-tests-file}[{api}-request] <1> Constructing a new GET request referencing an existing Trained Model <2> Set the paging parameters <3> Indicate if the complete model definition should be included -<4> Should the definition be fully decompressed on GET -<5> Allow empty response if no Trained Models match the provided ID patterns. +<4> Indicate if the total feature importance for the features used in training + should be included in the model `metadata` field. +<5> Should the definition be fully decompressed on GET +<6> Allow empty response if no Trained Models match the provided ID patterns. If false, an error will be thrown if no Trained Models match the ID patterns. -<6> An optional list of tags used to narrow the model search. A Trained Model +<7> An optional list of tags used to narrow the model search. A Trained Model can have many tags or none. The trained models in the response will contain all the provided tags. -<7> Optional boolean value indicating if certain fields should be removed on +<8> Optional boolean value indicating if certain fields should be removed on retrieval. This is useful for getting the trained model in a format that can then be put into another cluster. diff --git a/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc b/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc deleted file mode 100644 index 090094f383612..0000000000000 --- a/docs/reference/ml/df-analytics/apis/get-inference-trained-model-metadata.asciidoc +++ /dev/null @@ -1,284 +0,0 @@ -[role="xpack"] -[testenv="basic"] -[[get-inference-metadata]] -= Get trained model metadata API -[subs="attributes"] -++++ -Get {infer} trained model metadata -++++ - -Retrieves training metadata information for trained models. - -experimental[] - - -[[ml-get-inference-metadata-request]] -== {api-request-title} - -`GET _ml/inference/_metadata` + - -`GET _ml/inference/_all/_metadata` + - -`GET _ml/inference//_metadata` + - -`GET _ml/inference/,/_metadata` + - -`GET _ml/inference/,/_metadata` - - -[[ml-get-inference-metadata-prereq]] -== {api-prereq-title} - -If the {es} {security-features} are enabled, you must have the following -privileges: - -* cluster: `monitor_ml` - -For more information, see <> and {ml-docs-setup-privileges}. - -[[ml-get-inference-metadata-desc]] -== {api-description-title} - -You can get metadata for multiple trained models in a single API -request by using a comma-separated list of model IDs or a wildcard expression. - - -[[ml-get-inference-metadata-path-params]] -== {api-path-parms-title} - -``:: -(Optional, string) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] - - -[[ml-get-inference-metadata-query-params]] -== {api-query-parms-title} - -`allow_no_match`:: -(Optional, boolean) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models] - -`from`:: -(Optional, integer) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models] - -`size`:: -(Optional, integer) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size-models] - -[role="child_attributes"] -[[ml-get-inference-metadata-results]] -== {api-response-body-title} - -`count`:: -(integer) -The total number of trained model metadata objects that matched the requested ID patterns. -Could be higher than the number of items in the `trained_models_metadata` array as the -size of the array is restricted by the supplied `size` parameter. - -`trained_models_metadata`:: -(array) -An array of trained model metadata objects, which are sorted by the `model_id` value in -ascending order. -+ -.Properties of trained model metadata -[%collapsible%open] -==== -`model_id`::: -(string) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] - -`total_feature_importance`::: -(array) -An array of the total feature importance for each training feature used from -the training data set. -+ -.Properties of total feature importance -[%collapsible%open] -===== - -`feature_name`::: -(string) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-feature-name] - -`importance`::: -(object) -A collection of feature importance statistics related to the training data set for this particular feature. -+ -.Properties of feature importance -[%collapsible%open] -====== -`mean_magnitude`::: -(double) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude] - -`max`::: -(int) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max] - -`min`::: -(int) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min] - -====== - -`classes`::: -(array) -If the trained model is a classification model, feature importance statistics are gathered -per target class value. -+ -.Properties of class feature importance -[%collapsible%open] - -====== - -`class_name`::: -(string) -The target class value. Could be a string, boolean, or number. - -`importance`::: -(object) -A collection of feature importance statistics related to the training data set for this particular feature. -+ -.Properties of feature importance -[%collapsible%open] -======= -`mean_magnitude`::: -(double) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude] - -`max`::: -(int) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max] - -`min`::: -(int) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min] - -======= - -====== - -===== - -==== - -[[ml-get-inference-metadata-response-codes]] -== {api-response-codes-title} - -`404` (Missing resources):: - If `allow_no_match` is `false`, this code indicates that there are no - resources that match the request or only partial matches for the request. - -[[ml-get-inference-metadata-example]] -== {api-examples-title} - -The following example gets training metadata for all the trained models: - -[source,console] --------------------------------------------------- -GET _ml/inference/_metadata --------------------------------------------------- -// TEST[skip:TBD] - - -The API returns the following results: - -[source,console-result] ----- -{ - "count" : 2, - "trained_models_metadata" : [ - { - "model_id" : "avg_price_prediction-1599149443166", - "total_feature_importance" : [ - { - "feature_name" : "Origin", - "importance" : { - "mean_magnitude" : 25.862683737654795, - "min" : -188.93284143727874, - "max" : 162.8783518094679 - } - }, - ... - ] - }, - { - "model_id" : "dest_weather_prediction-1599149568413", - "total_feature_importance" : [ - { - "feature_name" : "dayOfWeek", - "classes" : [ - { - "class_name" : "Clear", - "importance" : { - "mean_magnitude" : 0.0033597810869050483, - "min" : -0.034589509802599394, - "max" : 0.013677011897069439 - } - }, - { - "class_name" : "Cloudy", - "importance" : { - "mean_magnitude" : 0.003549516620011909, - "min" : -0.06736294734141816, - "max" : 0.088650519638185 - } - }, - { - "class_name" : "Sunny", - "importance" : { - "mean_magnitude" : 0.004471474339413112, - "min" : -0.08060353377909144, - "max" : 0.1045130657148837 - } - }, - { - "class_name" : "Hail", - "importance" : { - "mean_magnitude" : 0.007072062864425885, - "min" : -0.05044235221609796, - "max" : 0.038623432806435085 - } - }, - { - "class_name" : "Heavy Fog", - "importance" : { - "mean_magnitude" : 0.0065139540458721236, - "min" : -0.018591621408001358, - "max" : 0.03543735929759353 - } - }, - { - "class_name" : "Thunder & Lightning", - "importance" : { - "mean_magnitude" : 0.0066567969304509155, - "min" : -0.059528507259167134, - "max" : 0.03628958395628503 - } - }, - { - "class_name" : "Rain", - "importance" : { - "mean_magnitude" : 0.0038758238618025985, - "min" : -0.07831548102713791, - "max" : 0.05696179640413974 - } - }, - { - "class_name" : "Damaging Wind", - "importance" : { - "mean_magnitude" : 0.0113605093018583, - "min" : -0.053221001966268555, - "max" : 0.07941614599243701 - } - } - ] - }, - ... - ] - } - ] -} ----- -// NOTCONSOLE diff --git a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc index e9cb170c1f4d0..553c615304f0e 100644 --- a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc @@ -29,19 +29,19 @@ experimental[] [[ml-get-inference-prereq]] == {api-prereq-title} -If the {es} {security-features} are enabled, you must have the following +If the {es} {security-features} are enabled, you must have the following privileges: * cluster: `monitor_ml` - -For more information, see <> and + +For more information, see <> and {ml-docs-setup-privileges}. [[ml-get-inference-desc]] == {api-description-title} -You can get information for multiple trained models in a single API request by +You can get information for multiple trained models in a single API request by using a comma-separated list of model IDs or a wildcard expression. @@ -49,7 +49,7 @@ using a comma-separated list of model IDs or a wildcard expression. == {api-path-parms-title} ``:: -(Optional, string) +(Optional, string) include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] @@ -57,12 +57,12 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] == {api-query-parms-title} `allow_no_match`:: -(Optional, boolean) +(Optional, boolean) include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models] `decompress_definition`:: (Optional, boolean) -Specifies whether the included model definition should be returned as a JSON map +Specifies whether the included model definition should be returned as a JSON map (`true`) or in a custom compressed format (`false`). Defaults to `true`. `for_export`:: @@ -72,17 +72,20 @@ retrieval. This allows the model to be in an acceptable format to be retrieved and then added to another cluster. Default is false. `from`:: -(Optional, integer) +(Optional, integer) include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models] -`include_model_definition`:: -(Optional, boolean) -Specifies whether the model definition is returned in the response. Defaults to -`false`. When `true`, only a single model must match the ID patterns provided. -Otherwise, a bad request is returned. +`include`:: +(Optional, string) +A comma delimited string of optional fields to include in the response body. +Valid options are: + - definition: to include the model definition + - total_feature_importance: to include the total feature importance for the + training feature sets. This field will be available in the `metadata` field. +Default is empty, indicating including no optional fields. `size`:: -(Optional, integer) +(Optional, integer) include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size-models] `tags`:: @@ -95,7 +98,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=tags] `trained_model_configs`:: (array) -An array of trained model resources, which are sorted by the `model_id` value in +An array of trained model resources, which are sorted by the `model_id` value in ascending order. + .Properties of trained model resources @@ -133,8 +136,86 @@ The license level of the trained model. `metadata`::: (object) -An object containing metadata about the trained model. For example, models +An object containing metadata about the trained model. For example, models created by {dfanalytics} contain `analysis_config` and `input` objects. +.Properties of metadata +[%collapsible%open] +===== +`total_feature_importance`::: +(array) +An array of the total feature importance for each training feature used from +the training data set. This array of objects is returned if {dfanalytics} trained +the model and the request includes `total_feature_importance` in the `include` +request parameter. ++ +.Properties of total feature importance +[%collapsible%open] +====== + +`feature_name`::: +(string) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-feature-name] + +`importance`::: +(object) +A collection of feature importance statistics related to the training data set for this particular feature. ++ +.Properties of feature importance +[%collapsible%open] +======= +`mean_magnitude`::: +(double) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude] + +`max`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max] + +`min`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min] + +======= + +`classes`::: +(array) +If the trained model is a classification model, feature importance statistics are gathered +per target class value. ++ +.Properties of class feature importance +[%collapsible%open] + +======= + +`class_name`::: +(string) +The target class value. Could be a string, boolean, or number. + +`importance`::: +(object) +A collection of feature importance statistics related to the training data set for this particular feature. ++ +.Properties of feature importance +[%collapsible%open] +======== +`mean_magnitude`::: +(double) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude] + +`max`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max] + +`min`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min] + +======== + +======= + +====== +===== `model_id`::: (string) @@ -154,13 +235,13 @@ The {es} version number in which the trained model was created. == {api-response-codes-title} `400`:: - If `include_model_definition` is `true`, this code indicates that more than + If `include_model_definition` is `true`, this code indicates that more than one models match the ID pattern. `404` (Missing resources):: If `allow_no_match` is `false`, this code indicates that there are no resources that match the request or only partial matches for the request. - + [[ml-get-inference-example]] == {api-examples-title} diff --git a/docs/reference/ml/df-analytics/apis/index.asciidoc b/docs/reference/ml/df-analytics/apis/index.asciidoc index 22096077fc944..421393a1b53e1 100644 --- a/docs/reference/ml/df-analytics/apis/index.asciidoc +++ b/docs/reference/ml/df-analytics/apis/index.asciidoc @@ -16,7 +16,6 @@ include::get-dfanalytics.asciidoc[leveloffset=+2] include::get-dfanalytics-stats.asciidoc[leveloffset=+2] include::get-inference-trained-model.asciidoc[leveloffset=+2] include::get-inference-trained-model-stats.asciidoc[leveloffset=+2] -include::get-inference-trained-model-metadata.asciidoc[leveloffset=+2] //SET/START/STOP include::start-dfanalytics.asciidoc[leveloffset=+2] include::stop-dfanalytics.asciidoc[leveloffset=+2] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java index 823ead709ce3e..4cabd05a2a99a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java @@ -5,19 +5,24 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Set; public class GetTrainedModelsAction extends ActionType { @@ -31,23 +36,60 @@ private GetTrainedModelsAction() { public static class Request extends AbstractGetResourcesRequest { - public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition"); + static final String DEFINITION = "definition"; + static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance"; + private static final Set KNOWN_INCLUDES; + static { + HashSet includes = new HashSet<>(2, 1.0f); + includes.add(DEFINITION); + includes.add(TOTAL_FEATURE_IMPORTANCE); + KNOWN_INCLUDES = Collections.unmodifiableSet(includes); + } + public static final ParseField INCLUDE = new ParseField("include"); + public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition"; public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); public static final ParseField TAGS = new ParseField("tags"); - private final boolean includeModelDefinition; + private final Set includes; private final List tags; + @Deprecated public Request(String id, boolean includeModelDefinition, List tags) { setResourceId(id); setAllowNoResources(true); - this.includeModelDefinition = includeModelDefinition; this.tags = tags == null ? Collections.emptyList() : tags; + if (includeModelDefinition) { + this.includes = new HashSet<>(Collections.singletonList(DEFINITION)); + } else { + this.includes = Collections.emptySet(); + } + } + + public Request(String id, List tags, Set includes) { + setResourceId(id); + setAllowNoResources(true); + this.tags = tags == null ? Collections.emptyList() : tags; + this.includes = includes == null ? Collections.emptySet() : includes; + Set unknownIncludes = Sets.difference(this.includes, KNOWN_INCLUDES); + if (unknownIncludes.isEmpty() == false) { + throw ExceptionsHelper.badRequestException( + "unknown [include] parameters {}. Valid options are {}", + unknownIncludes, + KNOWN_INCLUDES); + } } public Request(StreamInput in) throws IOException { super(in); - this.includeModelDefinition = in.readBoolean(); + if (in.getVersion().onOrAfter(Version.V_7_10_0)) { + this.includes = in.readSet(StreamInput::readString); + } else { + Set includes = new HashSet<>(); + if (in.readBoolean()) { + includes.add(DEFINITION); + } + this.includes = includes; + } this.tags = in.readStringList(); } @@ -57,7 +99,11 @@ public String getResourceIdField() { } public boolean isIncludeModelDefinition() { - return includeModelDefinition; + return this.includes.contains(DEFINITION); + } + + public boolean isIncludeTotalFeatureImportance() { + return this.includes.contains(TOTAL_FEATURE_IMPORTANCE); } public List getTags() { @@ -67,13 +113,17 @@ public List getTags() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeBoolean(includeModelDefinition); + if (out.getVersion().onOrAfter(Version.V_7_10_0)) { + out.writeCollection(this.includes, StreamOutput::writeString); + } else { + out.writeBoolean(this.includes.contains(DEFINITION)); + } out.writeStringCollection(tags); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), includeModelDefinition, tags); + return Objects.hash(super.hashCode(), includes, tags); } @Override @@ -85,7 +135,7 @@ public boolean equals(Object obj) { return false; } Request other = (Request) obj; - return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags); + return super.equals(obj) && this.includes.equals(other.includes) && Objects.equals(tags, other.tags); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 1ffa164af6857..9ae67880215a1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlStrings; @@ -39,6 +40,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; import static org.elasticsearch.action.ValidateActions.addValidationError; @@ -51,6 +53,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1; public static final String DECOMPRESS_DEFINITION = "decompress_definition"; public static final String FOR_EXPORT = "for_export"; + public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance"; + private static final Set RESERVED_METADATA_FIELDS = Collections.singleton(TOTAL_FEATURE_IMPORTANCE); private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; @@ -408,7 +412,7 @@ public Builder(TrainedModelConfig config) { this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition); this.description = config.getDescription(); this.tags = config.getTags(); - this.metadata = config.getMetadata(); + this.metadata = config.getMetadata() == null ? null : new HashMap<>(config.getMetadata()); this.input = config.getInput(); this.estimatedOperations = config.estimatedOperations; this.estimatedHeapMemory = config.estimatedHeapMemory; @@ -460,6 +464,18 @@ public Builder setMetadata(Map metadata) { return this; } + public Builder setFeatureImportance(List totalFeatureImportance) { + if (totalFeatureImportance == null) { + return this; + } + if (this.metadata == null) { + this.metadata = new HashMap<>(); + } + this.metadata.put(TOTAL_FEATURE_IMPORTANCE, + totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList())); + return this; + } + public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) { if (definition == null) { return this; @@ -616,6 +632,12 @@ public Builder validate(boolean forCreation) { ESTIMATED_OPERATIONS.getPreferredName(), validationException); validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException); + if (metadata != null) { + validationException = checkIllegalSetting( + metadata.get(TOTAL_FEATURE_IMPORTANCE), + METADATA.getPreferredName() + "." + TOTAL_FEATURE_IMPORTANCE, + validationException); + } } if (validationException != null) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java index 9f2df2b7512e6..8676af6ff5ca0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java @@ -20,8 +20,11 @@ import java.io.IOException; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; public class TotalFeatureImportance implements ToXContentObject, Writeable { @@ -81,16 +84,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(FEATURE_NAME.getPreferredName(), featureName); - if (importance != null) { - builder.field(IMPORTANCE.getPreferredName(), importance); - } - if (classImportances.isEmpty() == false) { - builder.field(CLASSES.getPreferredName(), classImportances); - } - builder.endObject(); - return builder; + return builder.map(asMap()); } @Override @@ -103,6 +97,18 @@ public boolean equals(Object o) { && Objects.equals(classImportances, that.classImportances); } + public Map asMap() { + Map map = new LinkedHashMap<>(); + map.put(FEATURE_NAME.getPreferredName(), featureName); + if (importance != null) { + map.put(IMPORTANCE.getPreferredName(), importance.asMap()); + } + if (classImportances.isEmpty() == false) { + map.put(CLASSES.getPreferredName(), classImportances.stream().map(ClassImportance::asMap).collect(Collectors.toList())); + } + return map; + } + @Override public int hashCode() { return Objects.hash(featureName, importance, classImportances); @@ -165,12 +171,15 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); - builder.field(MIN.getPreferredName(), min); - builder.field(MAX.getPreferredName(), max); - builder.endObject(); - return builder; + return builder.map(asMap()); + } + + private Map asMap() { + Map map = new LinkedHashMap<>(); + map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); + map.put(MIN.getPreferredName(), min); + map.put(MAX.getPreferredName(), max); + return map; } } @@ -229,11 +238,14 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(CLASS_NAME.getPreferredName(), className); - builder.field(IMPORTANCE.getPreferredName(), importance); - builder.endObject(); - return builder; + return builder.map(asMap()); + } + + private Map asMap() { + Map map = new LinkedHashMap<>(); + map.put(CLASS_NAME.getPreferredName(), className); + map.put(IMPORTANCE.getPreferredName(), importance.asMap()); + return map; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java index dc3e8fc54d998..dd2662cf4002e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java @@ -53,6 +53,10 @@ public static String docId(String modelId) { return NAME + "-" + modelId; } + public static String modelId(String docId) { + return docId.substring(NAME.length() + 1); + } + private final List totalFeatureImportances; private final String modelId; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 6a7965a01b2b4..4d18a8f1006a9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -103,7 +103,7 @@ public final class Messages { public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]"; - public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata [{0}]"; + public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata {0}"; public static final String INFERENCE_CANNOT_DELETE_MODEL = "Unable to delete model [{0}]"; public static final String MODEL_DEFINITION_TRUNCATED = diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java index 7955117e11759..a75d7fb8f6a3d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java @@ -5,19 +5,28 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request; -public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase { +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class GetTrainedModelsRequestTests extends AbstractBWCWireSerializationTestCase { @Override protected Request createTestInstance() { Request request = new Request(randomAlphaOfLength(20), - randomBoolean(), randomBoolean() ? null : - randomList(10, () -> randomAlphaOfLength(10))); + randomList(10, () -> randomAlphaOfLength(10)), + randomBoolean() ? null : + Stream.generate(() -> randomFrom(Request.DEFINITION, Request.TOTAL_FEATURE_IMPORTANCE)) + .limit(4) + .collect(Collectors.toSet())); request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); return request; } @@ -26,4 +35,19 @@ protected Request createTestInstance() { protected Writeable.Reader instanceReader() { return Request::new; } + + @Override + protected Request mutateInstanceForVersion(Request instance, Version version) { + if (version.before(Version.V_7_10_0)) { + Set includes = new HashSet<>(); + if (instance.isIncludeModelDefinition()) { + includes.add(Request.DEFINITION); + } + return new Request( + instance.getResourceId(), + instance.getTags(), + includes); + } + return instance; + } } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index b9549842333d7..defde90095d01 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -42,11 +42,13 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.startsWith; public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase { @@ -95,19 +97,21 @@ public void testStoreModelViaChunkedPersister() throws IOException { trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture); Tuple> ids = getIdsFuture.actionGet(); assertThat(ids.v1(), equalTo(1L)); + String inferenceModelId = ids.v2().iterator().next(); PlainActionFuture getTrainedModelFuture = new PlainActionFuture<>(); - trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture); + trainedModelProvider.getTrainedModel(inferenceModelId, true, true, getTrainedModelFuture); TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet(); assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition)); assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations())); assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed())); + assertThat(storedConfig.getMetadata(), hasKey("total_feature_importance")); - PlainActionFuture getTrainedMetadataFuture = new PlainActionFuture<>(); - trainedModelProvider.getTrainedModelMetadata(ids.v2().iterator().next(), getTrainedMetadataFuture); + PlainActionFuture> getTrainedMetadataFuture = new PlainActionFuture<>(); + trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture); - TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet(); + TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet().get(inferenceModelId); assertThat(storedMetadata.getModelId(), startsWith(modelId)); assertThat(storedMetadata.getTotalFeatureImportances(), equalTo(modelMetadata.getFeatureImportances())); } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index 6c9b5634b9448..ab8ff4b02b869 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -89,7 +89,10 @@ public void testGetTrainedModelConfig() throws Exception { assertThat(exceptionHolder.get(), is(nullValue())); AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + blockingCall( + listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener), + getConfigHolder, + exceptionHolder); getConfigHolder.get().ensureParsedDefinition(xContentRegistry()); assertThat(getConfigHolder.get(), is(not(nullValue()))); assertThat(getConfigHolder.get(), equalTo(config)); @@ -120,7 +123,10 @@ public void testGetTrainedModelConfigWithoutDefinition() throws Exception { assertThat(exceptionHolder.get(), is(nullValue())); AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, false, listener), getConfigHolder, exceptionHolder); + blockingCall(listener -> + trainedModelProvider.getTrainedModel(modelId, false, false, listener), + getConfigHolder, + exceptionHolder); getConfigHolder.get().ensureParsedDefinition(xContentRegistry()); assertThat(getConfigHolder.get(), is(not(nullValue()))); assertThat(getConfigHolder.get(), equalTo(copyWithoutDefinition)); @@ -131,7 +137,10 @@ public void testGetMissingTrainingModelConfig() throws Exception { String modelId = "test-get-missing-trained-model-config"; AtomicReference getConfigHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + blockingCall( + listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener), + getConfigHolder, + exceptionHolder); assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); @@ -153,7 +162,10 @@ public void testGetMissingTrainingModelConfigDefinition() throws Exception { .actionGet(); AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + blockingCall( + listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener), + getConfigHolder, + exceptionHolder); assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); @@ -192,7 +204,10 @@ public void testGetTruncatedModelDeprecatedDefinition() throws Exception { } AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + blockingCall( + listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener), + getConfigHolder, + exceptionHolder); assertThat(getConfigHolder.get(), is(nullValue())); assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); @@ -237,7 +252,10 @@ public void testGetTruncatedModelDefinition() throws Exception { } } AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + blockingCall( + listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener), + getConfigHolder, + exceptionHolder); assertThat(getConfigHolder.get(), is(nullValue())); assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 4c55fd4e7af96..da4955817fec5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -113,7 +113,6 @@ import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; -import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsMetadataAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; @@ -191,7 +190,6 @@ import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction; import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction; import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsAction; -import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsMetadataAction; import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.action.TransportInternalInferModelAction; import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction; @@ -310,7 +308,6 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; -import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsMetadataAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; @@ -866,7 +863,6 @@ public List getRestHandlers(Settings settings, RestController restC new RestDeleteTrainedModelAction(), new RestGetTrainedModelsStatsAction(), new RestPutTrainedModelAction(), - new RestGetTrainedModelsMetadataAction(), // CAT Handlers new RestCatJobsAction(), new RestCatTrainedModelsAction(), @@ -950,7 +946,6 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class), new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class), new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class), - new ActionHandler<>(GetTrainedModelsMetadataAction.INSTANCE, TransportGetTrainedModelsMetadataAction.class), usageAction, infoAction); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java index 1ffc13b8b1196..ba3edf91f91bc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -57,15 +57,25 @@ protected void doExecute(Task task, Request request, ActionListener li } if (request.isIncludeModelDefinition()) { - provider.getTrainedModel(totalAndIds.v2().iterator().next(), true, ActionListener.wrap( - config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()), - listener::onFailure - )); + provider.getTrainedModel( + totalAndIds.v2().iterator().next(), + true, + request.isIncludeTotalFeatureImportance(), + ActionListener.wrap( + config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()), + listener::onFailure + ) + ); } else { - provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap( - configs -> listener.onResponse(responseBuilder.setModels(configs).build()), - listener::onFailure - )); + provider.getTrainedModels( + totalAndIds.v2(), + request.isAllowNoResources(), + request.isIncludeTotalFeatureImportance(), + ActionListener.wrap( + configs -> listener.onResponse(responseBuilder.setModels(configs).build()), + listener::onFailure + ) + ); } }, listener::onFailure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index c94c668a87bde..2483a2cffaa3b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -82,7 +82,7 @@ protected void doExecute(Task task, Request request, ActionListener li responseBuilder.setLicensed(true); this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener); } else { - trainedModelProvider.getTrainedModel(request.getModelId(), false, ActionListener.wrap( + trainedModelProvider.getTrainedModel(request.getModelId(), false, false, ActionListener.wrap( trainedModelConfig -> { responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel())); if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 1c48ce9f1415b..838f06aacb708 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -270,7 +270,7 @@ private boolean loadModelIfNecessary(String modelId, Consumer consumer, ActionLi } private void loadModel(String modelId, Consumer consumer) { - provider.getTrainedModel(modelId, false, ActionListener.wrap( + provider.getTrainedModel(modelId, false, false, ActionListener.wrap( trainedModelConfig -> { trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); provider.getTrainedModelForInference(modelId, ActionListener.wrap( @@ -306,7 +306,7 @@ private void loadWithoutCaching(String modelId, ActionListener model // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // by a simulated pipeline logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId)); - provider.getTrainedModel(modelId, false, ActionListener.wrap( + provider.getTrainedModel(modelId, false, false, ActionListener.wrap( trainedModelConfig -> { // Verify we can pull the model into memory without causing OOM trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); @@ -434,7 +434,7 @@ private void cacheEvictionListener(RemovalNotification logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]", notification.getValue().model.getModelId())); - + // If the model is no longer referenced, flush the stats to persist as soon as possible notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false); } finally { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index f119774a96880..209b398ce7143 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -88,9 +88,11 @@ import java.util.Collection; import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.TreeSet; import java.util.stream.Collectors; @@ -234,14 +236,14 @@ public void storeTrainedModelMetadata(TrainedModelMetadata trainedModelMetadata, )); } - public void getTrainedModelMetadata(String modelId, ActionListener listener) { + public void getTrainedModelMetadata(Collection modelIds, ActionListener> listener) { SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders .boolQuery() - .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) + .filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelIds)) .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelMetadata.NAME)))) - .setSize(1) + .setSize(10_000) // First find the latest index .addSort("_index", SortOrder.DESC) .request(); @@ -249,18 +251,20 @@ public void getTrainedModelMetadata(String modelId, ActionListener { if (searchResponse.getHits().getHits().length == 0) { listener.onFailure(new ResourceNotFoundException( - Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId))); + Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds))); return; } - List metadataList = handleHits(searchResponse.getHits().getHits(), - modelId, - this::parseMetadataLenientlyFromSource); - listener.onResponse(metadataList.get(0)); + HashMap map = new HashMap<>(); + for (SearchHit hit : searchResponse.getHits().getHits()) { + String modelId = TrainedModelMetadata.modelId(Objects.requireNonNull(hit.getId())); + map.putIfAbsent(modelId, parseMetadataLenientlyFromSource(hit.getSourceRef(), modelId)); + } + listener.onResponse(map); }, e -> { if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { listener.onFailure(new ResourceNotFoundException( - Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId))); + Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds))); return; } listener.onFailure(e); @@ -370,7 +374,7 @@ public void getTrainedModelForInference(final String modelId, final ActionListen // TODO Change this when we get more than just langIdent stored if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { try { - TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry); + TrainedModelConfig config = loadModelFromResource(modelId, false).build().ensureParsedDefinition(xContentRegistry); assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork; listener.onResponse( InferenceDefinition.builder() @@ -433,18 +437,50 @@ public void getTrainedModelForInference(final String modelId, final ActionListen )); } - public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener listener) { + public void getTrainedModel(final String modelId, + final boolean includeDefinition, + final boolean includeTotalFeatureImportance, + final ActionListener finalListener) { if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { try { - listener.onResponse(loadModelFromResource(modelId, includeDefinition == false)); + finalListener.onResponse(loadModelFromResource(modelId, includeDefinition == false).build()); return; } catch (ElasticsearchException ex) { - listener.onFailure(ex); + finalListener.onFailure(ex); return; } } + ActionListener getTrainedModelListener = ActionListener.wrap( + modelBuilder -> { + if (includeTotalFeatureImportance == false) { + finalListener.onResponse(modelBuilder.build()); + return; + } + this.getTrainedModelMetadata(Collections.singletonList(modelId), ActionListener.wrap( + metadata -> { + TrainedModelMetadata modelMetadata = metadata.get(modelId); + if (modelMetadata != null) { + modelBuilder.setFeatureImportance(modelMetadata.getTotalFeatureImportances()); + } + finalListener.onResponse(modelBuilder.build()); + }, + failure -> { + // total feature importance is not necessary for a model to be valid + // we shouldn't fail if it is not found + if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) { + finalListener.onResponse(modelBuilder.build()); + return; + } + finalListener.onFailure(failure); + } + )); + + }, + finalListener::onFailure + ); + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders .idsQuery() .addIds(modelId)); @@ -482,11 +518,11 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio try { builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource); } catch (ResourceNotFoundException ex) { - listener.onFailure(new ResourceNotFoundException( + getTrainedModelListener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); return; } catch (Exception ex) { - listener.onFailure(ex); + getTrainedModelListener.onFailure(ex); return; } @@ -499,22 +535,22 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio String compressedString = getDefinitionFromDocs(docs, modelId); builder.setDefinitionFromString(compressedString); } catch (ElasticsearchException elasticsearchException) { - listener.onFailure(elasticsearchException); + getTrainedModelListener.onFailure(elasticsearchException); return; } } catch (ResourceNotFoundException ex) { - listener.onFailure(new ResourceNotFoundException( + getTrainedModelListener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); return; } catch (Exception ex) { - listener.onFailure(ex); + getTrainedModelListener.onFailure(ex); return; } } - listener.onResponse(builder.build()); + getTrainedModelListener.onResponse(builder); }, - listener::onFailure + getTrainedModelListener::onFailure ); executeAsyncWithOrigin(client, @@ -531,7 +567,10 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio * This does no expansion on the ids. * It assumes that there are fewer than 10k. */ - public void getTrainedModels(Set modelIds, boolean allowNoResources, final ActionListener> listener) { + public void getTrainedModels(Set modelIds, + boolean allowNoResources, + boolean includeTotalFeatureImportance, + final ActionListener> finalListener) { QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0]))); SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) @@ -540,23 +579,65 @@ public void getTrainedModels(Set modelIds, boolean allowNoResources, fin .setQuery(queryBuilder) .setSize(modelIds.size()) .request(); - List configs = new ArrayList<>(modelIds.size()); + List configs = new ArrayList<>(modelIds.size()); Set modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE); Set modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds); for(String modelId : modelsAsResource) { try { configs.add(loadModelFromResource(modelId, true)); } catch (ElasticsearchException ex) { - listener.onFailure(ex); + finalListener.onFailure(ex); return; } } if (modelsInIndex.isEmpty()) { - configs.sort(Comparator.comparing(TrainedModelConfig::getModelId)); - listener.onResponse(configs); + finalListener.onResponse(configs.stream() + .map(TrainedModelConfig.Builder::build) + .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) + .collect(Collectors.toList())); return; } + ActionListener> getTrainedModelListener = ActionListener.wrap( + modelBuilders -> { + if (includeTotalFeatureImportance == false) { + finalListener.onResponse(modelBuilders.stream() + .map(TrainedModelConfig.Builder::build) + .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) + .collect(Collectors.toList())); + return; + } + this.getTrainedModelMetadata(modelIds, ActionListener.wrap( + metadata -> + finalListener.onResponse(modelBuilders.stream() + .map(builder -> { + TrainedModelMetadata modelMetadata = metadata.get(builder.getModelId()); + if (modelMetadata != null) { + builder.setFeatureImportance(modelMetadata.getTotalFeatureImportances()); + } + return builder.build(); + }) + .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) + .collect(Collectors.toList())), + failure -> { + // total feature importance is not necessary for a model to be valid + // we shouldn't fail if it is not found + if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) { + finalListener.onResponse(modelBuilders.stream() + .map(TrainedModelConfig.Builder::build) + .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) + .collect(Collectors.toList())); + return; + + } + finalListener.onFailure(failure); + } + )); + + }, + finalListener::onFailure + ); + ActionListener configSearchHandler = ActionListener.wrap( searchResponse -> { Set observedIds = new HashSet<>( @@ -567,12 +648,12 @@ public void getTrainedModels(Set modelIds, boolean allowNoResources, fin try { if (observedIds.contains(searchHit.getId()) == false) { configs.add( - parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build() + parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()) ); observedIds.add(searchHit.getId()); } } catch (IOException ex) { - listener.onFailure( + getTrainedModelListener.onFailure( ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId())); return; } @@ -582,14 +663,13 @@ public void getTrainedModels(Set modelIds, boolean allowNoResources, fin // Otherwise, treat it as if it was never expanded to begin with. Set missingConfigs = Sets.difference(modelIds, observedIds); if (missingConfigs.isEmpty() == false && allowNoResources == false) { - listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); + getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); return; } // Ensure sorted even with the injection of locally resourced models - configs.sort(Comparator.comparing(TrainedModelConfig::getModelId)); - listener.onResponse(configs); + getTrainedModelListener.onResponse(configs); }, - listener::onFailure + getTrainedModelListener::onFailure ); executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler); @@ -638,7 +718,7 @@ public void expandIds(String idExpression, foundResourceIds = new HashSet<>(); for(String resourceId : matchedResourceIds) { // Does the model as a resource have all the tags? - if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) { + if (Sets.newHashSet(loadModelFromResource(resourceId, true).build().getTags()).containsAll(tags)) { foundResourceIds.add(resourceId); } } @@ -832,7 +912,7 @@ static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection tags return QueryBuilders.constantScoreQuery(boolQueryBuilder); } - TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) { + TrainedModelConfig.Builder loadModelFromResource(String modelId, boolean nullOutDefinition) { URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT); if (resource == null) { logger.error("[{}] presumed stored as a resource but not found", modelId); @@ -847,7 +927,7 @@ TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefiniti if (nullOutDefinition) { builder.clearDefinition(); } - return builder.build(); + return builder; } catch (IOException ioEx) { logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx); throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 417554a0a24b8..279a70893809f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -55,12 +56,17 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient if (Strings.isNullOrEmpty(modelId)) { modelId = Metadata.ALL; } - boolean includeModelDefinition = restRequest.paramAsBoolean( - GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(), - false - ); List tags = asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY)); - GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags); + Set includes = new HashSet<>( + asList( + restRequest.paramAsStringArray( + GetTrainedModelsAction.Request.INCLUDE.getPreferredName(), + Strings.EMPTY_ARRAY))); + final GetTrainedModelsAction.Request request = restRequest.hasParam(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION) ? + new GetTrainedModelsAction.Request(modelId, + restRequest.paramAsBoolean(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION, false), + tags) : + new GetTrainedModelsAction.Request(modelId, tags, includes); if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsMetadataAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsMetadataAction.java deleted file mode 100644 index 861563e165e77..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsMetadataAction.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.rest.inference; - -import org.elasticsearch.client.node.NodeClient; -import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.common.Strings; -import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.RestRequest; -import org.elasticsearch.rest.action.RestToXContentListener; -import org.elasticsearch.xpack.core.action.util.PageParams; -import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsMetadataAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; -import org.elasticsearch.xpack.ml.MachineLearning; - -import java.io.IOException; -import java.util.List; - -import static org.elasticsearch.rest.RestRequest.Method.GET; -import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH; - -public class RestGetTrainedModelsMetadataAction extends BaseRestHandler { - - @Override - public List routes() { - return List.of( - new Route(GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_metadata"), - new Route(GET, MachineLearning.BASE_PATH + "inference/_metadata")); - } - - @Override - public String getName() { - return "ml_get_trained_models_metadata_action"; - } - - @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - String modelId = restRequest.param(TrainedModelMetadata.MODEL_ID.getPreferredName()); - if (Strings.isNullOrEmpty(modelId)) { - modelId = Metadata.ALL; - } - GetTrainedModelsMetadataAction.Request request = new GetTrainedModelsMetadataAction.Request(modelId); - if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { - request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), - restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); - } - request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); - return channel -> client.execute(GetTrainedModelsMetadataAction.INSTANCE, request, new RestToXContentListener<>(channel)); - } - -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 5b67f68337e38..61a92a1be1698 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -437,9 +437,9 @@ public void testCircuitBreakerBreak() throws Exception { // the loading occurred or which models are currently in the cache due to evictions. // Verify that we have at least loaded all three assertBusy(() -> { - verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any()); - verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any()); - verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), eq(false), any()); }); assertBusy(() -> { assertThat(circuitBreaker.getUsed(), equalTo(10L)); @@ -553,10 +553,10 @@ private void withTrainedModel(String modelId, long size) { }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any()); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3]; listener.onResponse(trainedModelConfig); return null; - }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any()); } @SuppressWarnings("unchecked") @@ -564,20 +564,20 @@ private void withMissingModel(String modelId) { if (randomBoolean()) { doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3]; listener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); return null; - }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any()); } else { TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3]; listener.onResponse(trainedModelConfig); return null; - }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any()); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index aee4c43f22769..037b9ccc93e65 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -57,14 +57,14 @@ public void testGetModelThatExistsAsResource() throws Exception { TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) { PlainActionFuture future = new PlainActionFuture<>(); - trainedModelProvider.getTrainedModel(modelId, true, future); + trainedModelProvider.getTrainedModel(modelId, true, false, future); TrainedModelConfig configWithDefinition = future.actionGet(); assertThat(configWithDefinition.getModelId(), equalTo(modelId)); assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue()))); PlainActionFuture futureNoDefinition = new PlainActionFuture<>(); - trainedModelProvider.getTrainedModel(modelId, false, futureNoDefinition); + trainedModelProvider.getTrainedModel(modelId, false, false, futureNoDefinition); TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet(); assertThat(configWithoutDefinition.getModelId(), equalTo(modelId)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java index bc6061b3d921c..8eaa8f9f7c458 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java @@ -33,7 +33,7 @@ public void testLangInference() throws Exception { TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); PlainActionFuture future = new PlainActionFuture<>(); // Should be OK as we don't make any client calls - trainedModelProvider.getTrainedModel("lang_ident_model_1", true, future); + trainedModelProvider.getTrainedModel("lang_ident_model_1", true, false, future); TrainedModelConfig config = future.actionGet(); config.ensureParsedDefinition(xContentRegistry()); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json index 168d233c8e37e..a30cd14a7522b 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json @@ -34,11 +34,10 @@ "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", "default":true }, - "include_model_definition":{ - "type":"boolean", + "include":{ + "type":"string", "required":false, - "description":"Should the full model definition be included in the results. These definitions can be large. So be cautious when including them. Defaults to false.", - "default":false + "description":"A comma-separate list of fields to optionally include. Valid options are 'definition' and 'total_feature_importance'. Default is none." }, "decompress_definition":{ "type":"boolean", diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 5a437fc41665b..e9dd55060020c 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -1,6 +1,24 @@ setup: - skip: - features: headers + features: + - headers + - allowed_warnings + - do: + allowed_warnings: + - "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template" + headers: + Content-Type: application/json + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + index: + id: trained_model_metadata-a-regression-model-0 + index: .ml-inference-000003 + body: + model_id: "a-regression-model-0" + doc_type: "trained_model_metadata" + total_feature_importance: + - { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }} + - { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }} + - do: headers: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser @@ -548,6 +566,20 @@ setup: - match: { count: 12 } - match: { trained_model_configs.0.model_id: "a-regression-model-1" } --- +"Test get models with include total feature importance": + - do: + ml.get_trained_models: + model_id: "a-regression-model-*" + include: "total_feature_importance" + - match: { count: 2 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "a-regression-model-0" } + - is_true: trained_model_configs.0.metadata.total_feature_importance + - length: { trained_model_configs.0.metadata.total_feature_importance: 2 } + - match: { trained_model_configs.1.model_id: "a-regression-model-1" } + - is_false: trained_model_configs.1.metadata.total_feature_importance + +--- "Test delete given unused trained model": - do: ml.delete_trained_model: diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_metadata.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_metadata.yml deleted file mode 100644 index 4ebe7811b7b96..0000000000000 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_metadata.yml +++ /dev/null @@ -1,107 +0,0 @@ -setup: - - skip: - features: - - headers - - allowed_warnings - - do: - allowed_warnings: - - "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template" - headers: - Content-Type: application/json - Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser - index: - id: trained_model_metadata-a-regression-model0 - index: .ml-inference-000003 - body: - model_id: "a-regression-model0" - doc_type: "trained_model_metadata" - total_feature_importance: - - { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }} - - { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }} - - do: - allowed_warnings: - - "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template" - headers: - Content-Type: application/json - Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser - index: - id: trained_model_metadata-a-regression-model1 - index: .ml-inference-000003 - body: - model_id: "a-regression-model1" - doc_type: "trained_model_metadata" - total_feature_importance: - - { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 } } - - { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 } } - - do: - allowed_warnings: - - "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template" - headers: - Content-Type: application/json - Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser - index: - id: trained_model_metadata-a-classification-model - index: .ml-inference-000003 - body: - model_id: "a-classification-model" - doc_type: "trained_model_metadata" - total_feature_importance: - - { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 } } - - { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 } } - - - do: - headers: - Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser - indices.refresh: { } - ---- -"Test get given missing trained model metadata": - - - do: - catch: missing - ml.get_trained_models_metadata: - model_id: "missing-trained-model" ---- -"Test get given expression without matches and allow_no_match is false": - - - do: - catch: missing - ml.get_trained_models_metadata: - model_id: "missing-trained-model*" - allow_no_match: false - ---- -"Test get given expression without matches and allow_no_match is true": - - - do: - ml.get_trained_models_metadata: - model_id: "missing-trained-model*" - allow_no_match: true - - match: { count: 0 } - - match: { trained_models_metadata: [] } ---- -"Test get models metadata": - - do: - ml.get_trained_models_metadata: - model_id: "*" - size: 3 - - match: { count: 3 } - - length: { trained_models_metadata: 3 } - - - do: - ml.get_trained_models_metadata: - model_id: "a-regression*" - - match: { count: 2 } - - length: { trained_models_metadata: 2 } - - match: { trained_models_metadata.0.model_id: "a-regression-model0" } - - match: { trained_models_metadata.1.model_id: "a-regression-model1" } - - - do: - ml.get_trained_models_metadata: - model_id: "*" - from: 0 - size: 2 - - match: { count: 3 } - - length: { trained_models_metadata: 2 } - - match: { trained_models_metadata.0.model_id: "a-classification-model" } - - match: { trained_models_metadata.1.model_id: "a-regression-model0" } From 181200279350579e5a8f892faaa25b02f71452d6 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 16 Sep 2020 13:26:39 -0400 Subject: [PATCH 5/7] fixing test --- .../test/java/org/elasticsearch/client/MachineLearningIT.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index f5cb128b21f1e..2198ea5ee9162 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -2256,8 +2256,7 @@ public void testGetTrainedModels() throws Exception { getTrainedModelsResponse = execute( new GetTrainedModelsRequest(modelIdPrefix + 0) - .setDecompressDefinition(false) - .includeDefinition(), + .setDecompressDefinition(false), machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); assertThat(getTrainedModelsResponse.getCount(), equalTo(1L)); From ed50912feb97d444f9ac153c1f26b2dbc425191c Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 16 Sep 2020 13:53:22 -0400 Subject: [PATCH 6/7] test fix --- .../src/test/resources/rest-api-spec/test/ml/inference_crud.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index e9dd55060020c..31bda1d9a271e 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -856,7 +856,7 @@ setup: ml.get_trained_models: model_id: "a-regression-model-1" for_export: true - include_model_definition: true + include: "definition" decompress_definition: false - match: { trained_model_configs.0.description: "empty model for tests" } From 7ec181355914a17dc8b3cf5dcf248bff8323ae66 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 17 Sep 2020 07:45:52 -0400 Subject: [PATCH 7/7] addressing pr comments --- .../high-level/ml/get-trained-models.asciidoc | 16 ++++++++-------- .../apis/get-inference-trained-model.asciidoc | 9 +++++---- .../plugin/ml/qa/ml-with-security/build.gradle | 2 -- .../persistence/TrainedModelProvider.java | 2 -- 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/docs/java-rest/high-level/ml/get-trained-models.asciidoc b/docs/java-rest/high-level/ml/get-trained-models.asciidoc index c0276290cfc15..275b4c54292b5 100644 --- a/docs/java-rest/high-level/ml/get-trained-models.asciidoc +++ b/docs/java-rest/high-level/ml/get-trained-models.asciidoc @@ -22,28 +22,28 @@ IDs, or the special wildcard `_all` to get all trained models. -------------------------------------------------- include-tagged::{doc-tests-file}[{api}-request] -------------------------------------------------- -<1> Constructing a new GET request referencing an existing Trained Model +<1> Constructing a new GET request referencing an existing trained model <2> Set the paging parameters <3> Indicate if the complete model definition should be included <4> Indicate if the total feature importance for the features used in training should be included in the model `metadata` field. <5> Should the definition be fully decompressed on GET -<6> Allow empty response if no Trained Models match the provided ID patterns. - If false, an error will be thrown if no Trained Models match the +<6> Allow empty response if no trained models match the provided ID patterns. + If false, an error will be thrown if no trained models match the ID patterns. -<7> An optional list of tags used to narrow the model search. A Trained Model +<7> An optional list of tags used to narrow the model search. A trained model can have many tags or none. The trained models in the response will contain all the provided tags. -<8> Optional boolean value indicating if certain fields should be removed on - retrieval. This is useful for getting the trained model in a format that - can then be put into another cluster. +<8> Optional boolean value for requesting the trained model in a format that can + then be put into another cluster. Certain fields that can only be set when + the model is imported are removed. include::../execution.asciidoc[] [id="{upid}-{api}-response"] ==== Response -The returned +{response}+ contains the requested Trained Model. +The returned +{response}+ contains the requested trained model. ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- diff --git a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc index 553c615304f0e..6ba3b2e68449a 100644 --- a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc @@ -79,9 +79,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models] (Optional, string) A comma delimited string of optional fields to include in the response body. Valid options are: - - definition: to include the model definition - - total_feature_importance: to include the total feature importance for the - training feature sets. This field will be available in the `metadata` field. + - `definition`: Includes the model definition + - `total_feature_importance`: Includes the total feature importance for the + training data set. This field is available in the `metadata` field in the + response body. Default is empty, indicating including no optional fields. `size`:: @@ -143,7 +144,7 @@ created by {dfanalytics} contain `analysis_config` and `input` objects. ===== `total_feature_importance`::: (array) -An array of the total feature importance for each training feature used from +An array of the total feature importance for each feature used from the training data set. This array of objects is returned if {dfanalytics} trained the model and the request includes `total_feature_importance` in the `include` request parameter. diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 5f4405125ffc8..e0d9b4afd0681 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -142,8 +142,6 @@ yamlRestTest { 'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index', 'ml/inference_crud/Test put model with empty input.field_names', 'ml/inference_crud/Test PUT model where target type and inference config mismatch', - 'ml/inference_metadata/Test get given missing trained model metadata', - 'ml/inference_metadata/Test get given expression without matches and allow_no_match is false', 'ml/inference_processor/Test create processor with missing mandatory fields', 'ml/inference_stats_crud/Test get stats given missing trained model', 'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false', diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 209b398ce7143..82c64bb1203d1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -628,12 +628,10 @@ public void getTrainedModels(Set modelIds, .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) .collect(Collectors.toList())); return; - } finalListener.onFailure(failure); } )); - }, finalListener::onFailure );