Skip to content

Commit 0a92f1d

Browse files
committed
[HLRC][ML] Add ML get model snapshots API (#35487)
Relates #29827
1 parent 209c165 commit 0a92f1d

File tree

12 files changed

+1037
-1
lines changed

12 files changed

+1037
-1
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.elasticsearch.client.ml.GetInfluencersRequest;
4444
import org.elasticsearch.client.ml.GetJobRequest;
4545
import org.elasticsearch.client.ml.GetJobStatsRequest;
46+
import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
4647
import org.elasticsearch.client.ml.GetOverallBucketsRequest;
4748
import org.elasticsearch.client.ml.GetRecordsRequest;
4849
import org.elasticsearch.client.ml.OpenJobRequest;
@@ -361,6 +362,19 @@ static Request getCategories(GetCategoriesRequest getCategoriesRequest) throws I
361362
return request;
362363
}
363364

365+
static Request getModelSnapshots(GetModelSnapshotsRequest getModelSnapshotsRequest) throws IOException {
366+
String endpoint = new EndpointBuilder()
367+
.addPathPartAsIs("_xpack")
368+
.addPathPartAsIs("ml")
369+
.addPathPartAsIs("anomaly_detectors")
370+
.addPathPart(getModelSnapshotsRequest.getJobId())
371+
.addPathPartAsIs("model_snapshots")
372+
.build();
373+
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
374+
request.setEntity(createEntity(getModelSnapshotsRequest, REQUEST_BODY_CONTENT_TYPE));
375+
return request;
376+
}
377+
364378
static Request getOverallBuckets(GetOverallBucketsRequest getOverallBucketsRequest) throws IOException {
365379
String endpoint = new EndpointBuilder()
366380
.addPathPartAsIs("_xpack")

client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
import org.elasticsearch.client.ml.GetJobResponse;
5050
import org.elasticsearch.client.ml.GetJobStatsRequest;
5151
import org.elasticsearch.client.ml.GetJobStatsResponse;
52+
import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
53+
import org.elasticsearch.client.ml.GetModelSnapshotsResponse;
5254
import org.elasticsearch.client.ml.GetOverallBucketsRequest;
5355
import org.elasticsearch.client.ml.GetOverallBucketsResponse;
5456
import org.elasticsearch.client.ml.GetRecordsRequest;
@@ -897,6 +899,46 @@ public void getCategoriesAsync(GetCategoriesRequest request, RequestOptions opti
897899
Collections.emptySet());
898900
}
899901

