Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[7.x] [ML] Implement AucRoc metric for classification (#60502) #63051

Merged
merged 2 commits into from
Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ the probability that each document is an outlier.
`auc_roc`:::
(Optional, object) The AUC ROC (area under the curve of the receiver
operating characteristic) score and optionally the curve. Default value is
{"includes_curve": false}.
{"include_curve": false}.

`confusion_matrix`:::
(Optional, object) Set the different thresholds of the {olscore} at where
Expand Down Expand Up @@ -153,16 +153,39 @@ belongs.
The data type of this field must be categorical.

`predicted_field`::
(Required, string) The field in the `index` that contains the predicted value,
(Optional, string) The field in the `index` which contains the predicted value,
in other words the results of the {classanalysis}.

`top_classes_field`::
(Optional, string) The field of the `index` which is an array of documents
of the form `{ "class_name": XXX, "class_probability": YYY }`.
This field must be defined as `nested` in the mappings.

`metrics`::
(Optional, object) Specifies the metrics that are used for the evaluation.
Available metrics:

`accuracy`:::
(Optional, object) Accuracy of predictions (per-class and overall).

`auc_roc`:::
(Optional, object) The AUC ROC (area under the curve of the receiver
operating characteristic) score and optionally the curve.
It is calculated for a specific class (provided as "class_name")
treated as positive.

`class_name`::::
(Required, string) Name of the only class that will be treated as
positive during AUC ROC calculation. Other classes will be treated as
negative ("one-vs-all" strategy). Documents which do not have `class_name`
in the list of their top classes will not be taken into account for evaluation.
The number of documents taken into account is returned in the evaluation result
(`auc_roc.doc_count` field).

`include_curve`::::
(Optional, boolean) Whether or not the curve should be returned in
addition to the score. Default value is false.

`multiclass_confusion_matrix`:::
(Optional, object) Multiclass confusion matrix.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,16 @@ public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mapping
return additionalProperties;
}
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);

Map<String, Object> topClassesProperties = new HashMap<>();
topClassesProperties.put("class_name", dependentVariableMapping);
topClassesProperties.put("class_probability", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));

Map<String, Object> topClassesMapping = new HashMap<>();
topClassesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
topClassesMapping.put("properties", topClassesProperties);

additionalProperties.put(resultsFieldName + ".top_classes", topClassesMapping);
return additionalProperties;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
Expand All @@ -21,11 +22,16 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;

/**
* Defines an evaluation
Expand All @@ -38,14 +44,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
String getName();

/**
* Returns the field containing the actual value
*/
String getActualField();

/**
* Returns the field containing the predicted value
* Returns the collection of fields required by evaluation
*/
String getPredictedField();
EvaluationFields getFields();

/**
* Returns the list of metrics to evaluate
Expand All @@ -59,27 +60,74 @@ default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parse
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName());
}
Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName));
checkRequiredFieldsAreSet(metrics);
return metrics;
}

default <T extends EvaluationMetric> void checkRequiredFieldsAreSet(List<T> metrics) {
assert (metrics == null || metrics.isEmpty()) == false;
for (Tuple<String, String> requiredField : getFields().listPotentiallyRequiredFields()) {
String fieldDescriptor = requiredField.v1();
String field = requiredField.v2();
if (field == null) {
String metricNamesString =
metrics.stream()
.filter(m -> m.getRequiredFields().contains(fieldDescriptor))
.map(EvaluationMetric::getName)
.collect(joining(", "));
if (metricNamesString.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
"[{}] must define [{}] as required by the following metrics [{}]",
getName(), fieldDescriptor, metricNamesString);
}
}
}
}

/**
* Builds the search required to collect data to compute the evaluation result
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
*/
default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
Objects.requireNonNull(userProvidedQueryBuilder);
BoolQueryBuilder boolQuery =
QueryBuilders.boolQuery()
// Verify existence of required fields
.filter(QueryBuilders.existsQuery(getActualField()))
.filter(QueryBuilders.existsQuery(getPredictedField()))
// Apply user-provided query
.filter(userProvidedQueryBuilder);
Set<String> requiredFields = new HashSet<>(getRequiredFields());
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
if (getFields().getActualField() != null && requiredFields.contains(getFields().getActualField())) {
// Verify existence of the actual field if required
boolQuery.filter(QueryBuilders.existsQuery(getFields().getActualField()));
}
if (getFields().getPredictedField() != null && requiredFields.contains(getFields().getPredictedField())) {
// Verify existence of the predicted field if required
boolQuery.filter(QueryBuilders.existsQuery(getFields().getPredictedField()));
}
if (getFields().getPredictedClassField() != null && requiredFields.contains(getFields().getPredictedClassField())) {
assert getFields().getTopClassesField() != null;
// Verify existence of the predicted class name field if required
QueryBuilder predictedClassFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedClassField());
boolQuery.filter(
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedClassFieldExistsQuery, ScoreMode.None)
.ignoreUnmapped(true));
}
if (getFields().getPredictedProbabilityField() != null && requiredFields.contains(getFields().getPredictedProbabilityField())) {
// Verify existence of the predicted probability field if required
QueryBuilder predictedProbabilityFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedProbabilityField());
// predicted probability field may be either nested (just like in case of classification evaluation) or non-nested (just like
// in case of outlier detection evaluation). Here we support both modes.
if (getFields().isPredictedProbabilityFieldNested()) {
assert getFields().getTopClassesField() != null;
boolQuery.filter(
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedProbabilityFieldExistsQuery, ScoreMode.None)
.ignoreUnmapped(true));
} else {
boolQuery.filter(predictedProbabilityFieldExistsQuery);
}
}
// Apply user-provided query
boolQuery.filter(userProvidedQueryBuilder);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (EvaluationMetric metric : getMetrics()) {
// Fetch aggregations requested by individual metrics
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
metric.aggs(parameters, getActualField(), getPredictedField());
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(parameters, getFields());
aggs.v1().forEach(searchSourceBuilder::aggregation);
aggs.v2().forEach(searchSourceBuilder::aggregation);
}
Expand All @@ -93,14 +141,31 @@ default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBu
default void process(SearchResponse searchResponse) {
Objects.requireNonNull(searchResponse);
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException(
"No documents found containing both [{}, {}] fields", getActualField(), getPredictedField());
String requiredFieldsString = String.join(", ", getRequiredFields());
throw ExceptionsHelper.badRequestException("No documents found containing all the required fields [{}]", requiredFieldsString);
}
for (EvaluationMetric metric : getMetrics()) {
metric.process(searchResponse.getAggregations());
}
}

