Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HLRC: Add ML get categories API #33465

Merged
merged 6 commits into from
Sep 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.client.ml.FlushJobRequest;
import org.elasticsearch.client.ml.ForecastJobRequest;
import org.elasticsearch.client.ml.GetBucketsRequest;
import org.elasticsearch.client.ml.GetCategoriesRequest;
import org.elasticsearch.client.ml.GetInfluencersRequest;
import org.elasticsearch.client.ml.GetJobRequest;
import org.elasticsearch.client.ml.GetJobStatsRequest;
Expand Down Expand Up @@ -194,6 +195,20 @@ static Request getBuckets(GetBucketsRequest getBucketsRequest) throws IOExceptio
return request;
}

static Request getCategories(GetCategoriesRequest getCategoriesRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_xpack")
.addPathPartAsIs("ml")
.addPathPartAsIs("anomaly_detectors")
.addPathPart(getCategoriesRequest.getJobId())
.addPathPartAsIs("results")
.addPathPartAsIs("categories")
.build();
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
request.setEntity(createEntity(getCategoriesRequest, REQUEST_BODY_CONTENT_TYPE));
return request;
}

static Request getOverallBuckets(GetOverallBucketsRequest getOverallBucketsRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_xpack")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import org.elasticsearch.client.ml.FlushJobResponse;
import org.elasticsearch.client.ml.GetBucketsRequest;
import org.elasticsearch.client.ml.GetBucketsResponse;
import org.elasticsearch.client.ml.GetCategoriesRequest;
import org.elasticsearch.client.ml.GetCategoriesResponse;
import org.elasticsearch.client.ml.GetInfluencersRequest;
import org.elasticsearch.client.ml.GetInfluencersResponse;
import org.elasticsearch.client.ml.GetJobRequest;
Expand Down Expand Up @@ -474,6 +476,45 @@ public void getBucketsAsync(GetBucketsRequest request, RequestOptions options, A
Collections.emptySet());
}

/**
* Gets the categories for a Machine Learning Job.
* <p>
* For additional info
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-get-category.html">
* ML GET categories documentation</a>
*
* @param request The request
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know some methods add the @throws and others don't. I think we should still add the @throws entry even if it is a checked exception.

I have not been consistent about this myself either :(.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dimitris-athanasiou may have a different opinion around this ^, So I will defer to him.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right @benwtrent, we've been dropping the @throws clause in some of the methods in the client. We'll need to revisit and add them. I'll make a note to do that.

* @throws IOException when there is a serialization issue sending the request or receiving the response
*/
public GetCategoriesResponse getCategories(GetCategoriesRequest request, RequestOptions options) throws IOException {
return restHighLevelClient.performRequestAndParseEntity(request,
MLRequestConverters::getCategories,
options,
GetCategoriesResponse::fromXContent,
Collections.emptySet());
}

/**
* Gets the categories for a Machine Learning Job, notifies listener once the requested buckets are retrieved.
* <p>
* For additional info
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-get-category.html">
* ML GET categories documentation</a>
*
* @param request The request
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
* @param listener Listener to be notified upon request completion
*/
public void getCategoriesAsync(GetCategoriesRequest request, RequestOptions options, ActionListener<GetCategoriesResponse> listener) {
restHighLevelClient.performRequestAsyncAndParseEntity(request,
MLRequestConverters::getCategories,
options,
GetCategoriesResponse::fromXContent,
listener,
Collections.emptySet());
}

/**
* Gets overall buckets for a set of Machine Learning Jobs.
* <p>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;

import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.client.ml.job.config.Job;
import org.elasticsearch.client.ml.job.util.PageParams;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

/**
* A request to retrieve categories of a given job
*/
public class GetCategoriesRequest extends ActionRequest implements ToXContentObject {


public static final ParseField CATEGORY_ID = new ParseField("category_id");

public static final ConstructingObjectParser<GetCategoriesRequest, Void> PARSER = new ConstructingObjectParser<>(
"get_categories_request", a -> new GetCategoriesRequest((String) a[0]));


static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), Job.ID);
PARSER.declareLong(GetCategoriesRequest::setCategoryId, CATEGORY_ID);
PARSER.declareObject(GetCategoriesRequest::setPageParams, PageParams.PARSER, PageParams.PAGE);
}

