diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java index 663e1ba639adf..98888c539c189 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java @@ -6,15 +6,23 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; import java.util.List; +import java.util.Objects; import java.util.Optional; +import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -27,37 +35,67 @@ public interface Evaluation extends ToXContentObject, NamedWriteable { */ String getName(); + /** + * Returns the field containing the actual value + */ + String getActualField(); + + /** + * Returns the field containing the predicted value + */ + String getPredictedField(); + /** * Returns the list of metrics to evaluate * @return list of metrics to evaluate */ List getMetrics(); + default List initMetrics(@Nullable List parsedMetrics, Supplier> defaultMetricsSupplier) { + List metrics = parsedMetrics == null ? defaultMetricsSupplier.get() : new ArrayList<>(parsedMetrics); + if (metrics.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName()); + } + Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName)); + return metrics; + } + /** * Builds the search required to collect data to compute the evaluation result * @param userProvidedQueryBuilder User-provided query that must be respected when collecting data */ - SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder); - - /** - * Builds the search that verifies existence of required fields and applies user-provided query - * @param requiredFields fields that must exist - * @param userProvidedQueryBuilder user-provided query - */ - default SearchSourceBuilder newSearchSourceBuilder(List requiredFields, QueryBuilder userProvidedQueryBuilder) { - BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); - for (String requiredField : requiredFields) { - boolQuery.filter(QueryBuilders.existsQuery(requiredField)); + default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { + Objects.requireNonNull(userProvidedQueryBuilder); + BoolQueryBuilder boolQuery = + QueryBuilders.boolQuery() + // Verify existence of required fields + .filter(QueryBuilders.existsQuery(getActualField())) + .filter(QueryBuilders.existsQuery(getPredictedField())) + // Apply user-provided query + .filter(userProvidedQueryBuilder); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery); + for (EvaluationMetric metric : getMetrics()) { + // Fetch aggregations requested by individual metrics + List aggs = metric.aggs(getActualField(), getPredictedField()); + aggs.forEach(searchSourceBuilder::aggregation); } - boolQuery.filter(userProvidedQueryBuilder); - return new SearchSourceBuilder().size(0).query(boolQuery); + return searchSourceBuilder; } /** * Processes {@link SearchResponse} from the search action * @param searchResponse response from the search action */ - void process(SearchResponse searchResponse); + default void process(SearchResponse searchResponse) { + Objects.requireNonNull(searchResponse); + if (searchResponse.getHits().getTotalHits().value == 0) { + throw ExceptionsHelper.badRequestException( + "No documents found containing both [{}, {}] fields", getActualField(), getPredictedField()); + } + for (EvaluationMetric metric : getMetrics()) { + metric.process(searchResponse.getAggregations()); + } + } /** * @return true iff all the metrics have their results computed diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java index 54934b64652c0..7a539d030dd44 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java @@ -5,9 +5,13 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import java.util.List; import java.util.Optional; /** @@ -20,6 +24,20 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable { */ String getName(); + /** + * Builds the aggregation that collect required data to compute the metric + * @param actualField the field that stores the actual value + * @param predictedField the field that stores the predicted value (class name or probability) + * @return the aggregations required to compute the metric + */ + List aggs(String actualField, String predictedField); + + /** + * Processes given aggregations as a step towards computing result + * @param aggs aggregations from {@link SearchResponse} + */ + void process(Aggregations aggs); + /** * Gets the evaluation result for this metric. * @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java index a90de52ea15a9..ee312ee7c7fd8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -13,17 +12,11 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.search.aggregations.AggregationBuilder; -import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; import java.util.List; import java.util.Objects; @@ -55,13 +48,13 @@ public static Classification fromXContent(XContentParser parser) { /** * The field containing the actual value - * The value of this field is assumed to be numeric + * The value of this field is assumed to be categorical */ private final String actualField; /** * The field containing the predicted value - * The value of this field is assumed to be numeric + * The value of this field is assumed to be categorical */ private final String predictedField; @@ -73,7 +66,11 @@ public static Classification fromXContent(XContentParser parser) { public Classification(String actualField, String predictedField, @Nullable List metrics) { this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD); - this.metrics = initMetrics(metrics); + this.metrics = initMetrics(metrics, Classification::defaultMetrics); + } + + private static List defaultMetrics() { + return Arrays.asList(new MulticlassConfusionMatrix()); } public Classification(StreamInput in) throws IOException { @@ -82,49 +79,24 @@ public Classification(StreamInput in) throws IOException { this.metrics = in.readNamedWriteableList(ClassificationMetric.class); } - private static List initMetrics(@Nullable List parsedMetrics) { - List metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics); - if (metrics.isEmpty()) { - throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName()); - } - Collections.sort(metrics, Comparator.comparing(ClassificationMetric::getName)); - return metrics; - } - - private static List defaultMetrics() { - return Arrays.asList(new MulticlassConfusionMatrix()); - } - @Override public String getName() { return NAME.getPreferredName(); } @Override - public List getMetrics() { - return metrics; + public String getActualField() { + return actualField; } @Override - public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { - ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder"); - SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder(List.of(actualField, predictedField), userProvidedQueryBuilder); - for (ClassificationMetric metric : metrics) { - List aggs = metric.aggs(actualField, predictedField); - aggs.forEach(searchSourceBuilder::aggregation); - } - return searchSourceBuilder; + public String getPredictedField() { + return predictedField; } @Override - public void process(SearchResponse searchResponse) { - ExceptionsHelper.requireNonNull(searchResponse, "searchResponse"); - if (searchResponse.getHits().getTotalHits().value == 0) { - throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField); - } - for (ClassificationMetric metric : metrics) { - metric.process(searchResponse.getAggregations()); - } + public List getMetrics() { + return metrics; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java index 220942a4838a5..a61ac9a702fa2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java @@ -5,26 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.search.aggregations.AggregationBuilder; -import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; -import java.util.List; - public interface ClassificationMetric extends EvaluationMetric { - - /** - * Builds the aggregation that collect required data to compute the metric - * @param actualField the field that stores the actual value - * @param predictedField the field that stores the predicted value - * @return the aggregations required to compute the metric - */ - List aggs(String actualField, String predictedField); - - /** - * Processes given aggregations as a step towards computing result - * @param aggs aggregations from {@link SearchResponse} - */ - void process(Aggregations aggs); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java index c5f1a7a2fde2a..ccf16a9618ec6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -13,17 +12,11 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.search.aggregations.AggregationBuilder; -import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; import java.util.List; import java.util.Objects; @@ -73,7 +66,11 @@ public static Regression fromXContent(XContentParser parser) { public Regression(String actualField, String predictedField, @Nullable List metrics) { this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD); - this.metrics = initMetrics(metrics); + this.metrics = initMetrics(metrics, Regression::defaultMetrics); + } + + private static List defaultMetrics() { + return Arrays.asList(new MeanSquaredError(), new RSquared()); } public Regression(StreamInput in) throws IOException { @@ -82,49 +79,24 @@ public Regression(StreamInput in) throws IOException { this.metrics = in.readNamedWriteableList(RegressionMetric.class); } - private static List initMetrics(@Nullable List parsedMetrics) { - List metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics); - if (metrics.isEmpty()) { - throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName()); - } - Collections.sort(metrics, Comparator.comparing(RegressionMetric::getName)); - return metrics; - } - - private static List defaultMetrics() { - return Arrays.asList(new MeanSquaredError(), new RSquared()); - } - @Override public String getName() { return NAME.getPreferredName(); } @Override - public List getMetrics() { - return metrics; + public String getActualField() { + return actualField; } @Override - public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { - ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder"); - SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder(List.of(actualField, predictedField), userProvidedQueryBuilder); - for (RegressionMetric metric : metrics) { - List aggs = metric.aggs(actualField, predictedField); - aggs.forEach(searchSourceBuilder::aggregation); - } - return searchSourceBuilder; + public String getPredictedField() { + return predictedField; } @Override - public void process(SearchResponse searchResponse) { - ExceptionsHelper.requireNonNull(searchResponse, "searchResponse"); - if (searchResponse.getHits().getTotalHits().value == 0) { - throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField); - } - for (RegressionMetric metric : metrics) { - metric.process(searchResponse.getAggregations()); - } + public List getMetrics() { + return metrics; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java index 08dfbfab4aa75..5b46829b4c852 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java @@ -5,26 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.search.aggregations.AggregationBuilder; -import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; -import java.util.List; - public interface RegressionMetric extends EvaluationMetric { - - /** - * Builds the aggregation that collect required data to compute the metric - * @param actualField the field that stores the actual value - * @param predictedField the field that stores the predicted value - * @return the aggregations required to compute the metric - */ - List aggs(String actualField, String predictedField); - - /** - * Processes given aggregations as a step towards computing result - * @param aggs aggregations from {@link SearchResponse} - */ - void process(Aggregations aggs); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java index 9ce186c524aa8..286a68314d713 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; @@ -18,10 +19,11 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.Optional; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery; + abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric { public static final ParseField AT = new ParseField("at"); @@ -29,8 +31,8 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric protected final double[] thresholds; private EvaluationMetricResult result; - protected AbstractConfusionMatrixMetric(double[] thresholds) { - this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT); + protected AbstractConfusionMatrixMetric(List at) { + this.thresholds = ExceptionsHelper.requireNonNull(at, AT).stream().mapToDouble(Double::doubleValue).toArray(); if (thresholds.length == 0) { throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] must have at least one value"); } @@ -60,20 +62,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - public final List aggs(String actualField, List classInfos) { + public final List aggs(String actualField, String predictedProbabilityField) { if (result != null) { return List.of(); } - List aggs = new ArrayList<>(); - for (double threshold : thresholds) { - aggs.addAll(aggsAt(actualField, classInfos, threshold)); - } - return aggs; + return aggsAt(actualField, predictedProbabilityField); } @Override - public void process(ClassInfo classInfo, Aggregations aggs) { - result = evaluate(classInfo, aggs); + public void process(Aggregations aggs) { + result = evaluate(aggs); } @Override @@ -81,40 +79,43 @@ public Optional getResult() { return Optional.ofNullable(result); } - protected abstract List aggsAt(String labelField, List classInfos, double threshold); + protected abstract List aggsAt(String actualField, String predictedProbabilityField); + + protected abstract EvaluationMetricResult evaluate(Aggregations aggs); - protected abstract EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs); + enum Condition { + TP(true, true), + FP(false, true), + TN(false, false), + FN(true, false); - protected enum Condition { - TP, FP, TN, FN; + final boolean actual; + final boolean predicted; + + Condition(boolean actual, boolean predicted) { + this.actual = actual; + this.predicted = predicted; + } } - protected String aggName(ClassInfo classInfo, double threshold, Condition condition) { - return getName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name(); + protected String aggName(double threshold, Condition condition) { + return getName() + "_at_" + threshold + "_" + condition.name(); } - protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) { + protected AggregationBuilder buildAgg(String actualField, String predictedProbabilityField, double threshold, Condition condition) { BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); - switch (condition) { - case TP: - boolQuery.must(classInfo.matchingQuery()); - boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold)); - break; - case FP: - boolQuery.mustNot(classInfo.matchingQuery()); - boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold)); - break; - case TN: - boolQuery.mustNot(classInfo.matchingQuery()); - boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold)); - break; - case FN: - boolQuery.must(classInfo.matchingQuery()); - boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold)); - break; - default: - throw new IllegalArgumentException("Unknown enum value: " + condition); + QueryBuilder actualIsTrueQuery = actualIsTrueQuery(actualField); + QueryBuilder predictedIsTrueQuery = QueryBuilders.rangeQuery(predictedProbabilityField).gte(threshold); + if (condition.actual) { + boolQuery.must(actualIsTrueQuery); + } else { + boolQuery.mustNot(actualIsTrueQuery); + } + if (condition.predicted) { + boolQuery.must(predictedIsTrueQuery); + } else { + boolQuery.mustNot(predictedIsTrueQuery); } - return AggregationBuilders.filter(aggName(classInfo, threshold, condition), boolQuery); + return AggregationBuilders.filter(aggName(threshold, condition), boolQuery); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java index 188713b037127..40257ebce4cdb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java @@ -33,6 +33,8 @@ import java.util.Optional; import java.util.stream.IntStream; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery; + /** * Area under the curve (AUC) of the receiver operating characteristic (ROC). * The ROC curve is a plot of the TPR (true positive rate) against @@ -66,6 +68,9 @@ public class AucRoc implements SoftClassificationMetric { private static final String PERCENTILES = "percentiles"; + private static final String TRUE_AGG_NAME = NAME.getPreferredName() + "_true"; + private static final String NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true"; + public static AucRoc fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } @@ -118,30 +123,39 @@ public int hashCode() { } @Override - public List aggs(String actualField, List classInfos) { + public List aggs(String actualField, String predictedProbabilityField) { if (result != null) { return List.of(); } double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray(); - List aggs = new ArrayList<>(); - for (ClassInfo classInfo : classInfos) { - AggregationBuilder percentilesForClassValueAgg = AggregationBuilders - .filter(evaluatedLabelAggName(classInfo), classInfo.matchingQuery()) + AggregationBuilder percentilesForClassValueAgg = + AggregationBuilders + .filter(TRUE_AGG_NAME, actualIsTrueQuery(actualField)) .subAggregation( - AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles)); - AggregationBuilder percentilesForRestAgg = AggregationBuilders - .filter(restLabelsAggName(classInfo), QueryBuilders.boolQuery().mustNot(classInfo.matchingQuery())) + AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles)); + AggregationBuilder percentilesForRestAgg = + AggregationBuilders + .filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery(actualField))) .subAggregation( - AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles)); - aggs.add(percentilesForClassValueAgg); - aggs.add(percentilesForRestAgg); - } - return aggs; + AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles)); + return List.of(percentilesForClassValueAgg, percentilesForRestAgg); } @Override - public void process(ClassInfo classInfo, Aggregations aggs) { - result = evaluate(classInfo, aggs); + public void process(Aggregations aggs) { + Filter classAgg = aggs.get(TRUE_AGG_NAME); + Filter restAgg = aggs.get(NON_TRUE_AGG_NAME); + double[] tpPercentiles = + percentilesArray( + classAgg.getAggregations().get(PERCENTILES), + "[" + getName() + "] requires at least one actual_field to have the value [true]"); + double[] fpPercentiles = + percentilesArray( + restAgg.getAggregations().get(PERCENTILES), + "[" + getName() + "] requires at least one actual_field to have a different value than [true]"); + List aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = calculateAucScore(aucRocCurve); + result = new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()); } @Override @@ -149,26 +163,6 @@ public Optional getResult() { return Optional.ofNullable(result); } - private String evaluatedLabelAggName(ClassInfo classInfo) { - return getName() + "_" + classInfo.getName(); - } - - private String restLabelsAggName(ClassInfo classInfo) { - return getName() + "_non_" + classInfo.getName(); - } - - private EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { - Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo)); - Filter restAgg = aggs.get(restLabelsAggName(classInfo)); - double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES), - "[" + getName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]"); - double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES), - "[" + getName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]"); - List aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); - double aucRocScore = calculateAucScore(aucRocCurve); - return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()); - } - private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) { double[] result = new double[99]; percentiles.forEach(percentile -> { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java index 30858107af089..67a635e078be2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -13,17 +12,11 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.search.aggregations.AggregationBuilder; -import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; import java.util.List; import java.util.Objects; @@ -74,16 +67,7 @@ public BinarySoftClassification(String actualField, String predictedProbabilityF @Nullable List metrics) { this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD); - this.metrics = initMetrics(metrics); - } - - private static List initMetrics(@Nullable List parsedMetrics) { - List metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics; - if (metrics.isEmpty()) { - throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName()); - } - Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getName)); - return metrics; + this.metrics = initMetrics(metrics, BinarySoftClassification::defaultMetrics); } private static List defaultMetrics() { @@ -100,6 +84,26 @@ public BinarySoftClassification(StreamInput in) throws IOException { this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class); } + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public String getActualField() { + return actualField; + } + + @Override + public String getPredictedField() { + return predictedProbabilityField; + } + + @Override + public List getMetrics() { + return metrics; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -142,60 +146,4 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(actualField, predictedProbabilityField, metrics); } - - @Override - public String getName() { - return NAME.getPreferredName(); - } - - @Override - public List getMetrics() { - return metrics; - } - - @Override - public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { - ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder"); - SearchSourceBuilder searchSourceBuilder = - newSearchSourceBuilder(List.of(actualField, predictedProbabilityField), userProvidedQueryBuilder); - BinaryClassInfo binaryClassInfo = new BinaryClassInfo(); - for (SoftClassificationMetric metric : metrics) { - List aggs = metric.aggs(actualField, Collections.singletonList(binaryClassInfo)); - aggs.forEach(searchSourceBuilder::aggregation); - } - return searchSourceBuilder; - } - - @Override - public void process(SearchResponse searchResponse) { - ExceptionsHelper.requireNonNull(searchResponse, "searchResponse"); - if (searchResponse.getHits().getTotalHits().value == 0) { - throw ExceptionsHelper.badRequestException( - "No documents found containing both [{}, {}] fields", actualField, predictedProbabilityField); - } - BinaryClassInfo binaryClassInfo = new BinaryClassInfo(); - for (SoftClassificationMetric metric : metrics) { - metric.process(binaryClassInfo, searchResponse.getAggregations()); - } - } - - private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo { - - private QueryBuilder matchingQuery = QueryBuilders.queryStringQuery(actualField + ": (1 OR true)"); - - @Override - public String getName() { - return String.valueOf(true); - } - - @Override - public QueryBuilder matchingQuery() { - return matchingQuery; - } - - @Override - public String getProbabilityField() { - return predictedProbabilityField; - } - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java index 6fc05809245d3..d52468a0214b6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java @@ -37,7 +37,7 @@ public static ConfusionMatrix fromXContent(XContentParser parser) { } public ConfusionMatrix(List at) { - super(at.stream().mapToDouble(Double::doubleValue).toArray()); + super(at); } public ConfusionMatrix(StreamInput in) throws IOException { @@ -68,28 +68,29 @@ public int hashCode() { } @Override - protected List aggsAt(String labelField, List classInfos, double threshold) { + protected List aggsAt(String actualField, String predictedProbabilityField) { List aggs = new ArrayList<>(); - for (ClassInfo classInfo : classInfos) { - aggs.add(buildAgg(classInfo, threshold, Condition.TP)); - aggs.add(buildAgg(classInfo, threshold, Condition.FP)); - aggs.add(buildAgg(classInfo, threshold, Condition.TN)); - aggs.add(buildAgg(classInfo, threshold, Condition.FN)); + for (int i = 0; i < thresholds.length; i++) { + double threshold = thresholds[i]; + aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP)); + aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FP)); + aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TN)); + aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FN)); } return aggs; } @Override - public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + public EvaluationMetricResult evaluate(Aggregations aggs) { long[] tp = new long[thresholds.length]; long[] fp = new long[thresholds.length]; long[] tn = new long[thresholds.length]; long[] fn = new long[thresholds.length]; for (int i = 0; i < thresholds.length; i++) { - Filter tpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TP)); - Filter fpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FP)); - Filter tnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TN)); - Filter fnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FN)); + Filter tpAgg = aggs.get(aggName(thresholds[i], Condition.TP)); + Filter fpAgg = aggs.get(aggName(thresholds[i], Condition.FP)); + Filter tnAgg = aggs.get(aggName(thresholds[i], Condition.TN)); + Filter fnAgg = aggs.get(aggName(thresholds[i], Condition.FN)); tp[i] = tpAgg.getDocCount(); fp[i] = fpAgg.getDocCount(); tn[i] = tnAgg.getDocCount(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java index a0fcda5f90c6f..80f838dd5d166 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java @@ -35,7 +35,7 @@ public static Precision fromXContent(XContentParser parser) { } public Precision(List at) { - super(at.stream().mapToDouble(Double::doubleValue).toArray()); + super(at); } public Precision(StreamInput in) throws IOException { @@ -66,22 +66,23 @@ public int hashCode() { } @Override - protected List aggsAt(String labelField, List classInfos, double threshold) { + protected List aggsAt(String actualField, String predictedProbabilityField) { List aggs = new ArrayList<>(); - for (ClassInfo classInfo : classInfos) { - aggs.add(buildAgg(classInfo, threshold, Condition.TP)); - aggs.add(buildAgg(classInfo, threshold, Condition.FP)); + for (int i = 0; i < thresholds.length; i++) { + double threshold = thresholds[i]; + aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP)); + aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FP)); } return aggs; } @Override - public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + public EvaluationMetricResult evaluate(Aggregations aggs) { double[] precisions = new double[thresholds.length]; - for (int i = 0; i < precisions.length; i++) { + for (int i = 0; i < thresholds.length; i++) { double threshold = thresholds[i]; - Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP)); - Filter fpAgg = aggs.get(aggName(classInfo, threshold, Condition.FP)); + Filter tpAgg = aggs.get(aggName(threshold, Condition.TP)); + Filter fpAgg = aggs.get(aggName(threshold, Condition.FP)); long tp = tpAgg.getDocCount(); long fp = fpAgg.getDocCount(); precisions[i] = tp + fp == 0 ? 0.0 : (double) tp / (tp + fp); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java index 53b3f1a24a2f4..70bda8099db89 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java @@ -35,7 +35,7 @@ public static Recall fromXContent(XContentParser parser) { } public Recall(List at) { - super(at.stream().mapToDouble(Double::doubleValue).toArray()); + super(at); } public Recall(StreamInput in) throws IOException { @@ -66,22 +66,23 @@ public int hashCode() { } @Override - protected List aggsAt(String actualField, List classInfos, double threshold) { + protected List aggsAt(String actualField, String predictedProbabilityField) { List aggs = new ArrayList<>(); - for (ClassInfo classInfo : classInfos) { - aggs.add(buildAgg(classInfo, threshold, Condition.TP)); - aggs.add(buildAgg(classInfo, threshold, Condition.FN)); + for (int i = 0; i < thresholds.length; i++) { + double threshold = thresholds[i]; + aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP)); + aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FN)); } return aggs; } @Override - public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + public EvaluationMetricResult evaluate(Aggregations aggs) { double[] recalls = new double[thresholds.length]; - for (int i = 0; i < recalls.length; i++) { + for (int i = 0; i < thresholds.length; i++) { double threshold = thresholds[i]; - Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP)); - Filter fnAgg = aggs.get(aggName(classInfo, threshold, Condition.FN)); + Filter tpAgg = aggs.get(aggName(threshold, Condition.TP)); + Filter fnAgg = aggs.get(aggName(threshold, Condition.FN)); long tp = tpAgg.getDocCount(); long fn = fnAgg.getDocCount(); recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java index a5b072632c22a..9a9c382caf9d1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java @@ -5,49 +5,13 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.search.aggregations.AggregationBuilder; -import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; -import java.util.List; - public interface SoftClassificationMetric extends EvaluationMetric { - /** - * The information of a specific class - */ - interface ClassInfo { - - /** - * Returns the class name - */ - String getName(); - - /** - * Returns a query that matches documents of the class - */ - QueryBuilder matchingQuery(); - - /** - * Returns the field that has the probability to be of the class - */ - String getProbabilityField(); + static QueryBuilder actualIsTrueQuery(String actualField) { + return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)"); } - - /** - * Builds the aggregation that collect required data to compute the metric - * @param actualField the field that stores the actual class - * @param classInfos the information of each class to compute the metric for - * @return the aggregations required to compute the metric - */ - List aggs(String actualField, List classInfos); - - /** - * Processes given aggregations as a step towards computing result - * @param classInfo the class to calculate the metric for - * @param aggs aggregations from {@link SearchResponse} - */ - void process(ClassInfo classInfo, Aggregations aggs); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java index 41f78051af420..cf54131af137e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java @@ -49,22 +49,19 @@ public static ConfusionMatrix createRandom() { } public void testEvaluate() { - SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); - when(classInfo.getName()).thenReturn("foo"); - Aggregations aggs = new Aggregations(Arrays.asList( - createFilterAgg("confusion_matrix_foo_at_0.25_TP", 1L), - createFilterAgg("confusion_matrix_foo_at_0.25_FP", 2L), - createFilterAgg("confusion_matrix_foo_at_0.25_TN", 3L), - createFilterAgg("confusion_matrix_foo_at_0.25_FN", 4L), - createFilterAgg("confusion_matrix_foo_at_0.5_TP", 5L), - createFilterAgg("confusion_matrix_foo_at_0.5_FP", 6L), - createFilterAgg("confusion_matrix_foo_at_0.5_TN", 7L), - createFilterAgg("confusion_matrix_foo_at_0.5_FN", 8L) + createFilterAgg("confusion_matrix_at_0.25_TP", 1L), + createFilterAgg("confusion_matrix_at_0.25_FP", 2L), + createFilterAgg("confusion_matrix_at_0.25_TN", 3L), + createFilterAgg("confusion_matrix_at_0.25_FN", 4L), + createFilterAgg("confusion_matrix_at_0.5_TP", 5L), + createFilterAgg("confusion_matrix_at_0.5_FP", 6L), + createFilterAgg("confusion_matrix_at_0.5_TN", 7L), + createFilterAgg("confusion_matrix_at_0.5_FN", 8L) )); ConfusionMatrix confusionMatrix = new ConfusionMatrix(Arrays.asList(0.25, 0.5)); - EvaluationMetricResult result = confusionMatrix.evaluate(classInfo, aggs); + EvaluationMetricResult result = confusionMatrix.evaluate(aggs); String expected = "{\"0.25\":{\"tp\":1,\"fp\":2,\"tn\":3,\"fn\":4},\"0.5\":{\"tp\":5,\"fp\":6,\"tn\":7,\"fn\":8}}"; assertThat(Strings.toString(result), equalTo(expected)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/PrecisionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/PrecisionTests.java index c12156c39373e..58f2864fd0747 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/PrecisionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/PrecisionTests.java @@ -49,36 +49,30 @@ public static Precision createRandom() { } public void testEvaluate() { - SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); - when(classInfo.getName()).thenReturn("foo"); - Aggregations aggs = new Aggregations(Arrays.asList( - createFilterAgg("precision_foo_at_0.25_TP", 1L), - createFilterAgg("precision_foo_at_0.25_FP", 4L), - createFilterAgg("precision_foo_at_0.5_TP", 3L), - createFilterAgg("precision_foo_at_0.5_FP", 1L), - createFilterAgg("precision_foo_at_0.75_TP", 5L), - createFilterAgg("precision_foo_at_0.75_FP", 0L) + createFilterAgg("precision_at_0.25_TP", 1L), + createFilterAgg("precision_at_0.25_FP", 4L), + createFilterAgg("precision_at_0.5_TP", 3L), + createFilterAgg("precision_at_0.5_FP", 1L), + createFilterAgg("precision_at_0.75_TP", 5L), + createFilterAgg("precision_at_0.75_FP", 0L) )); Precision precision = new Precision(Arrays.asList(0.25, 0.5, 0.75)); - EvaluationMetricResult result = precision.evaluate(classInfo, aggs); + EvaluationMetricResult result = precision.evaluate(aggs); String expected = "{\"0.25\":0.2,\"0.5\":0.75,\"0.75\":1.0}"; assertThat(Strings.toString(result), equalTo(expected)); } public void testEvaluate_GivenZeroTpAndFp() { - SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); - when(classInfo.getName()).thenReturn("foo"); - Aggregations aggs = new Aggregations(Arrays.asList( - createFilterAgg("precision_foo_at_1.0_TP", 0L), - createFilterAgg("precision_foo_at_1.0_FP", 0L) + createFilterAgg("precision_at_1.0_TP", 0L), + createFilterAgg("precision_at_1.0_FP", 0L) )); Precision precision = new Precision(Arrays.asList(1.0)); - EvaluationMetricResult result = precision.evaluate(classInfo, aggs); + EvaluationMetricResult result = precision.evaluate(aggs); String expected = "{\"1.0\":0.0}"; assertThat(Strings.toString(result), equalTo(expected)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java index fc85b44f151d4..009805425cd88 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java @@ -49,36 +49,30 @@ public static Recall createRandom() { } public void testEvaluate() { - SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); - when(classInfo.getName()).thenReturn("foo"); - Aggregations aggs = new Aggregations(Arrays.asList( - createFilterAgg("recall_foo_at_0.25_TP", 1L), - createFilterAgg("recall_foo_at_0.25_FN", 4L), - createFilterAgg("recall_foo_at_0.5_TP", 3L), - createFilterAgg("recall_foo_at_0.5_FN", 1L), - createFilterAgg("recall_foo_at_0.75_TP", 5L), - createFilterAgg("recall_foo_at_0.75_FN", 0L) + createFilterAgg("recall_at_0.25_TP", 1L), + createFilterAgg("recall_at_0.25_FN", 4L), + createFilterAgg("recall_at_0.5_TP", 3L), + createFilterAgg("recall_at_0.5_FN", 1L), + createFilterAgg("recall_at_0.75_TP", 5L), + createFilterAgg("recall_at_0.75_FN", 0L) )); Recall recall = new Recall(Arrays.asList(0.25, 0.5, 0.75)); - EvaluationMetricResult result = recall.evaluate(classInfo, aggs); + EvaluationMetricResult result = recall.evaluate(aggs); String expected = "{\"0.25\":0.2,\"0.5\":0.75,\"0.75\":1.0}"; assertThat(Strings.toString(result), equalTo(expected)); } public void testEvaluate_GivenZeroTpAndFp() { - SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); - when(classInfo.getName()).thenReturn("foo"); - Aggregations aggs = new Aggregations(Arrays.asList( - createFilterAgg("recall_foo_at_1.0_TP", 0L), - createFilterAgg("recall_foo_at_1.0_FN", 0L) + createFilterAgg("recall_at_1.0_TP", 0L), + createFilterAgg("recall_at_1.0_FN", 0L) )); Recall recall = new Recall(Arrays.asList(1.0)); - EvaluationMetricResult result = recall.evaluate(classInfo, aggs); + EvaluationMetricResult result = recall.evaluate(aggs); String expected = "{\"1.0\":0.0}"; assertThat(Strings.toString(result), equalTo(expected));