diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java index b3b2a3b6666a4..7f8486223928a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java @@ -105,28 +105,31 @@ public String[] getIndices() { return indices; } - public final void setIndices(List indices) { + public final Request setIndices(List indices) { ExceptionsHelper.requireNonNull(indices, INDEX); if (indices.isEmpty()) { throw ExceptionsHelper.badRequestException("At least one index must be specified"); } this.indices = indices.toArray(new String[indices.size()]); + return this; } public QueryBuilder getParsedQuery() { return Optional.ofNullable(queryProvider).orElseGet(QueryProvider::defaultQuery).getParsedQuery(); } - public final void setQueryProvider(QueryProvider queryProvider) { + public final Request setQueryProvider(QueryProvider queryProvider) { this.queryProvider = queryProvider; + return this; } public Evaluation getEvaluation() { return evaluation; } - public final void setEvaluation(Evaluation evaluation) { + public final Request setEvaluation(Evaluation evaluation) { this.evaluation = ExceptionsHelper.requireNonNull(evaluation, EVALUATION); + return this; } @Override @@ -203,6 +206,14 @@ public Response(String evaluationName, List metrics) { this.metrics = Objects.requireNonNull(metrics); } + public String getEvaluationName() { + return evaluationName; + } + + public List getMetrics() { + return metrics; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(evaluationName); @@ -214,7 +225,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.startObject(evaluationName); for (EvaluationMetricResult metric : metrics) { - builder.field(metric.getName(), metric); + builder.field(metric.getMetricName(), metric); } builder.endObject(); builder.endObject(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java index 70f31273aba16..663e1ba639adf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java @@ -5,14 +5,17 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.builder.SearchSourceBuilder; import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; /** * Defines an evaluation @@ -24,16 +27,54 @@ public interface Evaluation extends ToXContentObject, NamedWriteable { */ String getName(); + /** + * Returns the list of metrics to evaluate + * @return list of metrics to evaluate + */ + List getMetrics(); + /** * Builds the search required to collect data to compute the evaluation result - * @param queryBuilder User-provided query that must be respected when collecting data + * @param userProvidedQueryBuilder User-provided query that must be respected when collecting data + */ + SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder); + + /** + * Builds the search that verifies existence of required fields and applies user-provided query + * @param requiredFields fields that must exist + * @param userProvidedQueryBuilder user-provided query + */ + default SearchSourceBuilder newSearchSourceBuilder(List requiredFields, QueryBuilder userProvidedQueryBuilder) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + for (String requiredField : requiredFields) { + boolQuery.filter(QueryBuilders.existsQuery(requiredField)); + } + boolQuery.filter(userProvidedQueryBuilder); + return new SearchSourceBuilder().size(0).query(boolQuery); + } + + /** + * Processes {@link SearchResponse} from the search action + * @param searchResponse response from the search action + */ + void process(SearchResponse searchResponse); + + /** + * @return true iff all the metrics have their results computed */ - SearchSourceBuilder buildSearch(QueryBuilder queryBuilder); + default boolean hasAllResults() { + return getMetrics().stream().map(EvaluationMetric::getResult).allMatch(Optional::isPresent); + } /** - * Computes the evaluation result - * @param searchResponse The search response required to compute the result - * @param listener A listener of the results + * Returns the list of evaluation results + * @return list of evaluation results */ - void evaluate(SearchResponse searchResponse, ActionListener> listener); + default List getResults() { + return getMetrics().stream() + .map(EvaluationMetric::getResult) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toList()); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java new file mode 100644 index 0000000000000..54934b64652c0 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; + +import java.util.Optional; + +/** + * {@link EvaluationMetric} class represents a metric to evaluate. + */ +public interface EvaluationMetric extends ToXContentObject, NamedWriteable { + + /** + * Returns the name of the metric (which may differ to the writeable name) + */ + String getName(); + + /** + * Gets the evaluation result for this metric. + * @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise + */ + Optional getResult(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java index 36b8adf9d4ea3..06c7719a401a9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java @@ -14,7 +14,7 @@ public interface EvaluationMetricResult extends ToXContentObject, NamedWriteable { /** - * Returns the name of the metric + * Returns the name of the metric (which may differ to the writeable name) */ - String getName(); + String getMetricName(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java index e48cb46b5c0a3..dc8de45f7bce7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java @@ -20,10 +20,12 @@ import java.io.IOException; import java.text.MessageFormat; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Objects; +import java.util.Optional; /** * Calculates the mean squared error between two known numerical fields. @@ -48,28 +50,34 @@ public static MeanSquaredError fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - public MeanSquaredError(StreamInput in) { + private EvaluationMetricResult result; - } - - public MeanSquaredError() { + public MeanSquaredError(StreamInput in) {} - } + public MeanSquaredError() {} @Override - public String getMetricName() { + public String getName() { return NAME.getPreferredName(); } @Override public List aggs(String actualField, String predictedField) { - return Collections.singletonList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))); + if (result != null) { + return Collections.emptyList(); + } + return Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))); } @Override - public EvaluationMetricResult evaluate(Aggregations aggs) { + public void process(Aggregations aggs) { NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME); - return value == null ? new Result(0.0) : new Result(value.value()); + result = value == null ? new Result(0.0) : new Result(value.value()); + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result); } @Override @@ -121,7 +129,7 @@ public String getWriteableName() { } @Override - public String getName() { + public String getMetricName() { return NAME.getPreferredName(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java index a55306561833d..9307d5ae0ae46 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java @@ -23,9 +23,11 @@ import java.io.IOException; import java.text.MessageFormat; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Objects; +import java.util.Optional; /** * Calculates R-Squared between two known numerical fields. @@ -53,36 +55,42 @@ public static RSquared fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - public RSquared(StreamInput in) { + private EvaluationMetricResult result; - } - - public RSquared() { + public RSquared(StreamInput in) {} - } + public RSquared() {} @Override - public String getMetricName() { + public String getName() { return NAME.getPreferredName(); } @Override public List aggs(String actualField, String predictedField) { + if (result != null) { + return Collections.emptyList(); + } return Arrays.asList( AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))), AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField)); } @Override - public EvaluationMetricResult evaluate(Aggregations aggs) { + public void process(Aggregations aggs) { NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES); ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual"); // extendedStats.getVariance() is the statistical sumOfSquares divided by count - return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ? + result = residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ? new Result(0.0) : new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount()))); } + @Override + public Optional getResult() { + return Optional.ofNullable(result); + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -132,7 +140,7 @@ public String getWriteableName() { } @Override - public String getName() { + public String getMetricName() { return NAME.getPreferredName(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java index bb2540a8691b7..4741a033ae530 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; @@ -14,17 +13,15 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.List; @@ -86,19 +83,16 @@ public Regression(StreamInput in) throws IOException { } private static List initMetrics(@Nullable List parsedMetrics) { - List metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics; + List metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics); if (metrics.isEmpty()) { throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName()); } - Collections.sort(metrics, Comparator.comparing(RegressionMetric::getMetricName)); + Collections.sort(metrics, Comparator.comparing(RegressionMetric::getName)); return metrics; } private static List defaultMetrics() { - List defaultMetrics = new ArrayList<>(2); - defaultMetrics.add(new MeanSquaredError()); - defaultMetrics.add(new RSquared()); - return defaultMetrics; + return Arrays.asList(new MeanSquaredError(), new RSquared()); } @Override @@ -107,12 +101,15 @@ public String getName() { } @Override - public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) { - BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() - .filter(QueryBuilders.existsQuery(actualField)) - .filter(QueryBuilders.existsQuery(predictedField)) - .filter(queryBuilder); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery); + public List getMetrics() { + return metrics; + } + + @Override + public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { + ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder"); + SearchSourceBuilder searchSourceBuilder = + newSearchSourceBuilder(Arrays.asList(actualField, predictedField), userProvidedQueryBuilder); for (RegressionMetric metric : metrics) { List aggs = metric.aggs(actualField, predictedField); aggs.forEach(searchSourceBuilder::aggregation); @@ -121,18 +118,14 @@ public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) { } @Override - public void evaluate(SearchResponse searchResponse, ActionListener> listener) { - List results = new ArrayList<>(metrics.size()); + public void process(SearchResponse searchResponse) { + ExceptionsHelper.requireNonNull(searchResponse, "searchResponse"); if (searchResponse.getHits().getTotalHits().value == 0) { - listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", - actualField, - predictedField)); - return; + throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField); } for (RegressionMetric metric : metrics) { - results.add(metric.evaluate(searchResponse.getAggregations())); + metric.process(searchResponse.getAggregations()); } - listener.onResponse(results); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java index 1da48e2f305e6..08dfbfab4aa75 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java @@ -5,20 +5,14 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; -import org.elasticsearch.common.io.stream.NamedWriteable; -import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.Aggregations; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import java.util.List; -public interface RegressionMetric extends ToXContentObject, NamedWriteable { - - /** - * Returns the name of the metric (which may differ to the writeable name) - */ - String getMetricName(); +public interface RegressionMetric extends EvaluationMetric { /** * Builds the aggregation that collect required data to compute the metric @@ -29,9 +23,8 @@ public interface RegressionMetric extends ToXContentObject, NamedWriteable { List aggs(String actualField, String predictedField); /** - * Calculates the metric result - * @param aggs the aggregations - * @return the metric result + * Processes given aggregations as a step towards computing result + * @param aggs aggregations from {@link SearchResponse} */ - EvaluationMetricResult evaluate(Aggregations aggs); + void process(Aggregations aggs); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java index facdcceea194f..45faec8512dc6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java @@ -13,27 +13,31 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Optional; abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric { public static final ParseField AT = new ParseField("at"); protected final double[] thresholds; + private EvaluationMetricResult result; protected AbstractConfusionMatrixMetric(double[] thresholds) { this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT); if (thresholds.length == 0) { - throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName() - + "] must have at least one value"); + throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] must have at least one value"); } for (double threshold : thresholds) { if (threshold < 0 || threshold > 1.0) { - throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName() + throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] values must be in [0.0, 1.0]"); } } @@ -58,6 +62,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public final List aggs(String actualField, List classInfos) { + if (result != null) { + return Collections.emptyList(); + } List aggs = new ArrayList<>(); for (double threshold : thresholds) { aggs.addAll(aggsAt(actualField, classInfos, threshold)); @@ -65,14 +72,26 @@ public final List aggs(String actualField, List c return aggs; } + @Override + public void process(ClassInfo classInfo, Aggregations aggs) { + result = evaluate(classInfo, aggs); + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result); + } + protected abstract List aggsAt(String labelField, List classInfos, double threshold); + protected abstract EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs); + protected enum Condition { TP, FP, TN, FN; } protected String aggName(ClassInfo classInfo, double threshold, Condition condition) { - return getMetricName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name(); + return getName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name(); } protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java index 228dac00bfb68..7f126b1ec2da4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java @@ -30,6 +30,7 @@ import java.util.Comparator; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.stream.IntStream; /** @@ -70,6 +71,7 @@ public static AucRoc fromXContent(XContentParser parser) { } private final boolean includeCurve; + private EvaluationMetricResult result; public AucRoc(Boolean includeCurve) { this.includeCurve = includeCurve == null ? false : includeCurve; @@ -98,7 +100,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - public String getMetricName() { + public String getName() { return NAME.getPreferredName(); } @@ -117,6 +119,9 @@ public int hashCode() { @Override public List aggs(String actualField, List classInfos) { + if (result != null) { + return Collections.emptyList(); + } double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray(); List aggs = new ArrayList<>(); for (ClassInfo classInfo : classInfos) { @@ -134,22 +139,31 @@ public List aggs(String actualField, List classIn return aggs; } + @Override + public void process(ClassInfo classInfo, Aggregations aggs) { + result = evaluate(classInfo, aggs); + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result); + } + private String evaluatedLabelAggName(ClassInfo classInfo) { - return getMetricName() + "_" + classInfo.getName(); + return getName() + "_" + classInfo.getName(); } private String restLabelsAggName(ClassInfo classInfo) { - return getMetricName() + "_non_" + classInfo.getName(); + return getName() + "_non_" + classInfo.getName(); } - @Override - public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + private EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo)); Filter restAgg = aggs.get(restLabelsAggName(classInfo)); double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES), - "[" + getMetricName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]"); + "[" + getName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]"); double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES), - "[" + getMetricName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]"); + "[" + getName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]"); List aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); double aucRocScore = calculateAucScore(aucRocCurve); return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()); @@ -326,7 +340,7 @@ public String getWriteableName() { } @Override - public String getName() { + public String getMetricName() { return NAME.getPreferredName(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java index 20731eba5e83f..386919edec87b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; @@ -14,18 +13,14 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.aggregations.AggregationBuilder; -import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; @@ -87,17 +82,16 @@ private static List initMetrics(@Nullable List defaultMetrics() { - List defaultMetrics = new ArrayList<>(4); - defaultMetrics.add(new AucRoc(false)); - defaultMetrics.add(new Precision(Arrays.asList(0.25, 0.5, 0.75))); - defaultMetrics.add(new Recall(Arrays.asList(0.25, 0.5, 0.75))); - defaultMetrics.add(new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75))); - return defaultMetrics; + return Arrays.asList( + new AucRoc(false), + new Precision(Arrays.asList(0.25, 0.5, 0.75)), + new Recall(Arrays.asList(0.25, 0.5, 0.75)), + new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75))); } public BinarySoftClassification(StreamInput in) throws IOException { @@ -126,7 +120,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(METRICS.getPreferredName()); for (SoftClassificationMetric metric : metrics) { - builder.field(metric.getMetricName(), metric); + builder.field(metric.getName(), metric); } builder.endObject(); @@ -155,34 +149,34 @@ public String getName() { } @Override - public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) { - BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() - .filter(QueryBuilders.existsQuery(actualField)) - .filter(QueryBuilders.existsQuery(predictedProbabilityField)) - .filter(queryBuilder); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery); + public List getMetrics() { + return metrics; + } + + @Override + public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { + ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder"); + SearchSourceBuilder searchSourceBuilder = + newSearchSourceBuilder(Arrays.asList(actualField, predictedProbabilityField), userProvidedQueryBuilder); + BinaryClassInfo binaryClassInfo = new BinaryClassInfo(); for (SoftClassificationMetric metric : metrics) { - List aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo())); + List aggs = metric.aggs(actualField, Collections.singletonList(binaryClassInfo)); aggs.forEach(searchSourceBuilder::aggregation); } return searchSourceBuilder; } @Override - public void evaluate(SearchResponse searchResponse, ActionListener> listener) { + public void process(SearchResponse searchResponse) { + ExceptionsHelper.requireNonNull(searchResponse, "searchResponse"); if (searchResponse.getHits().getTotalHits().value == 0) { - listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, - predictedProbabilityField)); - return; + throw ExceptionsHelper.badRequestException( + "No documents found containing both [{}, {}] fields", actualField, predictedProbabilityField); } - - List results = new ArrayList<>(); - Aggregations aggs = searchResponse.getAggregations(); BinaryClassInfo binaryClassInfo = new BinaryClassInfo(); for (SoftClassificationMetric metric : metrics) { - results.add(metric.evaluate(binaryClassInfo, aggs)); + metric.process(binaryClassInfo, searchResponse.getAggregations()); } - listener.onResponse(results); } private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java index 54f245962d515..6fc05809245d3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java @@ -50,7 +50,7 @@ public String getWriteableName() { } @Override - public String getMetricName() { + public String getName() { return NAME.getPreferredName(); } @@ -132,7 +132,7 @@ public String getWriteableName() { } @Override - public String getName() { + public String getMetricName() { return NAME.getPreferredName(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java index d38a52bb203e8..a0fcda5f90c6f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java @@ -48,7 +48,7 @@ public String getWriteableName() { } @Override - public String getMetricName() { + public String getName() { return NAME.getPreferredName(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java index f7103aceedae0..53b3f1a24a2f4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java @@ -48,7 +48,7 @@ public String getWriteableName() { } @Override - public String getMetricName() { + public String getName() { return NAME.getPreferredName(); } @@ -68,7 +68,7 @@ public int hashCode() { @Override protected List aggsAt(String actualField, List classInfos, double threshold) { List aggs = new ArrayList<>(); - for (ClassInfo classInfo: classInfos) { + for (ClassInfo classInfo : classInfos) { aggs.add(buildAgg(classInfo, threshold, Condition.TP)); aggs.add(buildAgg(classInfo, threshold, Condition.FN)); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java index bd6b6e7db25a1..0ad99a83cf25b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java @@ -40,7 +40,7 @@ public String getWriteableName() { } @Override - public String getName() { + public String getMetricName() { return name; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java index dfb256e9b52f2..a5b072632c22a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java @@ -5,16 +5,15 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; -import org.elasticsearch.common.io.stream.NamedWriteable; -import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.Aggregations; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import java.util.List; -public interface SoftClassificationMetric extends ToXContentObject, NamedWriteable { +public interface SoftClassificationMetric extends EvaluationMetric { /** * The information of a specific class @@ -37,11 +36,6 @@ interface ClassInfo { String getProbabilityField(); } - /** - * Returns the name of the metric (which may differ to the writeable name) - */ - String getMetricName(); - /** * Builds the aggregation that collect required data to compute the metric * @param actualField the field that stores the actual class @@ -51,10 +45,9 @@ interface ClassInfo { List aggs(String actualField, List classInfos); /** - * Calculates the metric result for a given class + * Processes given aggregations as a step towards computing result * @param classInfo the class to calculate the metric for - * @param aggs the aggregations - * @return the metric result + * @param aggs aggregations from {@link SearchResponse} */ - EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs); + void process(ClassInfo classInfo, Aggregations aggs); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java index a22c499220ce6..2516b2fea94a5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java @@ -49,8 +49,9 @@ public void testEvaluate() { )); MeanSquaredError mse = new MeanSquaredError(); - EvaluationMetricResult result = mse.evaluate(aggs); + mse.process(aggs); + EvaluationMetricResult result = mse.getResult().get(); String expected = "{\"error\":0.8123}"; assertThat(Strings.toString(result), equalTo(expected)); } @@ -61,7 +62,9 @@ public void testEvaluate_GivenMissingAggs() { )); MeanSquaredError mse = new MeanSquaredError(); - EvaluationMetricResult result = mse.evaluate(aggs); + mse.process(aggs); + + EvaluationMetricResult result = mse.getResult().get(); assertThat(result, equalTo(new MeanSquaredError.Result(0.0))); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java index 2c04222700659..4913d232f74cc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java @@ -52,8 +52,9 @@ public void testEvaluate() { )); RSquared rSquared = new RSquared(); - EvaluationMetricResult result = rSquared.evaluate(aggs); + rSquared.process(aggs); + EvaluationMetricResult result = rSquared.getResult().get(); String expected = "{\"value\":0.9348643947690524}"; assertThat(Strings.toString(result), equalTo(expected)); } @@ -67,35 +68,48 @@ public void testEvaluateWithZeroCount() { )); RSquared rSquared = new RSquared(); - EvaluationMetricResult result = rSquared.evaluate(aggs); + rSquared.process(aggs); + + EvaluationMetricResult result = rSquared.getResult().get(); assertThat(result, equalTo(new RSquared.Result(0.0))); } public void testEvaluate_GivenMissingAggs() { - EvaluationMetricResult zeroResult = new RSquared.Result(0.0); Aggregations aggs = new Aggregations(Collections.singletonList( createSingleMetricAgg("some_other_single_metric_agg", 0.2377) )); RSquared rSquared = new RSquared(); - EvaluationMetricResult result = rSquared.evaluate(aggs); - assertThat(result, equalTo(zeroResult)); + rSquared.process(aggs); + + EvaluationMetricResult result = rSquared.getResult().get(); + assertThat(result, equalTo(new RSquared.Result(0.0))); + } - aggs = new Aggregations(Arrays.asList( + public void testEvaluate_GivenMissingExtendedStatsAgg() { + Aggregations aggs = new Aggregations(Arrays.asList( createSingleMetricAgg("some_other_single_metric_agg", 0.2377), createSingleMetricAgg("residual_sum_of_squares", 0.2377) )); - result = rSquared.evaluate(aggs); - assertThat(result, equalTo(zeroResult)); + RSquared rSquared = new RSquared(); + rSquared.process(aggs); - aggs = new Aggregations(Arrays.asList( + EvaluationMetricResult result = rSquared.getResult().get(); + assertThat(result, equalTo(new RSquared.Result(0.0))); + } + + public void testEvaluate_GivenMissingResidualSumOfSquaresAgg() { + Aggregations aggs = new Aggregations(Arrays.asList( createSingleMetricAgg("some_other_single_metric_agg", 0.2377), createExtendedStatsAgg("extended_stats_actual",100, 50) )); - result = rSquared.evaluate(aggs); - assertThat(result, equalTo(zeroResult)); + RSquared rSquared = new RSquared(); + rSquared.process(aggs); + + EvaluationMetricResult result = rSquared.getResult().get(); + assertThat(result, equalTo(new RSquared.Result(0.0))); } private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java index 7f089ab18cd9d..077998b66aed0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; @@ -22,6 +23,7 @@ import java.util.List; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; public class RegressionTests extends AbstractSerializingTestCase { @@ -43,13 +45,7 @@ public static Regression createRandom() { if (randomBoolean()) { metrics.add(RSquaredTests.createRandom()); } - return new Regression(randomAlphaOfLength(10), - randomAlphaOfLength(10), - randomBoolean() ? - null : - metrics.isEmpty() ? - null : - metrics); + return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); } @Override @@ -74,7 +70,6 @@ public void testConstructor_GivenEmptyMetrics() { } public void testBuildSearch() { - Regression evaluation = new Regression("act", "prob", Arrays.asList(new MeanSquaredError())); QueryBuilder userProvidedQuery = QueryBuilders.boolQuery() .filter(QueryBuilders.termQuery("field_A", "some-value")) @@ -82,10 +77,15 @@ public void testBuildSearch() { QueryBuilder expectedSearchQuery = QueryBuilders.boolQuery() .filter(QueryBuilders.existsQuery("act")) - .filter(QueryBuilders.existsQuery("prob")) + .filter(QueryBuilders.existsQuery("pred")) .filter(QueryBuilders.boolQuery() .filter(QueryBuilders.termQuery("field_A", "some-value")) .filter(QueryBuilders.termQuery("field_B", "some-other-value"))); - assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery)); + + Regression evaluation = new Regression("act", "pred", Arrays.asList(new MeanSquaredError())); + + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery); + assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); + assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java index 6a589c0d055ca..e63e88f6f848f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; @@ -22,6 +23,7 @@ import java.util.List; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; public class BinarySoftClassificationTests extends AbstractSerializingTestCase { @@ -81,7 +83,6 @@ public void testConstructor_GivenEmptyMetrics() { } public void testBuildSearch() { - BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7)))); QueryBuilder userProvidedQuery = QueryBuilders.boolQuery() .filter(QueryBuilders.termQuery("field_A", "some-value")) @@ -93,6 +94,11 @@ public void testBuildSearch() { .filter(QueryBuilders.boolQuery() .filter(QueryBuilders.termQuery("field_A", "some-value")) .filter(QueryBuilders.termQuery("field_B", "some-other-value"))); - assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery)); + + BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7)))); + + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery); + assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); + assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java index 2ca09af7d33aa..5c48be663f117 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java @@ -12,12 +12,13 @@ import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.Client; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; import java.util.List; @@ -38,24 +39,64 @@ public TransportEvaluateDataFrameAction(TransportService transportService, Actio @Override protected void doExecute(Task task, EvaluateDataFrameAction.Request request, ActionListener listener) { - Evaluation evaluation = request.getEvaluation(); - SearchRequest searchRequest = new SearchRequest(request.getIndices()); - searchRequest.source(evaluation.buildSearch(request.getParsedQuery())); - - ActionListener> resultsListener = ActionListener.wrap( - results -> listener.onResponse(new EvaluateDataFrameAction.Response(evaluation.getName(), results)), + ActionListener> resultsListener = ActionListener.wrap( + unused -> { + EvaluateDataFrameAction.Response response = + new EvaluateDataFrameAction.Response(request.getEvaluation().getName(), request.getEvaluation().getResults()); + listener.onResponse(response); + }, listener::onFailure ); - client.execute(SearchAction.INSTANCE, searchRequest, ActionListener.wrap( - searchResponse -> threadPool.generic().execute(() -> { - try { - evaluation.evaluate(searchResponse, resultsListener); - } catch (Exception e) { - listener.onFailure(e); - }; - }), - listener::onFailure - )); + EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, request); + evaluationExecutor.execute(resultsListener); + } + + /** + * {@link EvaluationExecutor} class allows for serial execution of evaluation steps. + * + * Each step consists of the following phases: + * 1. build search request with aggs requested by individual metrics + * 2. execute search action with the request built in (1.) + * 3. make all individual metrics process the search response obtained in (2.) + * 4. check if all the metrics have their results computed + * a) If so, call the final listener and finish + * b) Otherwise, add another step to the queue + * + * To avoid infinite loop it is essential that every metric *does* compute its result at some point. + * */ + private static final class EvaluationExecutor extends TypedChainTaskExecutor { + + private final Client client; + private final EvaluateDataFrameAction.Request request; + private final Evaluation evaluation; + + EvaluationExecutor(ThreadPool threadPool, Client client, EvaluateDataFrameAction.Request request) { + super(threadPool.generic(), unused -> true, unused -> true); + this.client = client; + this.request = request; + this.evaluation = request.getEvaluation(); + // Add one task only. Other tasks will be added as needed by the nextTask method itself. + add(nextTask()); + } + + private TypedChainTaskExecutor.ChainTask nextTask() { + return listener -> { + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(request.getParsedQuery()); + SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder); + client.execute( + SearchAction.INSTANCE, + searchRequest, + ActionListener.wrap( + searchResponse -> { + evaluation.process(searchResponse); + if (evaluation.hasAllResults() == false) { + add(nextTask()); + } + listener.onResponse(null); + }, + listener::onFailure)); + }; + } } }