private final String jobId;
private Long categoryId;
private PageParams pageParams;

/**
* Constructs a request to retrieve category information from a given job
* @param jobId id of the job from which to retrieve results
*/
public GetCategoriesRequest(String jobId) {
this.jobId = Objects.requireNonNull(jobId);
}

public String getJobId() {
return jobId;
}

public PageParams getPageParams() {
return pageParams;
}

public Long getCategoryId() {
return categoryId;
}

/**
* Sets the category id
* @param categoryId the category id
*/
public void setCategoryId(Long categoryId) {
this.categoryId = categoryId;
}

/**
* Sets the paging parameters
* @param pageParams the paging parameters
*/
public void setPageParams(PageParams pageParams) {
this.pageParams = pageParams;
}

@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(Job.ID.getPreferredName(), jobId);
if (categoryId != null) {
builder.field(CATEGORY_ID.getPreferredName(), categoryId);
}
if (pageParams != null) {
builder.field(PageParams.PAGE.getPreferredName(), pageParams);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
GetCategoriesRequest request = (GetCategoriesRequest) obj;
return Objects.equals(jobId, request.jobId)
&& Objects.equals(categoryId, request.categoryId)
&& Objects.equals(pageParams, request.pageParams);
}

@Override
public int hashCode() {
return Objects.hash(jobId, categoryId, pageParams);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;

import org.elasticsearch.client.ml.job.results.CategoryDefinition;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

/**
* A response containing the requested categories
*/
public class GetCategoriesResponse extends AbstractResultResponse<CategoryDefinition> {

public static final ParseField CATEGORIES = new ParseField("categories");

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<GetCategoriesResponse, Void> PARSER =
new ConstructingObjectParser<>("get_categories_response", true,
a -> new GetCategoriesResponse((List<CategoryDefinition>) a[0], (long) a[1]));

static {
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), CategoryDefinition.PARSER, CATEGORIES);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), COUNT);
}

public static GetCategoriesResponse fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}

GetCategoriesResponse(List<CategoryDefinition> categories, long count) {
super(CATEGORIES, categories, count);
}

/**
* The retrieved categories
* @return the retrieved categories
*/
public List<CategoryDefinition> categories() {
return results;
}

@Override
public int hashCode() {
return Objects.hash(count, results);
}

@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
GetCategoriesResponse other = (GetCategoriesResponse) obj;
return count == other.count && Objects.equals(results, other.results);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.client.ml.FlushJobRequest;
import org.elasticsearch.client.ml.ForecastJobRequest;
import org.elasticsearch.client.ml.GetBucketsRequest;
import org.elasticsearch.client.ml.GetCategoriesRequest;
import org.elasticsearch.client.ml.GetInfluencersRequest;
import org.elasticsearch.client.ml.GetJobRequest;
import org.elasticsearch.client.ml.GetJobStatsRequest;
Expand Down Expand Up @@ -220,6 +221,21 @@ public void testGetBuckets() throws IOException {
}
}

public void testGetCategories() throws IOException {
String jobId = randomAlphaOfLength(10);
GetCategoriesRequest getCategoriesRequest = new GetCategoriesRequest(jobId);
getCategoriesRequest.setPageParams(new PageParams(100, 300));


Request request = MLRequestConverters.getCategories(getCategoriesRequest);
assertEquals(HttpGet.METHOD_NAME, request.getMethod());
assertEquals("/_xpack/ml/anomaly_detectors/" + jobId + "/results/categories", request.getEndpoint());
try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) {
GetCategoriesRequest parsedRequest = GetCategoriesRequest.PARSER.apply(parser, null);
assertThat(parsedRequest, equalTo(getCategoriesRequest));
}
}

public void testGetOverallBuckets() throws IOException {
String jobId = randomAlphaOfLength(10);
GetOverallBucketsRequest getOverallBucketsRequest = new GetOverallBucketsRequest(jobId);
Expand Down
Loading