Skip to content

Commit

Permalink
[ML] Implement AucRoc metric for classification (#60502)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek authored Sep 30, 2020
1 parent 9b9f33e commit cd1a27f
Show file tree
Hide file tree
Showing 42 changed files with 2,009 additions and 594 deletions.
27 changes: 25 additions & 2 deletions docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ the probability that each document is an outlier.
`auc_roc`:::
(Optional, object) The AUC ROC (area under the curve of the receiver
operating characteristic) score and optionally the curve. Default value is
{"includes_curve": false}.
{"include_curve": false}.

`confusion_matrix`:::
(Optional, object) Set the different thresholds of the {olscore} at where
Expand Down Expand Up @@ -154,16 +154,39 @@ belongs.
The data type of this field must be categorical.

`predicted_field`::
(Required, string) The field in the `index` that contains the predicted value,
(Optional, string) The field in the `index` which contains the predicted value,
in other words the results of the {classanalysis}.

`top_classes_field`::
(Optional, string) The field of the `index` which is an array of documents
of the form `{ "class_name": XXX, "class_probability": YYY }`.
This field must be defined as `nested` in the mappings.

`metrics`::
(Optional, object) Specifies the metrics that are used for the evaluation.
Available metrics:

`accuracy`:::
(Optional, object) Accuracy of predictions (per-class and overall).

`auc_roc`:::
(Optional, object) The AUC ROC (area under the curve of the receiver
operating characteristic) score and optionally the curve.
It is calculated for a specific class (provided as "class_name")
treated as positive.

`class_name`::::
(Required, string) Name of the only class that will be treated as
positive during AUC ROC calculation. Other classes will be treated as
negative ("one-vs-all" strategy). Documents which do not have `class_name`
in the list of their top classes will not be taken into account for evaluation.
The number of documents taken into account is returned in the evaluation result
(`auc_roc.doc_count` field).

`include_curve`::::
(Optional, boolean) Whether or not the curve should be returned in
addition to the score. Default value is false.

`multiclass_confusion_matrix`:::
(Optional, object) Multiclass confusion matrix.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,16 @@ public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mapping
return additionalProperties;
}
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);

Map<String, Object> topClassesProperties = new HashMap<>();
topClassesProperties.put("class_name", dependentVariableMapping);
topClassesProperties.put("class_probability", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));

Map<String, Object> topClassesMapping = new HashMap<>();
topClassesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
topClassesMapping.put("properties", topClassesProperties);

additionalProperties.put(resultsFieldName + ".top_classes", topClassesMapping);
return additionalProperties;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
Expand All @@ -21,11 +22,16 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;

/**
* Defines an evaluation
Expand All @@ -38,14 +44,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
String getName();

/**
* Returns the field containing the actual value
*/
String getActualField();

/**
* Returns the field containing the predicted value
* Returns the collection of fields required by evaluation
*/
String getPredictedField();
EvaluationFields getFields();

/**
* Returns the list of metrics to evaluate
Expand All @@ -59,27 +60,74 @@ default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parse
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName());
}
Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName));
checkRequiredFieldsAreSet(metrics);
return metrics;
}

private <T extends EvaluationMetric> void checkRequiredFieldsAreSet(List<T> metrics) {
assert (metrics == null || metrics.isEmpty()) == false;
for (Tuple<String, String> requiredField : getFields().listPotentiallyRequiredFields()) {
String fieldDescriptor = requiredField.v1();
String field = requiredField.v2();
if (field == null) {
String metricNamesString =
metrics.stream()
.filter(m -> m.getRequiredFields().contains(fieldDescriptor))
.map(EvaluationMetric::getName)
.collect(joining(", "));
if (metricNamesString.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
"[{}] must define [{}] as required by the following metrics [{}]",
getName(), fieldDescriptor, metricNamesString);
}
}
}
}

/**
* Builds the search required to collect data to compute the evaluation result
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
*/
default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
Objects.requireNonNull(userProvidedQueryBuilder);
BoolQueryBuilder boolQuery =
QueryBuilders.boolQuery()
// Verify existence of required fields
.filter(QueryBuilders.existsQuery(getActualField()))
.filter(QueryBuilders.existsQuery(getPredictedField()))
// Apply user-provided query
.filter(userProvidedQueryBuilder);
Set<String> requiredFields = new HashSet<>(getRequiredFields());
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
if (getFields().getActualField() != null && requiredFields.contains(getFields().getActualField())) {
// Verify existence of the actual field if required
boolQuery.filter(QueryBuilders.existsQuery(getFields().getActualField()));
}
if (getFields().getPredictedField() != null && requiredFields.contains(getFields().getPredictedField())) {
// Verify existence of the predicted field if required
boolQuery.filter(QueryBuilders.existsQuery(getFields().getPredictedField()));
}
if (getFields().getPredictedClassField() != null && requiredFields.contains(getFields().getPredictedClassField())) {
assert getFields().getTopClassesField() != null;
// Verify existence of the predicted class name field if required
QueryBuilder predictedClassFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedClassField());
boolQuery.filter(
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedClassFieldExistsQuery, ScoreMode.None)
.ignoreUnmapped(true));
}
if (getFields().getPredictedProbabilityField() != null && requiredFields.contains(getFields().getPredictedProbabilityField())) {
// Verify existence of the predicted probability field if required
QueryBuilder predictedProbabilityFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedProbabilityField());
// predicted probability field may be either nested (just like in case of classification evaluation) or non-nested (just like
// in case of outlier detection evaluation). Here we support both modes.
if (getFields().isPredictedProbabilityFieldNested()) {
assert getFields().getTopClassesField() != null;
boolQuery.filter(
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedProbabilityFieldExistsQuery, ScoreMode.None)
.ignoreUnmapped(true));
} else {
boolQuery.filter(predictedProbabilityFieldExistsQuery);
}
}
// Apply user-provided query
boolQuery.filter(userProvidedQueryBuilder);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (EvaluationMetric metric : getMetrics()) {
// Fetch aggregations requested by individual metrics
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
metric.aggs(parameters, getActualField(), getPredictedField());
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(parameters, getFields());
aggs.v1().forEach(searchSourceBuilder::aggregation);
aggs.v2().forEach(searchSourceBuilder::aggregation);
}
Expand All @@ -93,14 +141,31 @@ default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBu
default void process(SearchResponse searchResponse) {
Objects.requireNonNull(searchResponse);
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException(
"No documents found containing both [{}, {}] fields", getActualField(), getPredictedField());
String requiredFieldsString = String.join(", ", getRequiredFields());
throw ExceptionsHelper.badRequestException("No documents found containing all the required fields [{}]", requiredFieldsString);
}
for (EvaluationMetric metric : getMetrics()) {
metric.process(searchResponse.getAggregations());
}
}

