Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML: Add support for single bucket aggs in Datafeeds #37544

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,44 @@ public void testLookbackWithoutPermissionsAndRollup() throws Exception {
"action [indices:admin/xpack/rollup/search] is unauthorized for user [ml_admin_plus_data]\""));
}

public void testLookbackWithSingleBucketAgg() throws Exception {
String jobId = "aggs-date-histogram-with-single-bucket-agg-job";
Request createJobRequest = new Request("PUT", MachineLearning.BASE_PATH + "anomaly_detectors/" + jobId);
createJobRequest.setJsonEntity("{\n"
+ " \"description\": \"Aggs job\",\n"
+ " \"analysis_config\": {\n"
+ " \"bucket_span\": \"3600s\",\n"
+ " \"summary_count_field_name\": \"doc_count\",\n"
+ " \"detectors\": [\n"
+ " {\n"
+ " \"function\": \"mean\",\n"
+ " \"field_name\": \"responsetime\""
+ " }\n"
+ " ]\n"
+ " },\n"
+ " \"data_description\": {\"time_field\": \"time stamp\"}\n"
+ "}");
client().performRequest(createJobRequest);

String datafeedId = "datafeed-" + jobId;
String aggregations = "{\"time stamp\":{\"date_histogram\":{\"field\":\"time stamp\",\"interval\":\"1h\"},"
+ "\"aggregations\":{"
+ "\"time stamp\":{\"max\":{\"field\":\"time stamp\"}},"
+ "\"airlineFilter\":{\"filter\":{\"term\": {\"airline\":\"AAA\"}},"
+ " \"aggregations\":{\"responsetime\":{\"avg\":{\"field\":\"responsetime\"}}}}}}}";
new DatafeedBuilder(datafeedId, jobId, "airline-data-aggs", "response").setAggregations(aggregations).build();
openJob(client(), jobId);

startDatafeedAndWaitUntilStopped(datafeedId);
waitUntilJobIsClosed(jobId);
Response jobStatsResponse = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "anomaly_detectors/" + jobId + "/_stats"));
String jobStatsResponseAsString = EntityUtils.toString(jobStatsResponse.getEntity());
assertThat(jobStatsResponseAsString, containsString("\"input_record_count\":2"));
assertThat(jobStatsResponseAsString, containsString("\"processed_record_count\":2"));
assertThat(jobStatsResponseAsString, containsString("\"missing_field_count\":0"));
}

public void testRealtime() throws Exception {
String jobId = "job-realtime-1";
createJob(jobId, "airline");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregation;
import org.elasticsearch.search.aggregations.bucket.histogram.Histogram;
import org.elasticsearch.search.aggregations.metrics.Max;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
Expand Down Expand Up @@ -93,18 +94,30 @@ private void processAggs(long docCount, List<Aggregation> aggregations) throws I

List<Aggregation> leafAggregations = new ArrayList<>();
List<MultiBucketsAggregation> bucketAggregations = new ArrayList<>();
List<SingleBucketAggregation> singleBucketAggregations = new ArrayList<>();

// Sort into leaf and bucket aggregations.
// The leaf aggregations will be processed first.
for (Aggregation agg : aggregations) {
if (agg instanceof MultiBucketsAggregation) {
bucketAggregations.add((MultiBucketsAggregation)agg);
} else if (agg instanceof SingleBucketAggregation){
singleBucketAggregations.add((SingleBucketAggregation)agg);
} else {
leafAggregations.add(agg);
}
}

if (bucketAggregations.size() > 1) {
// If on the current level (indicated via bucketAggregations) or on of the next levels (singleBucketAggregations)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: "or on of the next levels" needs correction

// we have more than 1 `MultiBucketsAggregation`, we should error out.
// We need to make the check in this way as each of the items in `singleBucketAggregations` is treated as a separate branch
// in the recursive handling of this method.
int bucketAggLevelCount = Math.max(bucketAggregations.size(), (int)singleBucketAggregations.stream()
davidkyle marked this conversation as resolved.
Show resolved Hide resolved
.flatMap(s -> asList(s.getAggregations()).stream())
.filter(MultiBucketsAggregation.class::isInstance)
.count());

if (bucketAggLevelCount > 1) {
throw new IllegalArgumentException("Multiple bucket aggregations at the same level are not supported");
}

Expand Down Expand Up @@ -137,6 +150,12 @@ private void processAggs(long docCount, List<Aggregation> aggregations) throws I
}
}
}
noMoreBucketsToProcess = singleBucketAggregations.isEmpty() && noMoreBucketsToProcess;
// we support more than one `SingleBucketAggregation` at each level, each agg needs to be handled
// recursively.
for (SingleBucketAggregation singleBucketAggregation : singleBucketAggregations) {
processAggs(singleBucketAggregation.getDocCount(), asList(singleBucketAggregation.getAggregations()));
}

// If there are no more bucket aggregations to process we've reached the end
// and it's time to write the doc
Expand Down