902+
/**
903+
* Gets the snapshots for a Machine Learning Job.
904+
* <p>
905+
* For additional info
906+
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-get-snapshot.html">
907+
* ML GET model snapshots documentation</a>
908+
*
909+
* @param request The request
910+
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
911+
* @throws IOException when there is a serialization issue sending the request or receiving the response
912+
*/
913+
public GetModelSnapshotsResponse getModelSnapshots(GetModelSnapshotsRequest request, RequestOptions options) throws IOException {
914+
return restHighLevelClient.performRequestAndParseEntity(request,
915+
MLRequestConverters::getModelSnapshots,
916+
options,
917+
GetModelSnapshotsResponse::fromXContent,
918+
Collections.emptySet());
919+
}
920+
921+
/**
922+
* Gets the snapshots for a Machine Learning Job, notifies listener once the requested snapshots are retrieved.
923+
* <p>
924+
* For additional info
925+
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-get-snapshot.html">
926+
* ML GET model snapshots documentation</a>
927+
*
928+
* @param request The request
929+
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
930+
* @param listener Listener to be notified upon request completion
931+
*/
932+
public void getModelSnapshotsAsync(GetModelSnapshotsRequest request, RequestOptions options,
933+
ActionListener<GetModelSnapshotsResponse> listener) {
934+
restHighLevelClient.performRequestAsyncAndParseEntity(request,
935+
MLRequestConverters::getModelSnapshots,
936+
options,
937+
GetModelSnapshotsResponse::fromXContent,
938+
listener,
939+
Collections.emptySet());
940+
}
941+
900942
/**
901943
* Gets overall buckets for a set of Machine Learning Jobs.
902944
* <p>
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml;
20+
21+
import org.elasticsearch.action.ActionRequest;
22+
import org.elasticsearch.action.ActionRequestValidationException;
23+
import org.elasticsearch.client.ml.job.config.Job;
24+
import org.elasticsearch.client.ml.job.util.PageParams;
25+
import org.elasticsearch.common.ParseField;
26+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
27+
import org.elasticsearch.common.xcontent.ToXContentObject;
28+
import org.elasticsearch.common.xcontent.XContentBuilder;
29+
30+
import java.io.IOException;
31+
import java.util.Objects;
32+
33+
/**
34+
* A request to retrieve information about model snapshots for a given job
35+
*/
36+
public class GetModelSnapshotsRequest extends ActionRequest implements ToXContentObject {
37+
38+
39+
public static final ParseField SNAPSHOT_ID = new ParseField("snapshot_id");
40+
public static final ParseField SORT = new ParseField("sort");
41+
public static final ParseField START = new ParseField("start");
42+
public static final ParseField END = new ParseField("end");
43+
public static final ParseField DESC = new ParseField("desc");
44+
45+
public static final ConstructingObjectParser<GetModelSnapshotsRequest, Void> PARSER = new ConstructingObjectParser<>(
46+
"get_model_snapshots_request", a -> new GetModelSnapshotsRequest((String) a[0]));
47+
48+
49+
static {
50+
PARSER.declareString(ConstructingObjectParser.constructorArg(), Job.ID);
51+
PARSER.declareString(GetModelSnapshotsRequest::setSnapshotId, SNAPSHOT_ID);
52+
PARSER.declareString(GetModelSnapshotsRequest::setSort, SORT);
53+
PARSER.declareStringOrNull(GetModelSnapshotsRequest::setStart, START);
54+
PARSER.declareStringOrNull(GetModelSnapshotsRequest::setEnd, END);
55+
PARSER.declareBoolean(GetModelSnapshotsRequest::setDesc, DESC);
56+
PARSER.declareObject(GetModelSnapshotsRequest::setPageParams, PageParams.PARSER, PageParams.PAGE);
57+
}
58+
59+
private final String jobId;
60+
private String snapshotId;
61+
private String sort;
62+
private String start;
63+
private String end;
64+
private Boolean desc;
65+
private PageParams pageParams;
66+
67+
/**
68+
* Constructs a request to retrieve snapshot information from a given job
69+
* @param jobId id of the job from which to retrieve results
70+
*/
71+
public GetModelSnapshotsRequest(String jobId) {
72+
this.jobId = Objects.requireNonNull(jobId);
73+
}
74+
75+
public String getJobId() {
76+
return jobId;
77+
}
78+
79+
public String getSnapshotId() {
80+
return snapshotId;
81+
}
82+
83+
/**
84+
* Sets the id of the snapshot to retrieve.
85+
* @param snapshotId the snapshot id
86+
*/
87+
public void setSnapshotId(String snapshotId) {
88+
this.snapshotId = snapshotId;
89+
}
90+
91+
public String getSort() {
92+
return sort;
93+
}
94+
95+
/**
96+
* Sets the value of "sort".
97+
* Specifies the snapshot field to sort on.
98+
* @param sort value of "sort".
99+
*/
100+
public void setSort(String sort) {
101+
this.sort = sort;
102+
}
103+
104+
public PageParams getPageParams() {
105+
return pageParams;
106+
}
107+
108+
/**
109+
* Sets the paging parameters
110+
* @param pageParams the paging parameters
111+
*/
112+
public void setPageParams(PageParams pageParams) {
113+
this.pageParams = pageParams;
114+
}
115+
116+
public String getStart() {
117+
return start;
118+
}
119+
120+
/**
121+
* Sets the value of "start" which is a timestamp.
122+
* Only snapshots whose timestamp is on or after the "start" value will be returned.
123+
* @param start String representation of a timestamp; may be an epoch seconds, epoch millis or an ISO string
124+
*/
125+
public void setStart(String start) {
126+
this.start = start;
127+
}
128+
129+
130+
public String getEnd() {
131+
return end;
132+
}
133+
134+
/**
135+
* Sets the value of "end" which is a timestamp.
136+
* Only snapshots whose timestamp is before the "end" value will be returned.
137+
* @param end String representation of a timestamp; may be an epoch seconds, epoch millis or an ISO string
138+
*/
139+
public void setEnd(String end) {
140+
this.end = end;
141+
}
142+
143+
public Boolean getDesc() {
144+
return desc;
145+
}
146+
147+
/**
148+
* Sets the value of "desc".
149+
* Specifies the sorting order.
150+
* @param desc value of "desc"
151+
*/
152+
public void setDesc(boolean desc) {
153+
this.desc = desc;
154+
}
155+
156+
@Override
157+
public ActionRequestValidationException validate() {
158+
return null;
159+
}
160+
161+
@Override
162+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
163+
builder.startObject();
164+
builder.field(Job.ID.getPreferredName(), jobId);
165+
if (snapshotId != null) {
166+
builder.field(SNAPSHOT_ID.getPreferredName(), snapshotId);
167+
}
168+
if (sort != null) {
169+
builder.field(SORT.getPreferredName(), sort);
170+
}
171+
if (start != null) {
172+
builder.field(START.getPreferredName(), start);
173+
}
174+
if (end != null) {
175+
builder.field(END.getPreferredName(), end);
176+
}
177+
if (desc != null) {
178+
builder.field(DESC.getPreferredName(), desc);
179+
}
180+
if (pageParams != null) {
181+
builder.field(PageParams.PAGE.getPreferredName(), pageParams);
182+
} builder.endObject();
183+
return builder;
184+
}
185+
186+
@Override
187+
public boolean equals(Object obj) {
188+
if (obj == null) {
189+
return false;
190+
}
191+
if (getClass() != obj.getClass()) {
192+
return false;
193+
}
194+
GetModelSnapshotsRequest request = (GetModelSnapshotsRequest) obj;
195+
return Objects.equals(jobId, request.jobId)
196+
&& Objects.equals(snapshotId, request.snapshotId)
197+
&& Objects.equals(sort, request.sort)
198+
&& Objects.equals(start, request.start)
199+
&& Objects.equals(end, request.end)
200+
&& Objects.equals(desc, request.desc)
201+
&& Objects.equals(pageParams, request.pageParams);
202+
}
203+
204+
@Override
205+
public int hashCode() {
206+
return Objects.hash(jobId, snapshotId, pageParams, start, end, sort, desc);
207+
}
208+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml;
20+
21+
import org.elasticsearch.client.ml.job.process.ModelSnapshot;
22+
import org.elasticsearch.common.ParseField;
23+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
24+
import org.elasticsearch.common.xcontent.XContentParser;
25+
26+
import java.io.IOException;
27+
import java.util.List;
28+
import java.util.Objects;
29+
import java.util.stream.Collectors;
30+
31+
/**
32+
* A response containing the requested snapshots
33+
*/
34+
public class GetModelSnapshotsResponse extends AbstractResultResponse<ModelSnapshot> {
35+
36+
public static final ParseField SNAPSHOTS = new ParseField("model_snapshots");
37+
38+
@SuppressWarnings("unchecked")
39+
public static final ConstructingObjectParser<GetModelSnapshotsResponse, Void> PARSER =
40+
new ConstructingObjectParser<>("get_model_snapshots_response", true,
41+
a -> new GetModelSnapshotsResponse((List<ModelSnapshot.Builder>) a[0], (long) a[1]));
42+
43+
static {
44+
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), ModelSnapshot.PARSER, SNAPSHOTS);
45+
PARSER.declareLong(ConstructingObjectParser.constructorArg(), COUNT);
46+
}
47+
48+
public static GetModelSnapshotsResponse fromXContent(XContentParser parser) throws IOException {
49+
return PARSER.parse(parser, null);
50+
}
51+
52+
GetModelSnapshotsResponse(List<ModelSnapshot.Builder> snapshotBuilders, long count) {
53+
super(SNAPSHOTS, snapshotBuilders.stream().map(ModelSnapshot.Builder::build).collect(Collectors.toList()), count);
54+
}
55+
56+
/**
57+
* The retrieved snapshots
58+
* @return the retrieved snapshots
59+
*/
60+
public List<ModelSnapshot> snapshots() {
61+
return results;
62+
}
63+
64+
@Override
65+
public int hashCode() {
66+
return Objects.hash(count, results);
67+
}
68+
69+
@Override
70+
public boolean equals(Object obj) {
71+
if (obj == null) {
72+
return false;
73+
}
74+
if (getClass() != obj.getClass()) {
75+
return false;
76+
}
77+
GetModelSnapshotsResponse other = (GetModelSnapshotsResponse) obj;
78+
return count == other.count && Objects.equals(results, other.results);
79+
}
80+
}

0 commit comments

Comments
 (0)