From 4366d5856404064d98cec863d0f8e9a09ea8ea0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Wed, 30 Sep 2020 12:55:52 +0200 Subject: [PATCH] [7.x] [ML] Implement AucRoc metric for classification (#60502) (#63051) --- .../apis/evaluate-dfanalytics.asciidoc | 27 +- .../ml/dataframe/analyses/Classification.java | 11 +- .../ml/dataframe/evaluation/Evaluation.java | 105 ++++-- .../evaluation/EvaluationFields.java | 141 ++++++++ .../evaluation/EvaluationMetric.java | 13 +- .../evaluation/EvaluationParameters.java | 2 +- .../MlEvaluationNamedXContentProvider.java | 83 ++--- .../classification/AbstractAucRoc.java | 317 ++++++++++++++++++ .../evaluation/classification/Accuracy.java | 18 +- .../evaluation/classification/AucRoc.java | 228 +++++++++++++ .../classification/Classification.java | 103 ++++-- .../MulticlassConfusionMatrix.java | 13 +- .../evaluation/classification/Precision.java | 13 +- .../evaluation/classification/Recall.java | 13 +- .../AbstractConfusionMatrixMetric.java | 14 +- .../evaluation/outlierdetection/AucRoc.java | 277 +++------------ .../outlierdetection/OutlierDetection.java | 58 ++-- .../evaluation/regression/Huber.java | 13 +- .../regression/MeanSquaredError.java | 13 +- .../MeanSquaredLogarithmicError.java | 13 +- .../evaluation/regression/RSquared.java | 13 +- .../evaluation/regression/Regression.java | 54 ++- .../EvaluateDataFrameActionResponseTests.java | 48 ++- .../analyses/ClassificationTests.java | 27 +- .../evaluation/EvaluationFieldsTests.java | 48 +++ .../classification/AbstractAucRocTests.java | 105 ++++++ .../classification/AccuracyTests.java | 6 +- .../classification/AucRocResultTests.java | 46 +++ .../classification/AucRocTests.java | 34 ++ .../classification/ClassificationTests.java | 126 ++++++- .../MulticlassConfusionMatrixTests.java | 11 +- .../classification/PrecisionTests.java | 6 +- .../classification/RecallTests.java | 6 +- .../outlierdetection/AucRocTests.java | 93 ----- .../OutlierDetectionTests.java | 14 + .../regression/RegressionTests.java | 14 + .../ml/qa/ml-with-security/build.gradle | 5 + .../ClassificationEvaluationIT.java | 96 ++++-- .../ml/integration/ClassificationIT.java | 25 +- .../OutlierDetectionEvaluationIT.java | 114 +++++++ .../ml/dataframe/DestinationIndexTests.java | 16 +- .../test/ml/evaluate_data_frame.yml | 217 +++++++++++- 42 files changed, 2007 insertions(+), 592 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationFields.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationFieldsTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRocTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocTests.java create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java diff --git a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc index da72a588af1d1..42562cdb8e9a9 100644 --- a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc @@ -88,7 +88,7 @@ the probability that each document is an outlier. `auc_roc`::: (Optional, object) The AUC ROC (area under the curve of the receiver operating characteristic) score and optionally the curve. Default value is - {"includes_curve": false}. + {"include_curve": false}. `confusion_matrix`::: (Optional, object) Set the different thresholds of the {olscore} at where @@ -153,9 +153,14 @@ belongs. The data type of this field must be categorical. `predicted_field`:: - (Required, string) The field in the `index` that contains the predicted value, + (Optional, string) The field in the `index` which contains the predicted value, in other words the results of the {classanalysis}. +`top_classes_field`:: + (Optional, string) The field of the `index` which is an array of documents + of the form `{ "class_name": XXX, "class_probability": YYY }`. + This field must be defined as `nested` in the mappings. + `metrics`:: (Optional, object) Specifies the metrics that are used for the evaluation. Available metrics: @@ -163,6 +168,24 @@ belongs. `accuracy`::: (Optional, object) Accuracy of predictions (per-class and overall). + `auc_roc`::: + (Optional, object) The AUC ROC (area under the curve of the receiver + operating characteristic) score and optionally the curve. + It is calculated for a specific class (provided as "class_name") + treated as positive. + + `class_name`:::: + (Required, string) Name of the only class that will be treated as + positive during AUC ROC calculation. Other classes will be treated as + negative ("one-vs-all" strategy). Documents which do not have `class_name` + in the list of their top classes will not be taken into account for evaluation. + The number of documents taken into account is returned in the evaluation result + (`auc_roc.doc_count` field). + + `include_curve`:::: + (Optional, boolean) Whether or not the curve should be returned in + addition to the score. Default value is false. + `multiclass_confusion_matrix`::: (Optional, object) Multiclass confusion matrix. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 5c78213060d97..6f1777d55af18 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -394,7 +394,16 @@ public Map getExplicitlyMappedFields(Map mapping return additionalProperties; } additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping); - additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping); + + Map topClassesProperties = new HashMap<>(); + topClassesProperties.put("class_name", dependentVariableMapping); + topClassesProperties.put("class_probability", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); + + Map topClassesMapping = new HashMap<>(); + topClassesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE); + topClassesMapping.put("properties", topClassesProperties); + + additionalProperties.put(resultsFieldName + ".top_classes", topClassesMapping); return additionalProperties; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java index 9fdba68d4cda7..9d53777f91428 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation; +import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.collect.Tuple; @@ -21,11 +22,16 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.function.Supplier; -import java.util.stream.Collectors; + +import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toSet; /** * Defines an evaluation @@ -38,14 +44,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable { String getName(); /** - * Returns the field containing the actual value - */ - String getActualField(); - - /** - * Returns the field containing the predicted value + * Returns the collection of fields required by evaluation */ - String getPredictedField(); + EvaluationFields getFields(); /** * Returns the list of metrics to evaluate @@ -59,27 +60,74 @@ default List initMetrics(@Nullable List parse throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName()); } Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName)); + checkRequiredFieldsAreSet(metrics); return metrics; } + default void checkRequiredFieldsAreSet(List metrics) { + assert (metrics == null || metrics.isEmpty()) == false; + for (Tuple requiredField : getFields().listPotentiallyRequiredFields()) { + String fieldDescriptor = requiredField.v1(); + String field = requiredField.v2(); + if (field == null) { + String metricNamesString = + metrics.stream() + .filter(m -> m.getRequiredFields().contains(fieldDescriptor)) + .map(EvaluationMetric::getName) + .collect(joining(", ")); + if (metricNamesString.isEmpty() == false) { + throw ExceptionsHelper.badRequestException( + "[{}] must define [{}] as required by the following metrics [{}]", + getName(), fieldDescriptor, metricNamesString); + } + } + } + } + /** * Builds the search required to collect data to compute the evaluation result * @param userProvidedQueryBuilder User-provided query that must be respected when collecting data */ default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) { Objects.requireNonNull(userProvidedQueryBuilder); - BoolQueryBuilder boolQuery = - QueryBuilders.boolQuery() - // Verify existence of required fields - .filter(QueryBuilders.existsQuery(getActualField())) - .filter(QueryBuilders.existsQuery(getPredictedField())) - // Apply user-provided query - .filter(userProvidedQueryBuilder); + Set requiredFields = new HashSet<>(getRequiredFields()); + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + if (getFields().getActualField() != null && requiredFields.contains(getFields().getActualField())) { + // Verify existence of the actual field if required + boolQuery.filter(QueryBuilders.existsQuery(getFields().getActualField())); + } + if (getFields().getPredictedField() != null && requiredFields.contains(getFields().getPredictedField())) { + // Verify existence of the predicted field if required + boolQuery.filter(QueryBuilders.existsQuery(getFields().getPredictedField())); + } + if (getFields().getPredictedClassField() != null && requiredFields.contains(getFields().getPredictedClassField())) { + assert getFields().getTopClassesField() != null; + // Verify existence of the predicted class name field if required + QueryBuilder predictedClassFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedClassField()); + boolQuery.filter( + QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedClassFieldExistsQuery, ScoreMode.None) + .ignoreUnmapped(true)); + } + if (getFields().getPredictedProbabilityField() != null && requiredFields.contains(getFields().getPredictedProbabilityField())) { + // Verify existence of the predicted probability field if required + QueryBuilder predictedProbabilityFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedProbabilityField()); + // predicted probability field may be either nested (just like in case of classification evaluation) or non-nested (just like + // in case of outlier detection evaluation). Here we support both modes. + if (getFields().isPredictedProbabilityFieldNested()) { + assert getFields().getTopClassesField() != null; + boolQuery.filter( + QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedProbabilityFieldExistsQuery, ScoreMode.None) + .ignoreUnmapped(true)); + } else { + boolQuery.filter(predictedProbabilityFieldExistsQuery); + } + } + // Apply user-provided query + boolQuery.filter(userProvidedQueryBuilder); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery); for (EvaluationMetric metric : getMetrics()) { // Fetch aggregations requested by individual metrics - Tuple, List> aggs = - metric.aggs(parameters, getActualField(), getPredictedField()); + Tuple, List> aggs = metric.aggs(parameters, getFields()); aggs.v1().forEach(searchSourceBuilder::aggregation); aggs.v2().forEach(searchSourceBuilder::aggregation); } @@ -93,14 +141,31 @@ default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBu default void process(SearchResponse searchResponse) { Objects.requireNonNull(searchResponse); if (searchResponse.getHits().getTotalHits().value == 0) { - throw ExceptionsHelper.badRequestException( - "No documents found containing both [{}, {}] fields", getActualField(), getPredictedField()); + String requiredFieldsString = String.join(", ", getRequiredFields()); + throw ExceptionsHelper.badRequestException("No documents found containing all the required fields [{}]", requiredFieldsString); } for (EvaluationMetric metric : getMetrics()) { metric.process(searchResponse.getAggregations()); } } + /** + * @return list of fields which are required by at least one of the metrics + */ + default List getRequiredFields() { + Set requiredFieldDescriptors = + getMetrics().stream() + .map(EvaluationMetric::getRequiredFields) + .flatMap(Set::stream) + .collect(toSet()); + List requiredFields = + getFields().listPotentiallyRequiredFields().stream() + .filter(f -> requiredFieldDescriptors.contains(f.v1())) + .map(Tuple::v2) + .collect(toList()); + return requiredFields; + } + /** * @return true iff all the metrics have their results computed */ @@ -117,6 +182,6 @@ default List getResults() { .map(EvaluationMetric::getResult) .filter(Optional::isPresent) .map(Optional::get) - .collect(Collectors.toList()); + .collect(toList()); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationFields.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationFields.java new file mode 100644 index 0000000000000..6df04b71de67d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationFields.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * Encapsulates fields needed by evaluation. + */ +public final class EvaluationFields { + + public static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + public static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); + public static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field"); + public static final ParseField PREDICTED_CLASS_FIELD = new ParseField("predicted_class_field"); + public static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field"); + + /** + * The field containing the actual value + */ + private final String actualField; + + /** + * The field containing the predicted value + */ + private final String predictedField; + + /** + * The field containing the array of top classes + */ + private final String topClassesField; + + /** + * The field containing the predicted class name value + */ + private final String predictedClassField; + + /** + * The field containing the predicted probability value in [0.0, 1.0] + */ + private final String predictedProbabilityField; + + /** + * Whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries). + */ + private final boolean predictedProbabilityFieldNested; + + public EvaluationFields(@Nullable String actualField, + @Nullable String predictedField, + @Nullable String topClassesField, + @Nullable String predictedClassField, + @Nullable String predictedProbabilityField, + boolean predictedProbabilityFieldNested) { + + this.actualField = actualField; + this.predictedField = predictedField; + this.topClassesField = topClassesField; + this.predictedClassField = predictedClassField; + this.predictedProbabilityField = predictedProbabilityField; + this.predictedProbabilityFieldNested = predictedProbabilityFieldNested; + } + + /** + * Returns the field containing the actual value + */ + public String getActualField() { + return actualField; + } + + /** + * Returns the field containing the predicted value + */ + public String getPredictedField() { + return predictedField; + } + + /** + * Returns the field containing the array of top classes + */ + public String getTopClassesField() { + return topClassesField; + } + + /** + * Returns the field containing the predicted class name value + */ + public String getPredictedClassField() { + return predictedClassField; + } + + /** + * Returns the field containing the predicted probability value in [0.0, 1.0] + */ + public String getPredictedProbabilityField() { + return predictedProbabilityField; + } + + /** + * Returns whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries). + */ + public boolean isPredictedProbabilityFieldNested() { + return predictedProbabilityFieldNested; + } + + public List> listPotentiallyRequiredFields() { + return Arrays.asList( + Tuple.tuple(ACTUAL_FIELD.getPreferredName(), actualField), + Tuple.tuple(PREDICTED_FIELD.getPreferredName(), predictedField), + Tuple.tuple(TOP_CLASSES_FIELD.getPreferredName(), topClassesField), + Tuple.tuple(PREDICTED_CLASS_FIELD.getPreferredName(), predictedClassField), + Tuple.tuple(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EvaluationFields that = (EvaluationFields) o; + return Objects.equals(that.actualField, this.actualField) + && Objects.equals(that.predictedField, this.predictedField) + && Objects.equals(that.topClassesField, this.topClassesField) + && Objects.equals(that.predictedClassField, this.predictedClassField) + && Objects.equals(that.predictedProbabilityField, this.predictedProbabilityField) + && Objects.equals(that.predictedProbabilityFieldNested, this.predictedProbabilityFieldNested); + } + + @Override + public int hashCode() { + return Objects.hash( + actualField, predictedField, topClassesField, predictedClassField, predictedProbabilityField, predictedProbabilityFieldNested); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java index a5c3d657f55e9..5d8ab826b46b9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java @@ -15,6 +15,7 @@ import java.util.List; import java.util.Optional; +import java.util.Set; /** * {@link EvaluationMetric} class represents a metric to evaluate. @@ -26,16 +27,18 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable { */ String getName(); + /** + * Returns the set of fields that this metric requires in order to be calculated. + */ + Set getRequiredFields(); + /** * Builds the aggregation that collect required data to compute the metric * @param parameters settings that may be needed by aggregations - * @param actualField the field that stores the actual value - * @param predictedField the field that stores the predicted value (class name or probability) + * @param fields fields that may be needed by aggregations * @return the aggregations required to compute the metric */ - Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedField); + Tuple, List> aggs(EvaluationParameters parameters, EvaluationFields fields); /** * Processes given aggregations as a step towards computing result diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationParameters.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationParameters.java index e834efc7e67f7..b3ffcdc24ecbe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationParameters.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationParameters.java @@ -8,7 +8,7 @@ /** * Encapsulates parameters needed by evaluation. */ -public class EvaluationParameters { +public final class EvaluationParameters { /** * Maximum number of buckets allowed in any single search request. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index 009b4fa22890d..c7ae0a3848775 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -10,13 +10,13 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.ConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.ScoreByThresholdResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; @@ -63,19 +63,28 @@ public List getNamedXContentParsers() { // Outlier detection metrics new NamedXContentRegistry.Entry(EvaluationMetric.class, - new ParseField(registeredMetricName(OutlierDetection.NAME, AucRoc.NAME)), - AucRoc::fromXContent), + new ParseField( + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc.NAME)), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, - new ParseField(registeredMetricName(OutlierDetection.NAME, Precision.NAME)), - Precision::fromXContent), + new ParseField( + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision.NAME)), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, - new ParseField(registeredMetricName(OutlierDetection.NAME, Recall.NAME)), - Recall::fromXContent), + new ParseField( + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall.NAME)), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrix.NAME)), ConfusionMatrix::fromXContent), // Classification metrics + new NamedXContentRegistry.Entry(EvaluationMetric.class, + new ParseField(registeredMetricName(Classification.NAME, AucRoc.NAME)), + AucRoc::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME)), MulticlassConfusionMatrix::fromXContent), @@ -83,15 +92,11 @@ public List getNamedXContentParsers() { new ParseField(registeredMetricName(Classification.NAME, Accuracy.NAME)), Accuracy::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, - new ParseField( - registeredMetricName( - Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME)), - org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::fromXContent), + new ParseField(registeredMetricName(Classification.NAME, Precision.NAME)), + Precision::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, - new ParseField( - registeredMetricName( - Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME)), - org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::fromXContent), + new ParseField(registeredMetricName(Classification.NAME, Recall.NAME)), + Recall::fromXContent), // Regression metrics new NamedXContentRegistry.Entry(EvaluationMetric.class, @@ -124,17 +129,23 @@ public static List getNamedWriteables() { // Evaluation metrics new NamedWriteableRegistry.Entry(EvaluationMetric.class, - registeredMetricName(OutlierDetection.NAME, AucRoc.NAME), - AucRoc::new), + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc.NAME), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc::new), new NamedWriteableRegistry.Entry(EvaluationMetric.class, - registeredMetricName(OutlierDetection.NAME, Precision.NAME), - Precision::new), + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision.NAME), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision::new), new NamedWriteableRegistry.Entry(EvaluationMetric.class, - registeredMetricName(OutlierDetection.NAME, Recall.NAME), - Recall::new), + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall.NAME), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall::new), new NamedWriteableRegistry.Entry(EvaluationMetric.class, registeredMetricName(OutlierDetection.NAME, ConfusionMatrix.NAME), ConfusionMatrix::new), + new NamedWriteableRegistry.Entry(EvaluationMetric.class, + registeredMetricName(Classification.NAME, AucRoc.NAME), + AucRoc::new), new NamedWriteableRegistry.Entry(EvaluationMetric.class, registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME), MulticlassConfusionMatrix::new), @@ -142,13 +153,11 @@ public static List getNamedWriteables() { registeredMetricName(Classification.NAME, Accuracy.NAME), Accuracy::new), new NamedWriteableRegistry.Entry(EvaluationMetric.class, - registeredMetricName( - Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME), - org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::new), + registeredMetricName(Classification.NAME, Precision.NAME), + Precision::new), new NamedWriteableRegistry.Entry(EvaluationMetric.class, - registeredMetricName( - Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME), - org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::new), + registeredMetricName(Classification.NAME, Recall.NAME), + Recall::new), new NamedWriteableRegistry.Entry(EvaluationMetric.class, registeredMetricName(Regression.NAME, MeanSquaredError.NAME), MeanSquaredError::new), @@ -163,15 +172,15 @@ public static List getNamedWriteables() { RSquared::new), // Evaluation metrics results - new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, - registeredMetricName(OutlierDetection.NAME, AucRoc.NAME), - AucRoc.Result::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, registeredMetricName(OutlierDetection.NAME, ScoreByThresholdResult.NAME), ScoreByThresholdResult::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, registeredMetricName(OutlierDetection.NAME, ConfusionMatrix.NAME), ConfusionMatrix.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + registeredMetricName(Classification.NAME, AucRoc.NAME), + AucRoc.Result::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME), MulticlassConfusionMatrix.Result::new), @@ -179,13 +188,11 @@ public static List getNamedWriteables() { registeredMetricName(Classification.NAME, Accuracy.NAME), Accuracy.Result::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, - registeredMetricName( - Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME), - org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result::new), + registeredMetricName(Classification.NAME, Precision.NAME), + Precision.Result::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, - registeredMetricName( - Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME), - org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result::new), + registeredMetricName(Classification.NAME, Recall.NAME), + Recall.Result::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, registeredMetricName(Regression.NAME, MeanSquaredError.NAME), MeanSquaredError.Result::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java new file mode 100644 index 0000000000000..30a7a55c3edd3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java @@ -0,0 +1,317 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.Version; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.search.aggregations.metrics.Percentiles; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + +/** + * Area under the curve (AUC) of the receiver operating characteristic (ROC). + * The ROC curve is a plot of the TPR (true positive rate) against + * the FPR (false positive rate) over a varying threshold. + * + * This particular implementation is making use of ES aggregations + * to calculate the curve. It then uses the trapezoidal rule to calculate + * the AUC. + * + * In particular, in order to calculate the ROC, we get percentiles of TP + * and FP against the predicted probability. We call those Rate-Threshold + * curves. We then scan ROC points from each Rate-Threshold curve against the + * other using interpolation. This gives us an approximation of the ROC curve + * that has the advantage of being efficient and resilient to some edge cases. + * + * When this is used for multi-class classification, it will calculate the ROC + * curve of each class versus the rest. + */ +public abstract class AbstractAucRoc implements EvaluationMetric { + + public static final ParseField NAME = new ParseField("auc_roc"); + + protected AbstractAucRoc() {} + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + protected static double[] percentilesArray(Percentiles percentiles) { + double[] result = new double[99]; + percentiles.forEach(percentile -> { + if (Double.isNaN(percentile.getValue())) { + throw ExceptionsHelper.badRequestException( + "[{}] requires at all the percentiles values to be finite numbers", NAME.getPreferredName()); + } + result[((int) percentile.getPercent()) - 1] = percentile.getValue(); + }); + return result; + } + + /** + * Visible for testing + */ + protected static List buildAucRocCurve(double[] tpPercentiles, double[] fpPercentiles) { + assert tpPercentiles.length == fpPercentiles.length; + assert tpPercentiles.length == 99; + + List aucRocCurve = new ArrayList<>(); + aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0)); + aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0)); + RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true); + RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false); + aucRocCurve.addAll(tpCurve.scanPoints(fpCurve)); + aucRocCurve.addAll(fpCurve.scanPoints(tpCurve)); + Collections.sort(aucRocCurve); + return aucRocCurve; + } + + /** + * Visible for testing + */ + protected static double calculateAucScore(List rocCurve) { + // Calculates AUC based on the trapezoid rule + double aucRoc = 0.0; + for (int i = 1; i < rocCurve.size(); i++) { + AucRocPoint left = rocCurve.get(i - 1); + AucRocPoint right = rocCurve.get(i); + aucRoc += (right.fpr - left.fpr) * (right.tpr + left.tpr) / 2; + } + return aucRoc; + } + + private static class RateThresholdCurve { + + private final double[] percentiles; + private final boolean isTp; + + private RateThresholdCurve(double[] percentiles, boolean isTp) { + this.percentiles = percentiles; + this.isTp = isTp; + } + + private double getRate(int index) { + return 1 - 0.01 * (index + 1); + } + + private double getThreshold(int index) { + return percentiles[index]; + } + + private double interpolateRate(double threshold) { + int binarySearchResult = Arrays.binarySearch(percentiles, threshold); + if (binarySearchResult >= 0) { + return getRate(binarySearchResult); + } else { + int right = (binarySearchResult * -1) -1; + int left = right - 1; + if (right >= percentiles.length) { + return 0.0; + } else if (left < 0) { + return 1.0; + } else { + double rightRate = getRate(right); + double leftRate = getRate(left); + return interpolate(threshold, percentiles[left], leftRate, percentiles[right], rightRate); + } + } + } + + private List scanPoints(RateThresholdCurve againstCurve) { + List points = new ArrayList<>(); + for (int index = 0; index < percentiles.length; index++) { + double rate = getRate(index); + double scannedThreshold = getThreshold(index); + double againstRate = againstCurve.interpolateRate(scannedThreshold); + AucRocPoint point; + if (isTp) { + point = new AucRocPoint(rate, againstRate, scannedThreshold); + } else { + point = new AucRocPoint(againstRate, rate, scannedThreshold); + } + points.add(point); + } + return points; + } + } + + public static final class AucRocPoint implements Comparable, ToXContentObject, Writeable { + + private static final String TPR = "tpr"; + private static final String FPR = "fpr"; + private static final String THRESHOLD = "threshold"; + + private final double tpr; + private final double fpr; + private final double threshold; + + AucRocPoint(double tpr, double fpr, double threshold) { + this.tpr = tpr; + this.fpr = fpr; + this.threshold = threshold; + } + + private AucRocPoint(StreamInput in) throws IOException { + this.tpr = in.readDouble(); + this.fpr = in.readDouble(); + this.threshold = in.readDouble(); + } + + @Override + public int compareTo(AucRocPoint o) { + return Comparator.comparingDouble((AucRocPoint p) -> p.threshold).reversed() + .thenComparing(p -> p.fpr) + .thenComparing(p -> p.tpr) + .compare(this, o); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(tpr); + out.writeDouble(fpr); + out.writeDouble(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TPR, tpr); + builder.field(FPR, fpr); + builder.field(THRESHOLD, threshold); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRocPoint that = (AucRocPoint) o; + return tpr == that.tpr + && fpr == that.fpr + && threshold == that.threshold; + } + + @Override + public int hashCode() { + return Objects.hash(tpr, fpr, threshold); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + private static double interpolate(double x, double x1, double y1, double x2, double y2) { + return y1 + (x - x1) * (y2 - y1) / (x2 - x1); + } + + public static class Result implements EvaluationMetricResult { + + private static final String SCORE = "score"; + private static final String DOC_COUNT = "doc_count"; + private static final String CURVE = "curve"; + + private final double score; + private final Long docCount; + private final List curve; + + public Result(double score, Long docCount, List curve) { + this.score = score; + this.docCount = docCount; + this.curve = Objects.requireNonNull(curve); + } + + public Result(StreamInput in) throws IOException { + this.score = in.readDouble(); + if (in.getVersion().onOrAfter(Version.V_7_10_0)) { + this.docCount = in.readOptionalLong(); + } else { + this.docCount = null; + } + this.curve = in.readList(AucRocPoint::new); + } + + public double getScore() { + return score; + } + + public Long getDocCount() { + return docCount; + } + + public List getCurve() { + return Collections.unmodifiableList(curve); + } + + @Override + public String getWriteableName() { + return registeredMetricName(Classification.NAME, NAME); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(score); + if (out.getVersion().onOrAfter(Version.V_7_10_0)) { + out.writeOptionalLong(docCount); + } + out.writeList(curve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(SCORE, score); + if (docCount != null) { + builder.field(DOC_COUNT, docCount); + } + if (curve.isEmpty() == false) { + builder.field(CURVE, curve); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return score == that.score + && Objects.equals(docCount, that.docCount) + && Objects.equals(curve, that.curve); + } + + @Override + public int hashCode() { + return Objects.hash(score, docCount, curve); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index 440a46540cecb..96d249326fbe6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -22,6 +23,7 @@ import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -33,6 +35,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; @@ -95,21 +98,24 @@ public String getName() { return NAME.getPreferredName(); } + @Override + public Set getRequiredFields() { + return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName()); + } + @Override public final Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedField) { + EvaluationFields fields) { // Store given {@code actualField} for the purpose of generating error message in {@code process}. - this.actualField.trySet(actualField); + this.actualField.trySet(fields.getActualField()); List aggs = new ArrayList<>(); List pipelineAggs = new ArrayList<>(); if (overallAccuracy.get() == null) { - Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField); + Script script = PainlessScripts.buildIsEqualScript(fields.getActualField(), fields.getPredictedField()); aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(script)); } if (result.get() == null) { - Tuple, List> matrixAggs = - matrix.aggs(parameters, actualField, predictedField); + Tuple, List> matrixAggs = matrix.aggs(parameters, fields); aggs.addAll(matrixAggs.v1()); pipelineAggs.addAll(matrixAggs.v2()); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java new file mode 100644 index 0000000000000..e67361e457c04 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java @@ -0,0 +1,228 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.search.aggregations.bucket.nested.Nested; +import org.elasticsearch.search.aggregations.metrics.Percentiles; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; + +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; + +/** + * Area under the curve (AUC) of the receiver operating characteristic (ROC). + * The ROC curve is a plot of the TPR (true positive rate) against + * the FPR (false positive rate) over a varying threshold. + * + * This particular implementation is making use of ES aggregations + * to calculate the curve. It then uses the trapezoidal rule to calculate + * the AUC. + * + * In particular, in order to calculate the ROC, we get percentiles of TP + * and FP against the predicted probability. We call those Rate-Threshold + * curves. We then scan ROC points from each Rate-Threshold curve against the + * other using interpolation. This gives us an approximation of the ROC curve + * that has the advantage of being efficient and resilient to some edge cases. + * + * When this is used for multi-class classification, it will calculate the ROC + * curve of each class versus the rest. + */ +public class AucRoc extends AbstractAucRoc { + + public static final ParseField INCLUDE_CURVE = new ParseField("include_curve"); + public static final ParseField CLASS_NAME = new ParseField("class_name"); + + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME.getPreferredName(), a -> new AucRoc((Boolean) a[0], (String) a[1])); + + static { + PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), CLASS_NAME); + } + + private static final String TRUE_AGG_NAME = NAME.getPreferredName() + "_true"; + private static final String NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true"; + private static final String NESTED_AGG_NAME = "nested"; + private static final String NESTED_FILTER_AGG_NAME = "nested_filter"; + private static final String PERCENTILES_AGG_NAME = "percentiles"; + + public static AucRoc fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final boolean includeCurve; + private final String className; + private final SetOnce fields = new SetOnce<>(); + private final SetOnce result = new SetOnce<>(); + + public AucRoc(Boolean includeCurve, String className) { + this.includeCurve = includeCurve == null ? false : includeCurve; + this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME.getPreferredName()); + } + + public AucRoc(StreamInput in) throws IOException { + this.includeCurve = in.readBoolean(); + this.className = in.readOptionalString(); + } + + @Override + public String getWriteableName() { + return registeredMetricName(Classification.NAME, NAME); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(includeCurve); + out.writeOptionalString(className); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve); + if (className != null) { + builder.field(CLASS_NAME.getPreferredName(), className); + } + builder.endObject(); + return builder; + } + + @Override + public Set getRequiredFields() { + return Sets.newHashSet( + EvaluationFields.ACTUAL_FIELD.getPreferredName(), + EvaluationFields.PREDICTED_CLASS_FIELD.getPreferredName(), + EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName()); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRoc that = (AucRoc) o; + return includeCurve == that.includeCurve + && Objects.equals(className, that.className); + } + + @Override + public int hashCode() { + return Objects.hash(includeCurve, className); + } + + @Override + public Tuple, List> aggs(EvaluationParameters parameters, + EvaluationFields fields) { + if (result.get() != null) { + return Tuple.tuple(Arrays.asList(), Arrays.asList()); + } + // Store given {@code fields} for the purpose of generating error messages in {@code process}. + this.fields.trySet(fields); + + double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray(); + AggregationBuilder percentilesAgg = + AggregationBuilders + .percentiles(PERCENTILES_AGG_NAME) + .field(fields.getPredictedProbabilityField()) + .percentiles(percentiles); + AggregationBuilder nestedAgg = + AggregationBuilders + .nested(NESTED_AGG_NAME, fields.getTopClassesField()) + .subAggregation( + AggregationBuilders + .filter(NESTED_FILTER_AGG_NAME, QueryBuilders.termQuery(fields.getPredictedClassField(), className)) + .subAggregation(percentilesAgg)); + QueryBuilder actualIsTrueQuery = QueryBuilders.termQuery(fields.getActualField(), className); + AggregationBuilder percentilesForClassValueAgg = + AggregationBuilders + .filter(TRUE_AGG_NAME, actualIsTrueQuery) + .subAggregation(nestedAgg); + AggregationBuilder percentilesForRestAgg = + AggregationBuilders + .filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery)) + .subAggregation(nestedAgg); + return Tuple.tuple( + Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg), + Arrays.asList()); + } + + @Override + public void process(Aggregations aggs) { + if (result.get() != null) { + return; + } + Filter classAgg = aggs.get(TRUE_AGG_NAME); + Nested classNested = classAgg.getAggregations().get(NESTED_AGG_NAME); + Filter classNestedFilter = classNested.getAggregations().get(NESTED_FILTER_AGG_NAME); + if (classAgg.getDocCount() == 0) { + throw ExceptionsHelper.badRequestException( + "[{}] requires at least one [{}] to have the value [{}]", + getName(), fields.get().getActualField(), className); + } + if (classNestedFilter.getDocCount() == 0) { + throw ExceptionsHelper.badRequestException( + "[{}] requires at least one [{}] to have the value [{}]", + getName(), fields.get().getPredictedClassField(), className); + } + Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME); + double[] tpPercentiles = percentilesArray(classPercentiles); + + Filter restAgg = aggs.get(NON_TRUE_AGG_NAME); + Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME); + Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME); + if (restAgg.getDocCount() == 0) { + throw ExceptionsHelper.badRequestException( + "[{}] requires at least one [{}] to have a different value than [{}]", + getName(), fields.get().getActualField(), className); + } + if (restNestedFilter.getDocCount() == 0) { + throw ExceptionsHelper.badRequestException( + "[{}] requires at least one [{}] to have the value [{}]", + getName(), fields.get().getPredictedClassField(), className); + } + Percentiles restPercentiles = restNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME); + double[] fpPercentiles = percentilesArray(restPercentiles); + + List aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = calculateAucScore(aucRocCurve); + result.set( + new Result( + aucRocScore, + classNestedFilter.getDocCount() + restNestedFilter.getDocCount(), + includeCurve ? aucRocCurve : Collections.emptyList())); + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result.get()); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java index fb8014697555e..ab0fc45461f1c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -13,6 +14,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -21,6 +23,9 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.ACTUAL_FIELD; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.PREDICTED_FIELD; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.TOP_CLASSES_FIELD; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; /** @@ -30,17 +35,22 @@ public class Classification implements Evaluation { public static final ParseField NAME = new ParseField("classification"); - private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); - private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); private static final ParseField METRICS = new ParseField("metrics"); + private static final String DEFAULT_TOP_CLASSES_FIELD = "ml.top_classes"; + private static final String DEFAULT_PREDICTED_CLASS_FIELD_SUFFIX = ".class_name"; + private static final String DEFAULT_PREDICTED_PROBABILITY_FIELD_SUFFIX = ".class_probability"; + @SuppressWarnings("unchecked") - public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List) a[2])); + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + NAME.getPreferredName(), + a -> new Classification((String) a[0], (String) a[1], (String) a[2], (List) a[3])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); - PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTED_FIELD); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TOP_CLASSES_FIELD); PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS); } @@ -50,25 +60,35 @@ public static Classification fromXContent(XContentParser parser) { } /** - * The field containing the actual value - * The value of this field is assumed to be categorical - */ - private final String actualField; - - /** - * The field containing the predicted value - * The value of this field is assumed to be categorical + * The collection of fields in the index being evaluated. + * fields.getActualField() is assumed to be a ground truth label. + * fields.getPredictedField() is assumed to be a predicted label. + * fields.getPredictedClassField() and fields.getPredictedProbabilityField() are assumed to be properties under the same nested field. */ - private final String predictedField; + private final EvaluationFields fields; /** * The list of metrics to calculate */ private final List metrics; - public Classification(String actualField, String predictedField, @Nullable List metrics) { - this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); - this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD); + public Classification(String actualField, + @Nullable String predictedField, + @Nullable String topClassesField, + @Nullable List metrics) { + if (topClassesField == null) { + topClassesField = DEFAULT_TOP_CLASSES_FIELD; + } + String predictedClassField = topClassesField + DEFAULT_PREDICTED_CLASS_FIELD_SUFFIX; + String predictedProbabilityField = topClassesField + DEFAULT_PREDICTED_PROBABILITY_FIELD_SUFFIX; + this.fields = + new EvaluationFields( + ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD), + predictedField, + topClassesField, + predictedClassField, + predictedProbabilityField, + true); this.metrics = initMetrics(metrics, Classification::defaultMetrics); } @@ -77,8 +97,18 @@ private static List defaultMetrics() { } public Classification(StreamInput in) throws IOException { - this.actualField = in.readString(); - this.predictedField = in.readString(); + if (in.getVersion().onOrAfter(Version.V_7_10_0)) { + this.fields = + new EvaluationFields( + in.readString(), + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalString(), + true); + } else { + this.fields = new EvaluationFields(in.readString(), in.readString(), null, null, null, true); + } this.metrics = in.readNamedWriteableList(EvaluationMetric.class); } @@ -88,13 +118,8 @@ public String getName() { } @Override - public String getActualField() { - return actualField; - } - - @Override - public String getPredictedField() { - return predictedField; + public EvaluationFields getFields() { + return fields; } @Override @@ -109,17 +134,28 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(actualField); - out.writeString(predictedField); + out.writeString(fields.getActualField()); + if (out.getVersion().onOrAfter(Version.V_7_10_0)) { + out.writeOptionalString(fields.getPredictedField()); + out.writeOptionalString(fields.getTopClassesField()); + out.writeOptionalString(fields.getPredictedClassField()); + out.writeOptionalString(fields.getPredictedProbabilityField()); + } else { + out.writeString(fields.getPredictedField()); + } out.writeNamedWriteableList(metrics); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_FIELD.getPreferredName(), actualField); - builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); - + builder.field(ACTUAL_FIELD.getPreferredName(), fields.getActualField()); + if (fields.getPredictedField() != null) { + builder.field(PREDICTED_FIELD.getPreferredName(), fields.getPredictedField()); + } + if (fields.getTopClassesField() != null) { + builder.field(TOP_CLASSES_FIELD.getPreferredName(), fields.getTopClassesField()); + } builder.startObject(METRICS.getPreferredName()); for (EvaluationMetric metric : metrics) { builder.field(metric.getName(), metric); @@ -135,13 +171,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Classification that = (Classification) o; - return Objects.equals(that.actualField, this.actualField) - && Objects.equals(that.predictedField, this.predictedField) + return Objects.equals(that.fields, this.fields) && Objects.equals(that.metrics, this.metrics); } @Override public int hashCode() { - return Objects.hash(actualField, predictedField, metrics); + return Objects.hash(fields, metrics); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 13c08098776f5..efc90aea33396 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -27,6 +28,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter; import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.metrics.Cardinality; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -39,6 +41,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import static java.util.Comparator.comparing; @@ -125,10 +128,16 @@ public int getSize() { return size; } + @Override + public Set getRequiredFields() { + return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName()); + } + @Override public final Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedField) { + EvaluationFields fields) { + String actualField = fields.getActualField(); + String predictedField = fields.getPredictedField(); if (topActualClassNames.get() == null && actualClassesCardinality.get() == null) { // This is step 1 return Tuple.tuple( Arrays.asList( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index b90bfd8cce6c6..2367dd9880d96 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -28,6 +29,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter; import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -40,6 +42,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; @@ -90,10 +93,16 @@ public String getName() { return NAME.getPreferredName(); } + @Override + public Set getRequiredFields() { + return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName()); + } + @Override public final Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedField) { + EvaluationFields fields) { + String actualField = fields.getActualField(); + String predictedField = fields.getPredictedField(); // Store given {@code actualField} for the purpose of generating error message in {@code process}. this.actualField.trySet(actualField); if (topActualClassNames.get() == null) { // This is step 1 diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index 24319608150b7..3f8ebd2a557a8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -25,6 +26,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders; import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -37,6 +39,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; @@ -84,10 +87,16 @@ public String getName() { return NAME.getPreferredName(); } + @Override + public Set getRequiredFields() { + return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName()); + } + @Override public final Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedField) { + EvaluationFields fields) { + String actualField = fields.getActualField(); + String predictedField = fields.getPredictedField(); // Store given {@code actualField} for the purpose of generating error message in {@code process}. this.actualField.trySet(actualField); if (result.get() != null) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AbstractConfusionMatrixMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AbstractConfusionMatrixMetric.java index 32f27f56138de..831054b784192 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AbstractConfusionMatrixMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AbstractConfusionMatrixMetric.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; @@ -17,6 +18,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -26,6 +28,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection.actualIsTrueQuery; @@ -66,13 +69,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + @Override + public Set getRequiredFields() { + return Sets.newHashSet( + EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName()); + } + @Override public Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedProbabilityField) { + EvaluationFields fields) { if (result != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } + String actualField = fields.getActualField(); + String predictedProbabilityField = fields.getPredictedProbabilityField(); return Tuple.tuple(aggsAt(actualField, predictedProbabilityField), Collections.emptyList()); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java index 2e7547aea8b6c..98698dad5329e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java @@ -5,14 +5,13 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ConstructingObjectParser; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryBuilders; @@ -21,20 +20,19 @@ import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.bucket.filter.Filter; -import org.elasticsearch.search.aggregations.metrics.Percentiles; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.Comparator; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.stream.IntStream; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; @@ -58,30 +56,28 @@ * When this is used for multi-class classification, it will calculate the ROC * curve of each class versus the rest. */ -public class AucRoc implements EvaluationMetric { - - public static final ParseField NAME = new ParseField("auc_roc"); +public class AucRoc extends AbstractAucRoc { public static final ParseField INCLUDE_CURVE = new ParseField("include_curve"); - public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), - a -> new AucRoc((Boolean) a[0])); + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME.getPreferredName(), a -> new AucRoc((Boolean) a[0])); static { PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE); } - private static final String PERCENTILES = "percentiles"; - private static final String TRUE_AGG_NAME = NAME.getPreferredName() + "_true"; private static final String NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true"; + private static final String PERCENTILES_AGG_NAME = "percentiles"; public static AucRoc fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } private final boolean includeCurve; - private EvaluationMetricResult result; + private final SetOnce fields = new SetOnce<>(); + private final SetOnce result = new SetOnce<>(); public AucRoc(Boolean includeCurve) { this.includeCurve = includeCurve == null ? false : includeCurve; @@ -110,8 +106,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - public String getName() { - return NAME.getPreferredName(); + public Set getRequiredFields() { + return Sets.newHashSet( + EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName()); } @Override @@ -119,7 +116,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; AucRoc that = (AucRoc) o; - return Objects.equals(includeCurve, that.includeCurve); + return includeCurve == that.includeCurve; } @Override @@ -129,22 +126,29 @@ public int hashCode() { @Override public Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedProbabilityField) { - if (result != null) { + EvaluationFields fields) { + if (result.get() != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } + // Store given {@code fields} for the purpose of generating error messages in {@code process}. + this.fields.trySet(fields); + + String actualField = fields.getActualField(); + String predictedProbabilityField = fields.getPredictedProbabilityField(); double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray(); + AggregationBuilder percentilesAgg = + AggregationBuilders + .percentiles(PERCENTILES_AGG_NAME) + .field(predictedProbabilityField) + .percentiles(percentiles); AggregationBuilder percentilesForClassValueAgg = AggregationBuilders .filter(TRUE_AGG_NAME, actualIsTrueQuery(actualField)) - .subAggregation( - AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles)); + .subAggregation(percentilesAgg); AggregationBuilder percentilesForRestAgg = AggregationBuilders .filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery(actualField))) - .subAggregation( - AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles)); + .subAggregation(percentilesAgg); return Tuple.tuple( Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg), Collections.emptyList()); @@ -152,216 +156,33 @@ public Tuple, List> aggs(Ev @Override public void process(Aggregations aggs) { + if (result.get() != null) { + return; + } Filter classAgg = aggs.get(TRUE_AGG_NAME); + if (classAgg.getDocCount() == 0) { + throw ExceptionsHelper.badRequestException( + "[{}] requires at least one [{}] to have the value [{}]", getName(), fields.get().getActualField(), "true"); + } + double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES_AGG_NAME)); Filter restAgg = aggs.get(NON_TRUE_AGG_NAME); - double[] tpPercentiles = - percentilesArray( - classAgg.getAggregations().get(PERCENTILES), - "[" + getName() + "] requires at least one actual_field to have the value [true]"); - double[] fpPercentiles = - percentilesArray( - restAgg.getAggregations().get(PERCENTILES), - "[" + getName() + "] requires at least one actual_field to have a different value than [true]"); + if (restAgg.getDocCount() == 0) { + throw ExceptionsHelper.badRequestException( + "[{}] requires at least one [{}] to have a different value than [{}]", getName(), fields.get().getActualField(), "true"); + } + double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES_AGG_NAME)); + List aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); double aucRocScore = calculateAucScore(aucRocCurve); - result = new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()); + result.set( + new Result( + aucRocScore, + classAgg.getDocCount() + restAgg.getDocCount(), + includeCurve ? aucRocCurve : Collections.emptyList())); } @Override public Optional getResult() { - return Optional.ofNullable(result); - } - - private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) { - double[] result = new double[99]; - percentiles.forEach(percentile -> { - if (Double.isNaN(percentile.getValue())) { - throw ExceptionsHelper.badRequestException(errorIfUndefined); - } - result[((int) percentile.getPercent()) - 1] = percentile.getValue(); - }); - return result; - } - - /** - * Visible for testing - */ - static List buildAucRocCurve(double[] tpPercentiles, double[] fpPercentiles) { - assert tpPercentiles.length == fpPercentiles.length; - assert tpPercentiles.length == 99; - - List aucRocCurve = new ArrayList<>(); - aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0)); - aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0)); - RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true); - RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false); - aucRocCurve.addAll(tpCurve.scanPoints(fpCurve)); - aucRocCurve.addAll(fpCurve.scanPoints(tpCurve)); - Collections.sort(aucRocCurve); - return aucRocCurve; - } - - /** - * Visible for testing - */ - static double calculateAucScore(List rocCurve) { - // Calculates AUC based on the trapezoid rule - double aucRoc = 0.0; - for (int i = 1; i < rocCurve.size(); i++) { - AucRocPoint left = rocCurve.get(i - 1); - AucRocPoint right = rocCurve.get(i); - aucRoc += (right.fpr - left.fpr) * (right.tpr + left.tpr) / 2; - } - return aucRoc; - } - - private static class RateThresholdCurve { - - private final double[] percentiles; - private final boolean isTp; - - private RateThresholdCurve(double[] percentiles, boolean isTp) { - this.percentiles = percentiles; - this.isTp = isTp; - } - - private double getRate(int index) { - return 1 - 0.01 * (index + 1); - } - - private double getThreshold(int index) { - return percentiles[index]; - } - - private double interpolateRate(double threshold) { - int binarySearchResult = Arrays.binarySearch(percentiles, threshold); - if (binarySearchResult >= 0) { - return getRate(binarySearchResult); - } else { - int right = (binarySearchResult * -1) -1; - int left = right - 1; - if (right >= percentiles.length) { - return 0.0; - } else if (left < 0) { - return 1.0; - } else { - double rightRate = getRate(right); - double leftRate = getRate(left); - return interpolate(threshold, percentiles[left], leftRate, percentiles[right], rightRate); - } - } - } - - private List scanPoints(RateThresholdCurve againstCurve) { - List points = new ArrayList<>(); - for (int index = 0; index < percentiles.length; index++) { - double rate = getRate(index); - double scannedThreshold = getThreshold(index); - double againstRate = againstCurve.interpolateRate(scannedThreshold); - AucRocPoint point; - if (isTp) { - point = new AucRocPoint(rate, againstRate, scannedThreshold); - } else { - point = new AucRocPoint(againstRate, rate, scannedThreshold); - } - points.add(point); - } - return points; - } - } - - public static final class AucRocPoint implements Comparable, ToXContentObject, Writeable { - double tpr; - double fpr; - double threshold; - - private AucRocPoint(double tpr, double fpr, double threshold) { - this.tpr = tpr; - this.fpr = fpr; - this.threshold = threshold; - } - - private AucRocPoint(StreamInput in) throws IOException { - this.tpr = in.readDouble(); - this.fpr = in.readDouble(); - this.threshold = in.readDouble(); - } - - @Override - public int compareTo(AucRocPoint o) { - return Comparator.comparingDouble((AucRocPoint p) -> p.threshold).reversed() - .thenComparing(p -> p.fpr) - .thenComparing(p -> p.tpr) - .compare(this, o); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeDouble(tpr); - out.writeDouble(fpr); - out.writeDouble(threshold); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("tpr", tpr); - builder.field("fpr", fpr); - builder.field("threshold", threshold); - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - } - - private static double interpolate(double x, double x1, double y1, double x2, double y2) { - return y1 + (x - x1) * (y2 - y1) / (x2 - x1); - } - - public static class Result implements EvaluationMetricResult { - - private final double score; - private final List curve; - - public Result(double score, List curve) { - this.score = score; - this.curve = Objects.requireNonNull(curve); - } - - public Result(StreamInput in) throws IOException { - this.score = in.readDouble(); - this.curve = in.readList(AucRocPoint::new); - } - - @Override - public String getWriteableName() { - return registeredMetricName(OutlierDetection.NAME, NAME); - } - - @Override - public String getMetricName() { - return NAME.getPreferredName(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeDouble(score); - out.writeList(curve); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field("score", score); - if (curve.isEmpty() == false) { - builder.field("curve", curve); - } - builder.endObject(); - return builder; - } + return Optional.ofNullable(result.get()); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetection.java index 3250272b03d16..73057cb06f08c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetection.java @@ -15,6 +15,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -23,6 +24,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.ACTUAL_FIELD; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.PREDICTED_PROBABILITY_FIELD; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; /** @@ -32,8 +35,6 @@ public class OutlierDetection implements Evaluation { public static final ParseField NAME = new ParseField("outlier_detection", "binary_soft_classification"); - private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); - private static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field"); private static final ParseField METRICS = new ParseField("metrics"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -50,30 +51,34 @@ public static OutlierDetection fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - static QueryBuilder actualIsTrueQuery(String actualField) { + public static QueryBuilder actualIsTrueQuery(String actualField) { return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)"); } /** - * The field where the actual class is marked up. - * The value of this field is assumed to either be 1 or 0, or true or false. + * The collection of fields in the index being evaluated. + * fields.getActualField() is assumed to either be 1 or 0, or true or false. + * fields.getPredictedProbabilityField() is assumed to be a number in [0.0, 1.0]. + * Other fields are not needed by this evaluation. */ - private final String actualField; - - /** - * The field of the predicted probability in [0.0, 1.0]. - */ - private final String predictedProbabilityField; + private final EvaluationFields fields; /** * The list of metrics to calculate */ private final List metrics; - public OutlierDetection(String actualField, String predictedProbabilityField, + public OutlierDetection(String actualField, + String predictedProbabilityField, @Nullable List metrics) { - this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); - this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD); + this.fields = + new EvaluationFields( + ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD), + null, + null, + null, + ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD), + false); this.metrics = initMetrics(metrics, OutlierDetection::defaultMetrics); } @@ -86,8 +91,7 @@ private static List defaultMetrics() { } public OutlierDetection(StreamInput in) throws IOException { - this.actualField = in.readString(); - this.predictedProbabilityField = in.readString(); + this.fields = new EvaluationFields(in.readString(), null, null, null, in.readString(), false); this.metrics = in.readNamedWriteableList(EvaluationMetric.class); } @@ -97,13 +101,8 @@ public String getName() { } @Override - public String getActualField() { - return actualField; - } - - @Override - public String getPredictedField() { - return predictedProbabilityField; + public EvaluationFields getFields() { + return fields; } @Override @@ -118,16 +117,16 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(actualField); - out.writeString(predictedProbabilityField); + out.writeString(fields.getActualField()); + out.writeString(fields.getPredictedProbabilityField()); out.writeNamedWriteableList(metrics); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_FIELD.getPreferredName(), actualField); - builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField); + builder.field(ACTUAL_FIELD.getPreferredName(), fields.getActualField()); + builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), fields.getPredictedProbabilityField()); builder.startObject(METRICS.getPreferredName()); for (EvaluationMetric metric : metrics) { @@ -144,13 +143,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; OutlierDetection that = (OutlierDetection) o; - return Objects.equals(actualField, that.actualField) - && Objects.equals(predictedProbabilityField, that.predictedProbabilityField) + return Objects.equals(fields, that.fields) && Objects.equals(metrics, that.metrics); } @Override public int hashCode() { - return Objects.hash(actualField, predictedProbabilityField, metrics); + return Objects.hash(fields, metrics); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java index 7be8946b2939e..978ac0c74cded 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -20,6 +21,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -31,6 +33,7 @@ import java.util.List; import java.util.Locale; import java.util.Optional; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; @@ -86,13 +89,19 @@ public String getName() { return NAME.getPreferredName(); } + @Override + public Set getRequiredFields() { + return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName()); + } + @Override public Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedField) { + EvaluationFields fields) { if (result != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } + String actualField = fields.getActualField(); + String predictedField = fields.getPredictedField(); return Tuple.tuple( Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField, delta * delta)))), Collections.emptyList()); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java index 2637109646a25..4f6dbe28f828d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -19,6 +20,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -31,6 +33,7 @@ import java.util.Locale; import java.util.Objects; import java.util.Optional; +import java.util.Set; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; @@ -70,13 +73,19 @@ public String getName() { return NAME.getPreferredName(); } + @Override + public Set getRequiredFields() { + return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName()); + } + @Override public Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedField) { + EvaluationFields fields) { if (result != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } + String actualField = fields.getActualField(); + String predictedField = fields.getPredictedField(); return Tuple.tuple( Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))), Collections.emptyList()); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java index af2af28ce0490..d87004513dae9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -20,6 +21,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -31,6 +33,7 @@ import java.util.List; import java.util.Locale; import java.util.Optional; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; @@ -85,13 +88,19 @@ public String getName() { return NAME.getPreferredName(); } + @Override + public Set getRequiredFields() { + return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName()); + } + @Override public Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedField) { + EvaluationFields fields) { if (result != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } + String actualField = fields.getActualField(); + String predictedField = fields.getPredictedField(); return Tuple.tuple( Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField, offset)))), Collections.emptyList()); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java index 16a4b358a623b..32125ff78b954 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -20,6 +21,7 @@ import org.elasticsearch.search.aggregations.metrics.ExtendedStats; import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -32,6 +34,7 @@ import java.util.Locale; import java.util.Objects; import java.util.Optional; +import java.util.Set; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; @@ -74,13 +77,19 @@ public String getName() { return NAME.getPreferredName(); } + @Override + public Set getRequiredFields() { + return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName()); + } + @Override public Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedField) { + EvaluationFields fields) { if (result != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } + String actualField = fields.getActualField(); + String predictedField = fields.getPredictedField(); return Tuple.tuple( Arrays.asList( AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java index cc32ea4049282..a90e6821255d7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -21,6 +22,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.ACTUAL_FIELD; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.PREDICTED_FIELD; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; /** @@ -30,8 +33,6 @@ public class Regression implements Evaluation { public static final ParseField NAME = new ParseField("regression"); - private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); - private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); private static final ParseField METRICS = new ParseField("metrics"); @SuppressWarnings("unchecked") @@ -50,16 +51,12 @@ public static Regression fromXContent(XContentParser parser) { } /** - * The field containing the actual value - * The value of this field is assumed to be numeric + * The collection of fields in the index being evaluated. + * fields.getActualField() is assumed to be numeric. + * fields.getPredictedField() is assumed to be numeric. + * Other fields are not needed by this evaluation. */ - private final String actualField; - - /** - * The field containing the predicted value - * The value of this field is assumed to be numeric - */ - private final String predictedField; + private final EvaluationFields fields; /** * The list of metrics to calculate @@ -67,8 +64,14 @@ public static Regression fromXContent(XContentParser parser) { private final List metrics; public Regression(String actualField, String predictedField, @Nullable List metrics) { - this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); - this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD); + this.fields = + new EvaluationFields( + ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD), + ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD), + null, + null, + null, + false); this.metrics = initMetrics(metrics, Regression::defaultMetrics); } @@ -77,8 +80,7 @@ private static List defaultMetrics() { } public Regression(StreamInput in) throws IOException { - this.actualField = in.readString(); - this.predictedField = in.readString(); + this.fields = new EvaluationFields(in.readString(), in.readString(), null, null, null, false); this.metrics = in.readNamedWriteableList(EvaluationMetric.class); } @@ -88,13 +90,8 @@ public String getName() { } @Override - public String getActualField() { - return actualField; - } - - @Override - public String getPredictedField() { - return predictedField; + public EvaluationFields getFields() { + return fields; } @Override @@ -109,16 +106,16 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(actualField); - out.writeString(predictedField); + out.writeString(fields.getActualField()); + out.writeString(fields.getPredictedField()); out.writeNamedWriteableList(metrics); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_FIELD.getPreferredName(), actualField); - builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); + builder.field(ACTUAL_FIELD.getPreferredName(), fields.getActualField()); + builder.field(PREDICTED_FIELD.getPreferredName(), fields.getPredictedField()); builder.startObject(METRICS.getPreferredName()); for (EvaluationMetric metric : metrics) { @@ -135,13 +132,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Regression that = (Regression) o; - return Objects.equals(that.actualField, this.actualField) - && Objects.equals(that.predictedField, this.predictedField) + return Objects.equals(that.fields, this.fields) && Objects.equals(that.metrics, this.metrics); } @Override public int hashCode() { - return Objects.hash(actualField, predictedField, metrics); + return Objects.hash(fields, metrics); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java index 02767366c679e..12bb665f32b34 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java @@ -11,13 +11,14 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Response; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRocResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AccuracyResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PrecisionResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import java.util.Arrays; @@ -25,6 +26,10 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializingTestCase { + private static final String OUTLIER_DETECTION = "outlier_detection"; + private static final String CLASSIFICATION = "classification"; + private static final String REGRESSION = "regression"; + @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); @@ -32,18 +37,35 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { @Override protected Response createTestInstance() { - String evaluationName = randomAlphaOfLength(10); - List metrics = - Arrays.asList( - AccuracyResultTests.createRandom(), - PrecisionResultTests.createRandom(), - RecallResultTests.createRandom(), - MulticlassConfusionMatrixResultTests.createRandom(), - new MeanSquaredError.Result(randomDouble()), - new MeanSquaredLogarithmicError.Result(randomDouble()), - new Huber.Result(randomDouble()), - new RSquared.Result(randomDouble())); - return new Response(evaluationName, randomSubsetOf(metrics)); + String evaluationName = randomFrom(OUTLIER_DETECTION, CLASSIFICATION, REGRESSION); + List metrics; + switch (evaluationName) { + case OUTLIER_DETECTION: + metrics = randomSubsetOf( + Arrays.asList( + AucRocResultTests.createRandom())); + break; + case CLASSIFICATION: + metrics = randomSubsetOf( + Arrays.asList( + AucRocResultTests.createRandom(), + AccuracyResultTests.createRandom(), + PrecisionResultTests.createRandom(), + RecallResultTests.createRandom(), + MulticlassConfusionMatrixResultTests.createRandom())); + break; + case REGRESSION: + metrics = randomSubsetOf( + Arrays.asList( + new MeanSquaredError.Result(randomDouble()), + new MeanSquaredLogarithmicError.Result(randomDouble()), + new Huber.Result(randomDouble()), + new RSquared.Result(randomDouble()))); + break; + default: + throw new AssertionError("Please add missing \"case\" variant to the \"switch\" statement"); + } + return new Response(evaluationName, metrics); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 861e349317003..ca963b2e1390f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -44,7 +44,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -366,15 +365,27 @@ public void testGetExplicitlyMappedFields() { assertThat( new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"), equalTo(Collections.singletonMap("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING))); + Map expectedTopClassesMapping = new HashMap() {{ + put("type", "nested"); + put("properties", new HashMap() {{ + put("class_name", Collections.singletonMap("bar", "baz")); + put("class_probability", Collections.singletonMap("type", "double")); + }}); + }}; Map explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields( Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")), "results"); - assertThat(explicitlyMappedFields, - allOf( - hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")), - hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz")))); + assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz"))); + assertThat(explicitlyMappedFields, hasEntry("results.top_classes", expectedTopClassesMapping)); assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING)); + expectedTopClassesMapping = new HashMap() {{ + put("type", "nested"); + put("properties", new HashMap() {{ + put("class_name", Collections.singletonMap("type", "long")); + put("class_probability", Collections.singletonMap("type", "double")); + }}); + }}; explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields( new HashMap() {{ put("foo", new HashMap() {{ @@ -384,10 +395,8 @@ public void testGetExplicitlyMappedFields() { put("bar", Collections.singletonMap("type", "long")); }}, "results"); - assertThat(explicitlyMappedFields, - allOf( - hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")), - hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long")))); + assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "long"))); + assertThat(explicitlyMappedFields, hasEntry("results.top_classes", expectedTopClassesMapping)); assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING)); assertThat( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationFieldsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationFieldsTests.java new file mode 100644 index 0000000000000..d591b138b3846 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationFieldsTests.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.test.ESTestCase; + +import static java.util.stream.Collectors.toList; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + +public class EvaluationFieldsTests extends ESTestCase { + + public void testConstructorAndGetters() { + EvaluationFields fields = new EvaluationFields("a", "b", "c", "d", "e", true); + assertThat(fields.getActualField(), is(equalTo("a"))); + assertThat(fields.getPredictedField(), is(equalTo("b"))); + assertThat(fields.getTopClassesField(), is(equalTo("c"))); + assertThat(fields.getPredictedClassField(), is(equalTo("d"))); + assertThat(fields.getPredictedProbabilityField(), is(equalTo("e"))); + assertThat(fields.isPredictedProbabilityFieldNested(), is(true)); + } + + public void testConstructorAndGetters_WithNullValues() { + EvaluationFields fields = new EvaluationFields("a", null, "c", null, "e", true); + assertThat(fields.getActualField(), is(equalTo("a"))); + assertThat(fields.getPredictedField(), is(nullValue())); + assertThat(fields.getTopClassesField(), is(equalTo("c"))); + assertThat(fields.getPredictedClassField(), is(nullValue())); + assertThat(fields.getPredictedProbabilityField(), is(equalTo("e"))); + assertThat(fields.isPredictedProbabilityFieldNested(), is(true)); + } + + public void testListPotentiallyRequiredFields() { + EvaluationFields fields = new EvaluationFields("a", "b", "c", "d", "e", randomBoolean()); + assertThat(fields.listPotentiallyRequiredFields().stream().map(Tuple::v2).collect(toList()), contains("a", "b", "c", "d", "e")); + } + + public void testListPotentiallyRequiredFields_WithNullValues() { + EvaluationFields fields = new EvaluationFields("a", null, "c", null, "e", randomBoolean()); + assertThat(fields.listPotentiallyRequiredFields().stream().map(Tuple::v2).collect(toList()), contains("a", null, "c", null, "e")); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRocTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRocTests.java new file mode 100644 index 0000000000000..fef8418edb9d7 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRocTests.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class AbstractAucRocTests extends ESTestCase { + + public void testCalculateAucScore_GivenZeroPercentiles() { + double[] tpPercentiles = zeroPercentiles(); + double[] fpPercentiles = zeroPercentiles(); + + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = AucRoc.calculateAucScore(curve); + + assertThat(aucRocScore, closeTo(0.5, 0.01)); + } + + public void testCalculateAucScore_GivenRandomTpPercentilesAndZeroFpPercentiles() { + double[] tpPercentiles = randomPercentiles(); + double[] fpPercentiles = zeroPercentiles(); + + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = AucRoc.calculateAucScore(curve); + + assertThat(aucRocScore, closeTo(1.0, 0.1)); + } + + public void testCalculateAucScore_GivenZeroTpPercentilesAndRandomFpPercentiles() { + double[] tpPercentiles = zeroPercentiles(); + double[] fpPercentiles = randomPercentiles(); + + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = AucRoc.calculateAucScore(curve); + + assertThat(aucRocScore, closeTo(0.0, 0.1)); + } + + public void testCalculateAucScore_GivenRandomPercentiles() { + for (int i = 0; i < 20; i++) { + double[] tpPercentiles = randomPercentiles(); + double[] fpPercentiles = randomPercentiles(); + + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = AucRoc.calculateAucScore(curve); + + List inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles); + double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve); + + assertThat(aucRocScore, greaterThanOrEqualTo(0.0)); + assertThat(aucRocScore, lessThanOrEqualTo(1.0)); + assertThat(inverseAucRocScore, greaterThanOrEqualTo(0.0)); + assertThat(inverseAucRocScore, lessThanOrEqualTo(1.0)); + assertThat(aucRocScore + inverseAucRocScore, closeTo(1.0, 0.05)); + } + } + + public void testCalculateAucScore_GivenPrecalculated() { + double[] tpPercentiles = new double[99]; + double[] fpPercentiles = new double[99]; + + double[] tpSimplified = new double[] { 0.3, 0.6, 0.5 , 0.8 }; + double[] fpSimplified = new double[] { 0.1, 0.3, 0.5 , 0.5 }; + + for (int i = 0; i < tpPercentiles.length; i++) { + int simplifiedIndex = i / 25; + tpPercentiles[i] = tpSimplified[simplifiedIndex]; + fpPercentiles[i] = fpSimplified[simplifiedIndex]; + } + + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = AucRoc.calculateAucScore(curve); + + List inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles); + double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve); + + assertThat(aucRocScore, closeTo(0.8, 0.05)); + assertThat(inverseAucRocScore, closeTo(0.2, 0.05)); + } + + public static double[] zeroPercentiles() { + double[] percentiles = new double[99]; + Arrays.fill(percentiles, 0.0); + return percentiles; + } + + public static double[] randomPercentiles() { + double[] percentiles = new double[99]; + for (int i = 0; i < percentiles.length; i++) { + percentiles[i] = randomDouble(); + } + Arrays.sort(percentiles); + return percentiles; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java index 35a9a85d135a3..7d89d961fe12d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; @@ -32,6 +33,7 @@ public class AccuracyTests extends AbstractSerializingTestCase { private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100); + private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true); @Override protected Accuracy doParseInstance(XContentParser parser) throws IOException { @@ -88,7 +90,7 @@ public void testProcess() { Accuracy accuracy = new Accuracy(); accuracy.process(aggs); - assertThat(accuracy.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); + assertThat(accuracy.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty())); Result result = accuracy.getResult().get(); assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); @@ -130,7 +132,7 @@ public void testProcess_GivenCardinalityTooHigh() { mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); Accuracy accuracy = new Accuracy(); - accuracy.aggs(EVALUATION_PARAMETERS, "foo", "bar"); + accuracy.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs)); assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java new file mode 100644 index 0000000000000..1a03e2e0c2c78 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc.AucRocPoint; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc.Result; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class AucRocResultTests extends AbstractWireSerializingTestCase { + + public static Result createRandom() { + double score = randomDoubleBetween(0.0, 1.0, true); + Long docCount = randomBoolean() ? randomLong() : null; + List curve = + Stream + .generate(() -> new AucRocPoint(randomDouble(), randomDouble(), randomDouble())) + .limit(randomIntBetween(0, 20)) + .collect(Collectors.toList()); + return new Result(score, docCount, curve); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); + } + + @Override + protected Result createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Result::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocTests.java new file mode 100644 index 0000000000000..403a837c5605b --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +public class AucRocTests extends AbstractSerializingTestCase { + + @Override + protected AucRoc doParseInstance(XContentParser parser) throws IOException { + return AucRoc.PARSER.apply(parser, null); + } + + @Override + protected AucRoc createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return AucRoc::new; + } + + public static AucRoc createRandom() { + return new AucRoc(randomBoolean() ? randomBoolean() : null, randomAlphaOfLength(randomIntBetween(2, 10))); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java index ed1789c3d3875..5466d2fa088ba 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -6,12 +6,14 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -23,6 +25,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -33,6 +36,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; @@ -61,10 +65,17 @@ public static Classification createRandom() { randomSubsetOf( Arrays.asList( AccuracyTests.createRandom(), + AucRocTests.createRandom(), PrecisionTests.createRandom(), RecallTests.createRandom(), MulticlassConfusionMatrixTests.createRandom())); - return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); + boolean usesAucRoc = metrics.stream().map(EvaluationMetric::getName).anyMatch(n -> AucRoc.NAME.getPreferredName().equals(n)); + return new Classification( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + // If AucRoc is to be calculated, the top_classes field is required + (usesAucRoc || randomBoolean()) ? randomAlphaOfLength(10) : null, + metrics.isEmpty() ? null : metrics); } @Override @@ -82,13 +93,35 @@ protected Writeable.Reader instanceReader() { return Classification::new; } + public void testConstructor_GivenMissingField() { + FakeClassificationMetric metric = new FakeClassificationMetric("fake"); + ElasticsearchStatusException e = + expectThrows( + ElasticsearchStatusException.class, + () -> new Classification("foo", null, null, Collections.singletonList(metric))); + assertThat( + e.getMessage(), + is(equalTo("[classification] must define [predicted_field] as required by the following metrics [fake]"))); + } + public void testConstructor_GivenEmptyMetrics() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", "bar", Collections.emptyList())); + () -> new Classification("foo", "bar", "results", Collections.emptyList())); assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics")); } - public void testBuildSearch() { + public void testGetFields() { + Classification evaluation = new Classification("foo", "bar", "results", null); + EvaluationFields fields = evaluation.getFields(); + assertThat(fields.getActualField(), is(equalTo("foo"))); + assertThat(fields.getPredictedField(), is(equalTo("bar"))); + assertThat(fields.getTopClassesField(), is(equalTo("results"))); + assertThat(fields.getPredictedClassField(), is(equalTo("results.class_name"))); + assertThat(fields.getPredictedProbabilityField(), is(equalTo("results.class_probability"))); + assertThat(fields.isPredictedProbabilityFieldNested(), is(true)); + } + + public void testBuildSearch_WithDefaultNonRequiredNestedFields() { QueryBuilder userProvidedQuery = QueryBuilders.boolQuery() .filter(QueryBuilders.termQuery("field_A", "some-value")) @@ -101,7 +134,78 @@ public void testBuildSearch() { .filter(QueryBuilders.termQuery("field_A", "some-value")) .filter(QueryBuilders.termQuery("field_B", "some-other-value"))); - Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix())); + Classification evaluation = new Classification("act", "pred", null, Arrays.asList(new MulticlassConfusionMatrix())); + + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery); + assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); + assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); + } + + public void testBuildSearch_WithExplicitNonRequiredNestedFields() { + QueryBuilder userProvidedQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value")); + QueryBuilder expectedSearchQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("act")) + .filter(QueryBuilders.existsQuery("pred")) + .filter(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value"))); + + Classification evaluation = new Classification("act", "pred", "results", Arrays.asList(new MulticlassConfusionMatrix())); + + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery); + assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); + assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); + } + + public void testBuildSearch_WithDefaultRequiredNestedFields() { + QueryBuilder userProvidedQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value")); + QueryBuilder expectedSearchQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("act")) + .filter( + QueryBuilders.nestedQuery("ml.top_classes", QueryBuilders.existsQuery("ml.top_classes.class_name"), ScoreMode.None) + .ignoreUnmapped(true)) + .filter( + QueryBuilders.nestedQuery( + "ml.top_classes", QueryBuilders.existsQuery("ml.top_classes.class_probability"), ScoreMode.None) + .ignoreUnmapped(true)) + .filter(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value"))); + + Classification evaluation = new Classification("act", "pred", null, Arrays.asList(new AucRoc(false, "some-value"))); + + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery); + assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); + assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); + } + + public void testBuildSearch_WithExplicitRequiredNestedFields() { + QueryBuilder userProvidedQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value")); + QueryBuilder expectedSearchQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("act")) + .filter( + QueryBuilders.nestedQuery("results", QueryBuilders.existsQuery("results.class_name"), ScoreMode.None) + .ignoreUnmapped(true)) + .filter( + QueryBuilders.nestedQuery("results", QueryBuilders.existsQuery("results.class_probability"), ScoreMode.None) + .ignoreUnmapped(true)) + .filter(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value"))); + + Classification evaluation = new Classification("act", "pred", "results", Arrays.asList(new AucRoc(false, "some-value"))); SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery); assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); @@ -114,7 +218,7 @@ public void testProcess_MultipleMetricsWithDifferentNumberOfSteps() { EvaluationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4); EvaluationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5); - Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4)); + Classification evaluation = new Classification("act", "pred", null, Arrays.asList(metric1, metric2, metric3, metric4)); assertThat(metric1.getResult(), isEmpty()); assertThat(metric2.getResult(), isEmpty()); assertThat(metric3.getResult(), isEmpty()); @@ -183,6 +287,10 @@ private static class FakeClassificationMetric implements EvaluationMetric { private int currentStepIndex; private EvaluationMetricResult result; + FakeClassificationMetric(String name) { + this(name, 1); + } + FakeClassificationMetric(String name, int numSteps) { this.name = name; this.numSteps = numSteps; @@ -198,10 +306,14 @@ public String getWriteableName() { return name; } + @Override + public Set getRequiredFields() { + return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName()); + } + @Override public Tuple, List> aggs(EvaluationParameters parameters, - String actualField, - String predictedField) { + EvaluationFields fields) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index e6662c0429bde..12b8aae95d196 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; @@ -37,6 +38,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase { private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100); + private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true); @Override protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException { @@ -83,7 +85,8 @@ public void testConstructor_SizeValidationFailures() { public void testAggs() { MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(); - Tuple, List> aggs = confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"); + Tuple, List> aggs = + confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS); assertThat(aggs, isTuple(not(empty()), empty())); assertThat(confusionMatrix.getResult(), isEmpty()); } @@ -119,7 +122,7 @@ public void testProcess() { MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); - assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); + assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty())); Result result = confusionMatrix.getResult().get(); assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( @@ -162,7 +165,7 @@ public void testProcess_OtherClassesCountGreaterThanZero() { MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); - assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); + assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty())); Result result = confusionMatrix.getResult().get(); assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( @@ -246,7 +249,7 @@ public void testProcess_MoreThanTwoStepsNeeded() { confusionMatrix.process(aggsStep2); confusionMatrix.process(aggsStep3); - assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); + assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty())); Result result = confusionMatrix.getResult().get(); assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java index 81c734863408d..028680f4bf131 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import java.io.IOException; @@ -28,6 +29,7 @@ public class PrecisionTests extends AbstractSerializingTestCase { private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100); + private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true); @Override protected Precision doParseInstance(XContentParser parser) throws IOException { @@ -64,7 +66,7 @@ public void testProcess() { Precision precision = new Precision(); precision.process(aggs); - assertThat(precision.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); + assertThat(precision.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty())); assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123))); } @@ -114,7 +116,7 @@ public void testProcess_GivenCardinalityTooHigh() { Aggregations aggs = new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1))); Precision precision = new Precision(); - precision.aggs(EVALUATION_PARAMETERS, "foo", "bar"); + precision.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs)); assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java index efced860b9192..e5236fb704d84 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import java.io.IOException; @@ -27,6 +28,7 @@ public class RecallTests extends AbstractSerializingTestCase { private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100); + private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true); @Override protected Recall doParseInstance(XContentParser parser) throws IOException { @@ -62,7 +64,7 @@ public void testProcess() { Recall recall = new Recall(); recall.process(aggs); - assertThat(recall.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty())); + assertThat(recall.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty())); assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123))); } @@ -113,7 +115,7 @@ public void testProcess_GivenCardinalityTooHigh() { mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1), mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123))); Recall recall = new Recall(); - recall.aggs(EVALUATION_PARAMETERS, "foo", "bar"); + recall.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs)); assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRocTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRocTests.java index 610b4830b57d2..6d8a472d97538 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRocTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRocTests.java @@ -10,12 +10,6 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; -import java.util.Arrays; -import java.util.List; - -import static org.hamcrest.Matchers.closeTo; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.lessThanOrEqualTo; public class AucRocTests extends AbstractSerializingTestCase { @@ -37,91 +31,4 @@ protected Writeable.Reader instanceReader() { public static AucRoc createRandom() { return new AucRoc(randomBoolean() ? randomBoolean() : null); } - - public void testCalculateAucScore_GivenZeroPercentiles() { - double[] tpPercentiles = zeroPercentiles(); - double[] fpPercentiles = zeroPercentiles(); - - List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); - double aucRocScore = AucRoc.calculateAucScore(curve); - - assertThat(aucRocScore, closeTo(0.5, 0.01)); - } - - public void testCalculateAucScore_GivenRandomTpPercentilesAndZeroFpPercentiles() { - double[] tpPercentiles = randomPercentiles(); - double[] fpPercentiles = zeroPercentiles(); - - List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); - double aucRocScore = AucRoc.calculateAucScore(curve); - - assertThat(aucRocScore, closeTo(1.0, 0.1)); - } - - public void testCalculateAucScore_GivenZeroTpPercentilesAndRandomFpPercentiles() { - double[] tpPercentiles = zeroPercentiles(); - double[] fpPercentiles = randomPercentiles(); - - List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); - double aucRocScore = AucRoc.calculateAucScore(curve); - - assertThat(aucRocScore, closeTo(0.0, 0.1)); - } - - public void testCalculateAucScore_GivenRandomPercentiles() { - for (int i = 0; i < 20; i++) { - double[] tpPercentiles = randomPercentiles(); - double[] fpPercentiles = randomPercentiles(); - - List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); - double aucRocScore = AucRoc.calculateAucScore(curve); - - List inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles); - double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve); - - assertThat(aucRocScore, greaterThanOrEqualTo(0.0)); - assertThat(aucRocScore, lessThanOrEqualTo(1.0)); - assertThat(inverseAucRocScore, greaterThanOrEqualTo(0.0)); - assertThat(inverseAucRocScore, lessThanOrEqualTo(1.0)); - assertThat(aucRocScore + inverseAucRocScore, closeTo(1.0, 0.05)); - } - } - - public void testCalculateAucScore_GivenPrecalculated() { - double[] tpPercentiles = new double[99]; - double[] fpPercentiles = new double[99]; - - double[] tpSimplified = new double[] { 0.3, 0.6, 0.5 , 0.8 }; - double[] fpSimplified = new double[] { 0.1, 0.3, 0.5 , 0.5 }; - - for (int i = 0; i < tpPercentiles.length; i++) { - int simplifiedIndex = i / 25; - tpPercentiles[i] = tpSimplified[simplifiedIndex]; - fpPercentiles[i] = fpSimplified[simplifiedIndex]; - } - - List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); - double aucRocScore = AucRoc.calculateAucScore(curve); - - List inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles); - double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve); - - assertThat(aucRocScore, closeTo(0.8, 0.05)); - assertThat(inverseAucRocScore, closeTo(0.2, 0.05)); - } - - public static double[] zeroPercentiles() { - double[] percentiles = new double[99]; - Arrays.fill(percentiles, 0.0); - return percentiles; - } - - public static double[] randomPercentiles() { - double[] percentiles = new double[99]; - for (int i = 0; i < percentiles.length; i++) { - percentiles[i] = randomDouble(); - } - Arrays.sort(percentiles); - return percentiles; - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java index cc560e6495927..c0b72dbe1c234 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; @@ -26,6 +27,8 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; public class OutlierDetectionTests extends AbstractSerializingTestCase { @@ -86,6 +89,17 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[outlier_detection] must have one or more metrics")); } + public void testGetFields() { + OutlierDetection evaluation = new OutlierDetection("foo", "bar", null); + EvaluationFields fields = evaluation.getFields(); + assertThat(fields.getActualField(), is(equalTo("foo"))); + assertThat(fields.getPredictedField(), is(nullValue())); + assertThat(fields.getTopClassesField(), is(nullValue())); + assertThat(fields.getPredictedClassField(), is(nullValue())); + assertThat(fields.getPredictedProbabilityField(), is(equalTo("bar"))); + assertThat(fields.isPredictedProbabilityFieldNested(), is(false)); + } + public void testBuildSearch() { QueryBuilder userProvidedQuery = QueryBuilders.boolQuery() diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java index 26dff097b1b32..c8fc2d5d67d55 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; @@ -26,6 +27,8 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; public class RegressionTests extends AbstractSerializingTestCase { @@ -73,6 +76,17 @@ public void testConstructor_GivenEmptyMetrics() { assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics")); } + public void testGetFields() { + Regression evaluation = new Regression("foo", "bar", null); + EvaluationFields fields = evaluation.getFields(); + assertThat(fields.getActualField(), is(equalTo("foo"))); + assertThat(fields.getPredictedField(), is(equalTo("bar"))); + assertThat(fields.getTopClassesField(), is(nullValue())); + assertThat(fields.getPredictedClassField(), is(nullValue())); + assertThat(fields.getPredictedProbabilityField(), is(nullValue())); + assertThat(fields.isPredictedProbabilityFieldNested(), is(false)); + } + public void testBuildSearch() { QueryBuilder userProvidedQuery = QueryBuilders.boolQuery() diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index b8cab82f7aaef..a9fd5f1305037 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -115,6 +115,11 @@ yamlRestTest { 'ml/evaluate_data_frame/Test classification given evaluation with empty metrics', 'ml/evaluate_data_frame/Test classification given missing actual_field', 'ml/evaluate_data_frame/Test classification given missing predicted_field', + 'ml/evaluate_data_frame/Test classification given missing top_classes_field', + 'ml/evaluate_data_frame/Test classification auc_roc given actual_field is never equal to fish', + 'ml/evaluate_data_frame/Test classification auc_roc given predicted_class_field is never equal to mouse', + 'ml/evaluate_data_frame/Test classification auc_roc with missing class_name', + 'ml/evaluate_data_frame/Test classification accuracy with missing predicted_field', 'ml/evaluate_data_frame/Test regression given evaluation with empty metrics', 'ml/evaluate_data_frame/Test regression given missing actual_field', 'ml/evaluate_data_frame/Test regression given missing predicted_field', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index d718a0f27908a..b16c8ce67b241 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; @@ -24,13 +25,17 @@ import org.junit.Before; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.stream.IntStream; import static java.util.stream.Collectors.toList; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -38,16 +43,19 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase { - private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index"; + static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index"; - private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword"; - private static final String ANIMAL_NAME_PREDICTION_KEYWORD_FIELD = "animal_name_keyword_prediction"; - private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword"; - private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer"; - private static final String NO_LEGS_PREDICTION_INTEGER_FIELD = "no_legs_integer_prediction"; - private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword"; - private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean"; - private static final String IS_PREDATOR_PREDICTION_BOOLEAN_FIELD = "predator_boolean_prediction"; + static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword"; + static final String ANIMAL_NAME_PREDICTION_KEYWORD_FIELD = "animal_name_keyword_prediction"; + static final String ANIMAL_NAME_PREDICTION_PROB_FIELD = "animal_name_prediction_prob"; + static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword"; + static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer"; + static final String NO_LEGS_PREDICTION_INTEGER_FIELD = "no_legs_integer_prediction"; + static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword"; + static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean"; + static final String IS_PREDATOR_PREDICTION_BOOLEAN_FIELD = "predator_boolean_prediction"; + static final String IS_PREDATOR_PREDICTION_PROBABILITY_FIELD = "predator_prediction_probability"; + static final String ML_TOP_CLASSES_FIELD = "ml_results"; @Before public void setup() { @@ -67,7 +75,8 @@ public void cleanup() { public void testEvaluate_DefaultMetrics() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null)); + ANIMALS_DATA_INDEX, + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null, null)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( @@ -82,6 +91,7 @@ public void testEvaluate_AllMetrics() { new Classification( ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, + null, Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); @@ -116,6 +126,7 @@ public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() { new Classification( actualField, predictedField, + null, Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); @@ -139,9 +150,37 @@ public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() { assertThat(recallResult.getAvgRecall(), equalTo(0.0)); } + private AucRoc.Result evaluateAucRoc(boolean includeCurve) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, + new Classification(ANIMAL_NAME_KEYWORD_FIELD, null, ML_TOP_CLASSES_FIELD, Arrays.asList(new AucRoc(includeCurve, "cat")))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + AucRoc.Result aucrocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(aucrocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName())); + return aucrocResult; + } + + public void testEvaluate_AucRoc_DoNotIncludeCurve() { + AucRoc.Result aucrocResult = evaluateAucRoc(false); + assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001))); + assertThat(aucrocResult.getDocCount(), is(equalTo(75L))); + assertThat(aucrocResult.getCurve(), hasSize(0)); + } + + public void testEvaluate_AucRoc_IncludeCurve() { + AucRoc.Result aucrocResult = evaluateAucRoc(true); + assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001))); + assertThat(aucrocResult.getDocCount(), is(equalTo(75L))); + assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0))); + } + private Accuracy.Result evaluateAccuracy(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Accuracy()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, null, Arrays.asList(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -260,7 +299,7 @@ public void testEvaluate_Accuracy_FieldTypeMismatch() { private Precision.Result evaluatePrecision(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Precision()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, null, Arrays.asList(new Precision()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -354,13 +393,13 @@ public void testEvaluate_Precision_CardinalityTooHigh() { ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Precision())))); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null, Arrays.asList(new Precision())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } private Recall.Result evaluateRecall(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Recall()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, null, Arrays.asList(new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -469,7 +508,7 @@ public void testEvaluate_Recall_CardinalityTooHigh() { ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Recall())))); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null, Arrays.asList(new Recall())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } @@ -478,7 +517,10 @@ private void evaluateMulticlassConfusionMatrix() { evaluateDataFrame( ANIMALS_DATA_INDEX, new Classification( - ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new MulticlassConfusionMatrix()))); + ANIMAL_NAME_KEYWORD_FIELD, + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, + null, + Arrays.asList(new MulticlassConfusionMatrix()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -561,6 +603,7 @@ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { new Classification( ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, + null, Arrays.asList(new MulticlassConfusionMatrix(3, null)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); @@ -595,7 +638,7 @@ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L)); } - private static void createAnimalsIndex(String indexName) { + static void createAnimalsIndex(String indexName) { client().admin().indices().prepareCreate(indexName) .addMapping("_doc", ANIMAL_NAME_KEYWORD_FIELD, "type=keyword", @@ -605,28 +648,41 @@ private static void createAnimalsIndex(String indexName) { NO_LEGS_PREDICTION_INTEGER_FIELD, "type=integer", IS_PREDATOR_KEYWORD_FIELD, "type=keyword", IS_PREDATOR_BOOLEAN_FIELD, "type=boolean", - IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, "type=boolean") + IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, "type=boolean", + IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, "type=double", + ML_TOP_CLASSES_FIELD, "type=nested") .get(); } - private static void indexAnimalsData(String indexName) { + static void indexAnimalsData(String indexName) { List animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox"); BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 0; i < animalNames.size(); i++) { for (int j = 0; j < animalNames.size(); j++) { for (int k = 0; k < j + 1; k++) { + List topClasses = + IntStream + .range(0, 5) + .mapToObj(ix -> new HashMap() {{ + put("class_name", animalNames.get(ix)); + put("class_probability", 0.4 - 0.1 * ix); + }}) + .collect(toList()); bulkRequestBuilder.add( new IndexRequest(indexName) .source( ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i), ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, animalNames.get((i + j) % animalNames.size()), + ANIMAL_NAME_PREDICTION_PROB_FIELD, animalNames.get((i + j) % animalNames.size()), NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1), NO_LEGS_INTEGER_FIELD, i + 1, NO_LEGS_PREDICTION_INTEGER_FIELD, j + 1, IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0), IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0, - IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, (i + j) % 2 == 0)); + IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, (i + j) % 2 == 0, + IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, i % 2 == 0 ? 1.0 - 0.1 * i : 0.1 * i, + ML_TOP_CLASSES_FIELD, topClasses)); } } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 9315b5044e961..5cc876e752b26 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -40,6 +40,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; @@ -957,9 +958,15 @@ private void assertEvaluation(String dependentVariable, List dependentVar new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification( dependentVariable, predictedClassField, - Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); + null, + Arrays.asList( + new Accuracy(), + new AucRoc(true, dependentVariableValues.get(0).toString()), + new MulticlassConfusionMatrix(), + new Precision(), + new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4)); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(5)); { // Accuracy Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); @@ -970,9 +977,17 @@ private void assertEvaluation(String dependentVariable, List dependentVar } } + { // AucRoc + AucRoc.Result aucRocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(1); + assertThat(aucRocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName())); + assertThat(aucRocResult.getScore(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); + assertThat(aucRocResult.getDocCount(), allOf(greaterThanOrEqualTo(1L), lessThanOrEqualTo(350L))); + assertThat(aucRocResult.getCurve(), hasSize(greaterThan(0))); + } + { // MulticlassConfusionMatrix MulticlassConfusionMatrix.Result confusionMatrixResult = - (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(1); + (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(2); assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); List actualClasses = confusionMatrixResult.getConfusionMatrix(); assertThat( @@ -990,7 +1005,7 @@ private void assertEvaluation(String dependentVariable, List dependentVar } { // Precision - Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2); + Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(3); assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); for (Precision.PerClassResult klass : precisionResult.getClasses()) { assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings))); @@ -999,7 +1014,7 @@ private void assertEvaluation(String dependentVariable, List dependentVar } { // Recall - Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3); + Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(4); assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); for (Recall.PerClassResult klass : recallResult.getClasses()) { assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings))); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java new file mode 100644 index 0000000000000..724c9e2d5a3ed --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.ConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall; +import org.junit.After; +import org.junit.Before; + +import java.util.Arrays; + +import static java.util.stream.Collectors.toList; +import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.ANIMALS_DATA_INDEX; +import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.IS_PREDATOR_BOOLEAN_FIELD; +import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.IS_PREDATOR_PREDICTION_PROBABILITY_FIELD; +import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.createAnimalsIndex; +import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.indexAnimalsData; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +public class OutlierDetectionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase { + + @Before + public void setup() { + createAnimalsIndex(ANIMALS_DATA_INDEX); + indexAnimalsData(ANIMALS_DATA_INDEX); + } + + @After + public void cleanup() { + cleanUp(); + } + + public void testEvaluate_DefaultMetrics() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new OutlierDetection(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, null)); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME.getPreferredName())); + assertThat( + evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), + containsInAnyOrder( + AucRoc.NAME.getPreferredName(), + Precision.NAME.getPreferredName(), + Recall.NAME.getPreferredName(), + ConfusionMatrix.NAME.getPreferredName())); + } + + public void testEvaluate_AllMetrics() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, + new OutlierDetection( + IS_PREDATOR_BOOLEAN_FIELD, + IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, + Arrays.asList( + new AucRoc(false), + new Precision(Arrays.asList(0.5)), + new Recall(Arrays.asList(0.5)), + new ConfusionMatrix(Arrays.asList(0.5))))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME.getPreferredName())); + assertThat( + evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), + containsInAnyOrder( + AucRoc.NAME.getPreferredName(), + Precision.NAME.getPreferredName(), + Recall.NAME.getPreferredName(), + ConfusionMatrix.NAME.getPreferredName())); + } + + private AucRoc.Result evaluateAucRoc(String actualField, String predictedField, boolean includeCurve) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, + new OutlierDetection(actualField, predictedField, Arrays.asList(new AucRoc(includeCurve)))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + AucRoc.Result aucrocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(aucrocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName())); + return aucrocResult; + } + + public void testEvaluate_AucRoc_DoNotIncludeCurve() { + AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, false); + assertThat(aucrocResult.getScore(), is(closeTo(1.0, 0.0001))); + assertThat(aucrocResult.getCurve(), hasSize(0)); + } + + public void testEvaluate_AucRoc_IncludeCurve() { + AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, true); + assertThat(aucrocResult.getScore(), is(closeTo(1.0, 0.0001))); + assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0))); + } + + @Override + boolean supportsInference() { + return false; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndexTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndexTests.java index aedfcc90adc3b..7d627e9d0830e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndexTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndexTests.java @@ -206,25 +206,25 @@ public void testCreateDestinationIndex_Regression() throws IOException { public void testCreateDestinationIndex_Classification() throws IOException { Map map = testCreateDestinationIndex(new Classification(NUMERICAL_FIELD)); assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("integer")); - assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer")); + assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer")); } public void testCreateDestinationIndex_Classification_DependentVariableIsNested() throws IOException { Map map = testCreateDestinationIndex(new Classification(OUTER_FIELD + "." + INNER_FIELD)); assertThat(extractValue("_doc.properties.ml.outer-field.inner-field_prediction.type", map), equalTo("integer")); - assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer")); + assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer")); } public void testCreateDestinationIndex_Classification_DependentVariableIsAlias() throws IOException { Map map = testCreateDestinationIndex(new Classification(ALIAS_TO_NUMERICAL_FIELD)); assertThat(extractValue("_doc.properties.ml.alias-to-numerical-field_prediction.type", map), equalTo("integer")); - assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer")); + assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer")); } public void testCreateDestinationIndex_Classification_DependentVariableIsAliasToNested() throws IOException { Map map = testCreateDestinationIndex(new Classification(ALIAS_TO_NESTED_FIELD)); assertThat(extractValue("_doc.properties.ml.alias-to-nested-field_prediction.type", map), equalTo("integer")); - assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer")); + assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer")); } public void testCreateDestinationIndex_ResultsFieldsExistsInSourceIndex() throws IOException { @@ -322,25 +322,25 @@ public void testUpdateMappingsToDestIndex_Regression() throws IOException { public void testUpdateMappingsToDestIndex_Classification() throws IOException { Map map = testUpdateMappingsToDestIndex(new Classification(NUMERICAL_FIELD)); assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("integer")); - assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer")); + assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer")); } public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsNested() throws IOException { Map map = testUpdateMappingsToDestIndex(new Classification(OUTER_FIELD + "." + INNER_FIELD)); assertThat(extractValue("properties.ml.outer-field.inner-field_prediction.type", map), equalTo("integer")); - assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer")); + assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer")); } public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsAlias() throws IOException { Map map = testUpdateMappingsToDestIndex(new Classification(ALIAS_TO_NUMERICAL_FIELD)); assertThat(extractValue("properties.ml.alias-to-numerical-field_prediction.type", map), equalTo("integer")); - assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer")); + assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer")); } public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsAliasToNested() throws IOException { Map map = testUpdateMappingsToDestIndex(new Classification(ALIAS_TO_NESTED_FIELD)); assertThat(extractValue("properties.ml.alias-to-nested-field_prediction.type", map), equalTo("integer")); - assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer")); + assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer")); } public void testUpdateMappingsToDestIndex_ResultsFieldsExistsInSourceIndex() throws IOException { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 57eb1d1116acb..94282acec9ac1 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -1,5 +1,14 @@ setup: + - do: + indices.create: + index: utopia + body: + mappings: + properties: + ml.top_classes: + type: nested + - do: index: index: utopia @@ -14,7 +23,11 @@ setup: "classification_field_act": "dog", "classification_field_pred": "dog", "all_true_field": true, - "all_false_field": false + "all_false_field": false, + "ml.top_classes": [ + {"class_name": "dog", "class_probability": 0.9}, + {"class_name": "cat", "class_probability": 0.1} + ] } - do: @@ -31,7 +44,11 @@ setup: "classification_field_act": "cat", "classification_field_pred": "cat", "all_true_field": true, - "all_false_field": false + "all_false_field": false, + "ml.top_classes": [ + {"class_name": "cat", "class_probability": 0.8}, + {"class_name": "dog", "class_probability": 0.2} + ] } - do: @@ -48,7 +65,11 @@ setup: "classification_field_act": "mouse", "classification_field_pred": "mouse", "all_true_field": true, - "all_false_field": false + "all_false_field": false, + "ml.top_classes": [ + {"class_name": "cat", "class_probability": 0.3}, + {"class_name": "dog", "class_probability": 0.1} + ] } - do: @@ -65,7 +86,11 @@ setup: "classification_field_act": "dog", "classification_field_pred": "cat", "all_true_field": true, - "all_false_field": false + "all_false_field": false, + "ml.top_classes": [ + {"class_name": "cat", "class_probability": 0.6}, + {"class_name": "dog", "class_probability": 0.3} + ] } - do: @@ -82,7 +107,11 @@ setup: "classification_field_act": "cat", "classification_field_pred": "dog", "all_true_field": true, - "all_false_field": false + "all_false_field": false, + "ml.top_classes": [ + {"class_name": "dog", "class_probability": 0.7}, + {"class_name": "cat", "class_probability": 0.3} + ] } - do: @@ -99,7 +128,11 @@ setup: "classification_field_act": "dog", "classification_field_pred": "dog", "all_true_field": true, - "all_false_field": false + "all_false_field": false, + "ml.top_classes": [ + {"class_name": "dog", "class_probability": 0.9}, + {"class_name": "cat", "class_probability": 0.1} + ] } - do: @@ -116,7 +149,11 @@ setup: "classification_field_act": "cat", "classification_field_pred": "cat", "all_true_field": true, - "all_false_field": false + "all_false_field": false, + "ml.top_classes": [ + {"class_name": "cat", "class_probability": 0.8}, + {"class_name": "dog", "class_probability": 0.2} + ] } - do: @@ -133,7 +170,11 @@ setup: "classification_field_act": "mouse", "classification_field_pred": "cat", "all_true_field": true, - "all_false_field": false + "all_false_field": false, + "ml.top_classes": [ + {"class_name": "cat", "class_probability": 0.8}, + {"class_name": "dog", "class_probability": 0.2} + ] } # This document misses the required fields and should be ignored @@ -166,6 +207,7 @@ setup: } } - match: { outlier_detection.auc_roc.score: 0.9899 } + - match: { outlier_detection.auc_roc.doc_count: 8 } - is_false: outlier_detection.auc_roc.curve --- @@ -186,6 +228,7 @@ setup: } } - match: { outlier_detection.auc_roc.score: 0.9899 } + - match: { outlier_detection.auc_roc.doc_count: 8 } - is_false: outlier_detection.auc_roc.curve --- @@ -206,12 +249,13 @@ setup: } } - match: { outlier_detection.auc_roc.score: 0.9899 } + - match: { outlier_detection.auc_roc.doc_count: 8 } - is_true: outlier_detection.auc_roc.curve --- "Test outlier_detection auc_roc given actual_field is always true": - do: - catch: /\[auc_roc\] requires at least one actual_field to have a different value than \[true\]/ + catch: /\[auc_roc\] requires at least one \[all_true_field\] to have a different value than \[true\]/ ml.evaluate_data_frame: body: > { @@ -230,7 +274,7 @@ setup: --- "Test outlier_detection auc_roc given actual_field is always false": - do: - catch: /\[auc_roc\] requires at least one actual_field to have the value \[true\]/ + catch: /\[auc_roc\] requires at least one \[all_false_field\] to have the value \[true\]/ ml.evaluate_data_frame: body: > { @@ -371,6 +415,7 @@ setup: } } - is_true: outlier_detection.auc_roc.score + - is_true: outlier_detection.auc_roc.doc_count - is_true: outlier_detection.precision.0\.25 - is_true: outlier_detection.precision.0\.5 - is_true: outlier_detection.precision.0\.75 @@ -443,7 +488,7 @@ setup: --- "Test outlier_detection given missing actual_field": - do: - catch: /No documents found containing both \[missing, outlier_score\] fields/ + catch: /No documents found containing all the required fields \[missing, outlier_score\]/ ml.evaluate_data_frame: body: > { @@ -459,7 +504,7 @@ setup: --- "Test outlier_detection given missing predicted_probability_field": - do: - catch: /No documents found containing both \[is_outlier, missing\] fields/ + catch: /No documents found containing all the required fields \[is_outlier, missing\]/ ml.evaluate_data_frame: body: > { @@ -598,7 +643,124 @@ setup: "classification": { "actual_field": "classification_field_act.keyword", "predicted_field": "classification_field_pred.keyword", - "metrics": { } + "metrics": {} + } + } + } +--- +"Test classification auc_roc with missing class_name": + - do: + # TODO: Revisit this error message as it does not give any clue about which field is missing + catch: /Failed to build \[auc_roc\] after last required field arrived/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "top_classes_field": "ml.top_classes", + "metrics": { + "auc_roc": {} + } + } + } + } +--- +"Test classification auc_roc given actual_field is never equal to fish": + - do: + catch: /\[auc_roc\] requires at least one \[classification_field_act.keyword\] to have the value \[fish\]/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "top_classes_field": "ml.top_classes", + "metrics": { + "auc_roc": { + "class_name": "fish" + } + } + } + } + } +--- +"Test classification auc_roc given predicted_class_field is never equal to mouse": + - do: + catch: /\[auc_roc\] requires at least one \[ml.top_classes.class_name\] to have the value \[mouse\]/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "top_classes_field": "ml.top_classes", + "metrics": { + "auc_roc": { + "class_name": "mouse" + } + } + } + } + } +--- +"Test classification auc_roc": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "top_classes_field": "ml.top_classes", + "metrics": { + "auc_roc": { + "class_name": "cat" + } + } + } + } + } + - match: { classification.auc_roc.score: 0.8050111095212122 } + - match: { classification.auc_roc.doc_count: 8 } + - is_false: classification.auc_roc.curve +--- +"Test classification auc_roc with default top_classes_field": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "metrics": { + "auc_roc": { + "class_name": "cat" + } + } + } + } + } + - match: { classification.auc_roc.score: 0.8050111095212122 } + - match: { classification.auc_roc.doc_count: 8 } + - is_false: classification.auc_roc.curve +--- +"Test classification accuracy with missing predicted_field": + - do: + catch: /\[classification\] must define \[predicted_field\] as required by the following metrics \[accuracy\]/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "metrics": { "accuracy": {} } } } } @@ -785,7 +947,7 @@ setup: --- "Test classification given missing actual_field": - do: - catch: /No documents found containing both \[missing, classification_field_pred.keyword\] fields/ + catch: /No documents found containing all the required fields \[missing, classification_field_pred.keyword\]/ ml.evaluate_data_frame: body: > { @@ -801,7 +963,7 @@ setup: --- "Test classification given missing predicted_field": - do: - catch: /No documents found containing both \[classification_field_act.keyword, missing\] fields/ + catch: /No documents found containing all the required fields \[classification_field_act.keyword, missing\]/ ml.evaluate_data_frame: body: > { @@ -815,6 +977,27 @@ setup: } --- +"Test classification given missing top_classes_field": + - do: + catch: /No documents found containing all the required fields \[classification_field_act.keyword, missing.class_name, missing.class_probability\]/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword", + "top_classes_field": "missing", + "metrics": { + "auc_roc": { + "class_name": "dummy" + } + } + } + } + } +--- "Test regression given evaluation with empty metrics": - do: catch: /\[regression\] must have one or more metrics/ @@ -932,7 +1115,7 @@ setup: --- "Test regression given missing actual_field": - do: - catch: /No documents found containing both \[missing, regression_field_pred\] fields/ + catch: /No documents found containing all the required fields \[missing, regression_field_pred\]/ ml.evaluate_data_frame: body: > { @@ -948,7 +1131,7 @@ setup: --- "Test regression given missing predicted_field": - do: - catch: /No documents found containing both \[regression_field_act, missing\] fields/ + catch: /No documents found containing all the required fields \[regression_field_act, missing\]/ ml.evaluate_data_frame: body: > {