/**
* @return list of fields which are required by at least one of the metrics
*/
default List<String> getRequiredFields() {
Set<String> requiredFieldDescriptors =
getMetrics().stream()
.map(EvaluationMetric::getRequiredFields)
.flatMap(Set::stream)
.collect(toSet());
List<String> requiredFields =
getFields().listPotentiallyRequiredFields().stream()
.filter(f -> requiredFieldDescriptors.contains(f.v1()))
.map(Tuple::v2)
.collect(toList());
return requiredFields;
}

/**
* @return true iff all the metrics have their results computed
*/
Expand All @@ -117,6 +182,6 @@ default List<EvaluationMetricResult> getResults() {
.map(EvaluationMetric::getResult)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toList());
.collect(toList());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/**
* Encapsulates fields needed by evaluation.
*/
public final class EvaluationFields {

public static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
public static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
public static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field");
public static final ParseField PREDICTED_CLASS_FIELD = new ParseField("predicted_class_field");
public static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field");

/**
* The field containing the actual value
*/
private final String actualField;

/**
* The field containing the predicted value
*/
private final String predictedField;

/**
* The field containing the array of top classes
*/
private final String topClassesField;

/**
* The field containing the predicted class name value
*/
private final String predictedClassField;

/**
* The field containing the predicted probability value in [0.0, 1.0]
*/
private final String predictedProbabilityField;

/**
* Whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
*/
private final boolean predictedProbabilityFieldNested;

public EvaluationFields(@Nullable String actualField,
@Nullable String predictedField,
@Nullable String topClassesField,
@Nullable String predictedClassField,
@Nullable String predictedProbabilityField,
boolean predictedProbabilityFieldNested) {

this.actualField = actualField;
this.predictedField = predictedField;
this.topClassesField = topClassesField;
this.predictedClassField = predictedClassField;
this.predictedProbabilityField = predictedProbabilityField;
this.predictedProbabilityFieldNested = predictedProbabilityFieldNested;
}

/**
* Returns the field containing the actual value
*/
public String getActualField() {
return actualField;
}

/**
* Returns the field containing the predicted value
*/
public String getPredictedField() {
return predictedField;
}

/**
* Returns the field containing the array of top classes
*/
public String getTopClassesField() {
return topClassesField;
}

/**
* Returns the field containing the predicted class name value
*/
public String getPredictedClassField() {
return predictedClassField;
}

/**
* Returns the field containing the predicted probability value in [0.0, 1.0]
*/
public String getPredictedProbabilityField() {
return predictedProbabilityField;
}

/**
* Returns whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
*/
public boolean isPredictedProbabilityFieldNested() {
return predictedProbabilityFieldNested;
}

public List<Tuple<String, String>> listPotentiallyRequiredFields() {
return Arrays.asList(
Tuple.tuple(ACTUAL_FIELD.getPreferredName(), actualField),
Tuple.tuple(PREDICTED_FIELD.getPreferredName(), predictedField),
Tuple.tuple(TOP_CLASSES_FIELD.getPreferredName(), topClassesField),
Tuple.tuple(PREDICTED_CLASS_FIELD.getPreferredName(), predictedClassField),
Tuple.tuple(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField));
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
EvaluationFields that = (EvaluationFields) o;
return Objects.equals(that.actualField, this.actualField)
&& Objects.equals(that.predictedField, this.predictedField)
&& Objects.equals(that.topClassesField, this.topClassesField)
&& Objects.equals(that.predictedClassField, this.predictedClassField)
&& Objects.equals(that.predictedProbabilityField, this.predictedProbabilityField)
&& Objects.equals(that.predictedProbabilityFieldNested, this.predictedProbabilityFieldNested);
}

@Override
public int hashCode() {
return Objects.hash(
actualField, predictedField, topClassesField, predictedClassField, predictedProbabilityField, predictedProbabilityFieldNested);
}
}
Loading