/**
* @return list of fields which are required by at least one of the metrics
*/
private List<String> getRequiredFields() {
Set<String> requiredFieldDescriptors =
getMetrics().stream()
.map(EvaluationMetric::getRequiredFields)
.flatMap(Set::stream)
.collect(toSet());
List<String> requiredFields =
getFields().listPotentiallyRequiredFields().stream()
.filter(f -> requiredFieldDescriptors.contains(f.v1()))
.map(Tuple::v2)
.collect(toList());
return requiredFields;
}

/**
* @return true iff all the metrics have their results computed
*/
Expand All @@ -117,6 +182,6 @@ default List<EvaluationMetricResult> getResults() {
.map(EvaluationMetric::getResult)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toList());
.collect(toList());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/**
* Encapsulates fields needed by evaluation.
*/
public final class EvaluationFields {

public static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
public static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
public static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field");
public static final ParseField PREDICTED_CLASS_FIELD = new ParseField("predicted_class_field");
public static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field");

/**
* The field containing the actual value
*/
private final String actualField;

/**
* The field containing the predicted value
*/
private final String predictedField;

/**
* The field containing the array of top classes
*/
private final String topClassesField;

/**
* The field containing the predicted class name value
*/
private final String predictedClassField;

/**
* The field containing the predicted probability value in [0.0, 1.0]
*/
private final String predictedProbabilityField;

/**
* Whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
*/
private final boolean predictedProbabilityFieldNested;

public EvaluationFields(@Nullable String actualField,
@Nullable String predictedField,
@Nullable String topClassesField,
@Nullable String predictedClassField,
@Nullable String predictedProbabilityField,
boolean predictedProbabilityFieldNested) {

this.actualField = actualField;
this.predictedField = predictedField;
this.topClassesField = topClassesField;
this.predictedClassField = predictedClassField;
this.predictedProbabilityField = predictedProbabilityField;
this.predictedProbabilityFieldNested = predictedProbabilityFieldNested;
}

/**
* Returns the field containing the actual value
*/
public String getActualField() {
return actualField;
}

/**
* Returns the field containing the predicted value
*/
public String getPredictedField() {
return predictedField;
}

/**
* Returns the field containing the array of top classes
*/
public String getTopClassesField() {
return topClassesField;
}

/**
* Returns the field containing the predicted class name value
*/
public String getPredictedClassField() {
return predictedClassField;
}

/**
* Returns the field containing the predicted probability value in [0.0, 1.0]
*/
public String getPredictedProbabilityField() {
return predictedProbabilityField;
}

/**
* Returns whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
*/
public boolean isPredictedProbabilityFieldNested() {
return predictedProbabilityFieldNested;
}

public List<Tuple<String, String>> listPotentiallyRequiredFields() {
return Arrays.asList(
Tuple.tuple(ACTUAL_FIELD.getPreferredName(), actualField),
Tuple.tuple(PREDICTED_FIELD.getPreferredName(), predictedField),
Tuple.tuple(TOP_CLASSES_FIELD.getPreferredName(), topClassesField),
Tuple.tuple(PREDICTED_CLASS_FIELD.getPreferredName(), predictedClassField),
Tuple.tuple(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField));
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
EvaluationFields that = (EvaluationFields) o;
return Objects.equals(that.actualField, this.actualField)
&& Objects.equals(that.predictedField, this.predictedField)
&& Objects.equals(that.topClassesField, this.topClassesField)
&& Objects.equals(that.predictedClassField, this.predictedClassField)
&& Objects.equals(that.predictedProbabilityField, this.predictedProbabilityField)
&& Objects.equals(that.predictedProbabilityFieldNested, this.predictedProbabilityFieldNested);
}

@Override
public int hashCode() {
return Objects.hash(
actualField, predictedField, topClassesField, predictedClassField, predictedProbabilityField, predictedProbabilityFieldNested);
}
}
Loading

0 comments on commit cd1a27f

Please sign in to comment.