Skip to content

Commit 18ec001

Browse files
committed
[ML][Inference] add tags url param to GET (elastic#51330)
Adds a new URL parameter, `tags` to the GET _ml/inference/<model_id> endpoint. This parameter allows the list of models to be further reduced to those who contain all the provided tags.
1 parent ded7407 commit 18ec001

File tree

16 files changed

+177
-12
lines changed

16 files changed

+177
-12
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,9 @@ static Request getTrainedModels(GetTrainedModelsRequest getTrainedModelsRequest)
755755
params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
756756
Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
757757
}
758+
if (getTrainedModelsRequest.getTags() != null) {
759+
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
760+
}
758761
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
759762
request.addParameters(params.asMap());
760763
return request;

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.client.Validatable;
2323
import org.elasticsearch.client.ValidationException;
2424
import org.elasticsearch.client.core.PageParams;
25+
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
2526
import org.elasticsearch.common.Nullable;
2627

2728
import java.util.Arrays;
@@ -34,12 +35,14 @@ public class GetTrainedModelsRequest implements Validatable {
3435
public static final String ALLOW_NO_MATCH = "allow_no_match";
3536
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
3637
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
38+
public static final String TAGS = "tags";
3739

3840
private final List<String> ids;
3941
private Boolean allowNoMatch;
4042
private Boolean includeDefinition;
4143
private Boolean decompressDefinition;
4244
private PageParams pageParams;
45+
private List<String> tags;
4346

4447
/**
4548
* Helper method to create a request that will get ALL TrainedModelConfigs
@@ -111,6 +114,29 @@ public GetTrainedModelsRequest setDecompressDefinition(Boolean decompressDefinit
111114
return this;
112115
}
113116

117+
public List<String> getTags() {
118+
return tags;
119+
}
120+
121+
/**
122+
* The tags that the trained model must match. These correspond to {@link TrainedModelConfig#getTags()}.
123+
*
124+
* The models returned will match ALL tags supplied.
125+
* If none are provided, only the provided ids are used to find models
126+
* @param tags The tags to match when finding models
127+
*/
128+
public GetTrainedModelsRequest setTags(List<String> tags) {
129+
this.tags = tags;
130+
return this;
131+
}
132+
133+
/**
134+
* See {@link GetTrainedModelsRequest#setTags(List)}
135+
*/
136+
public GetTrainedModelsRequest setTags(String... tags) {
137+
return setTags(Arrays.asList(tags));
138+
}
139+
114140
@Override
115141
public Optional<ValidationException> validate() {
116142
if (ids == null || ids.isEmpty()) {

client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ public void testGetTrainedModels() {
834834
.setAllowNoMatch(false)
835835
.setDecompressDefinition(true)
836836
.setIncludeDefinition(false)
837+
.setTags("tag1", "tag2")
837838
.setPageParams(new PageParams(100, 300));
838839

839840
Request request = MLRequestConverters.getTrainedModels(getRequest);
@@ -845,6 +846,7 @@ public void testGetTrainedModels() {
845846
hasEntry("size", "300"),
846847
hasEntry("allow_no_match", "false"),
847848
hasEntry("decompress_definition", "true"),
849+
hasEntry("tags", "tag1,tag2"),
848850
hasEntry("include_model_definition", "false")
849851
));
850852
assertNull(request.getEntity());

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3587,8 +3587,10 @@ public void testGetTrainedModels() throws Exception {
35873587
.setPageParams(new PageParams(0, 1)) // <2>
35883588
.setIncludeDefinition(false) // <3>
35893589
.setDecompressDefinition(false) // <4>
3590-
.setAllowNoMatch(true); // <5>
3590+
.setAllowNoMatch(true) // <5>
3591+
.setTags("regression"); // <6>
35913592
// end::get-trained-models-request
3593+
request.setTags((List<String>)null);
35923594

35933595
// tag::get-trained-models-execute
35943596
GetTrainedModelsResponse response = client.machineLearning().getTrainedModels(request, RequestOptions.DEFAULT);

docs/java-rest/high-level/ml/get-trained-models.asciidoc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ include-tagged::{doc-tests-file}[{api}-request]
2929
<5> Allow empty response if no Trained Models match the provided ID patterns.
3030
If false, an error will be thrown if no Trained Models match the
3131
ID patterns.
32+
<6> An optional list of tags used to narrow the model search. A Trained Model
33+
can have many tags or none. The trained models in the response will
34+
contain all the provided tags.
3235

3336
include::../execution.asciidoc[]
3437

docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=include-model-definition]
7474
(Optional, integer)
7575
include::{docdir}/ml/ml-shared.asciidoc[tag=size]
7676

77+
`tags`::
78+
(Optional, string)
79+
include::{docdir}/ml/ml-shared.asciidoc[tag=tags]
7780

7881
[[ml-get-inference-response-codes]]
7982
==== {api-response-codes-title}
@@ -96,4 +99,4 @@ The following example gets configuration information for all the trained models:
9699
--------------------------------------------------
97100
GET _ml/inference/
98101
--------------------------------------------------
99-
// TEST[skip:TBD]
102+
// TEST[skip:TBD]

docs/reference/ml/ml-shared.asciidoc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,12 @@ to `false`. When `true`, only a single model must match the ID patterns
639639
provided, otherwise a bad request is returned.
640640
end::include-model-definition[]
641641

642+
tag::tags[]
643+
A comma delimited string of tags. A {infer} model can have many tags, or none.
644+
When supplied, only {infer} models that contain all the supplied tags are
645+
returned.
646+
end::tags[]
647+
642648
tag::indices[]
643649
An array of index names. Wildcards are supported. For example:
644650
`["it_ops_metrics", "server*"]`.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.action;
77

8+
import org.elasticsearch.Version;
89
import org.elasticsearch.action.ActionType;
910
import org.elasticsearch.common.ParseField;
1011
import org.elasticsearch.common.io.stream.StreamInput;
@@ -33,18 +34,26 @@ public static class Request extends AbstractGetResourcesRequest {
3334

3435
public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
3536
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
37+
public static final ParseField TAGS = new ParseField("tags");
3638

3739
private final boolean includeModelDefinition;
40+
private final List<String> tags;
3841

39-
public Request(String id, boolean includeModelDefinition) {
42+
public Request(String id, boolean includeModelDefinition, List<String> tags) {
4043
setResourceId(id);
4144
setAllowNoResources(true);
4245
this.includeModelDefinition = includeModelDefinition;
46+
this.tags = tags == null ? Collections.emptyList() : tags;
4347
}
4448

4549
public Request(StreamInput in) throws IOException {
4650
super(in);
4751
this.includeModelDefinition = in.readBoolean();
52+
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
53+
this.tags = in.readStringList();
54+
} else {
55+
this.tags = Collections.emptyList();
56+
}
4857
}
4958

5059
@Override
@@ -56,15 +65,22 @@ public boolean isIncludeModelDefinition() {
5665
return includeModelDefinition;
5766
}
5867

68+
public List<String> getTags() {
69+
return tags;
70+
}
71+
5972
@Override
6073
public void writeTo(StreamOutput out) throws IOException {
6174
super.writeTo(out);
6275
out.writeBoolean(includeModelDefinition);
76+
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
77+
out.writeStringCollection(tags);
78+
}
6379
}
6480

6581
@Override
6682
public int hashCode() {
67-
return Objects.hash(super.hashCode(), includeModelDefinition);
83+
return Objects.hash(super.hashCode(), includeModelDefinition, tags);
6884
}
6985

7086
@Override
@@ -76,7 +92,7 @@ public boolean equals(Object obj) {
7692
return false;
7793
}
7894
Request other = (Request) obj;
79-
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition;
95+
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags);
8096
}
8197
}
8298

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas
1414

1515
@Override
1616
protected Request createTestInstance() {
17-
Request request = new Request(randomAlphaOfLength(20), randomBoolean());
17+
Request request = new Request(randomAlphaOfLength(20),
18+
randomBoolean(),
19+
randomBoolean() ? null :
20+
randomList(10, () -> randomAlphaOfLength(10)));
1821
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
1922
return request;
2023
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
2121

2222
import java.util.Collections;
23+
import java.util.HashSet;
2324
import java.util.Set;
2425

2526

@@ -70,7 +71,11 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
7071
listener::onFailure
7172
);
7273

73-
provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener);
74+
provider.expandIds(request.getResourceId(),
75+
request.isAllowNoResources(),
76+
request.getPageParams(),
77+
new HashSet<>(request.getTags()),
78+
idExpansionListener);
7479
}
7580

7681
}

0 commit comments

Comments
 (0)