Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ClassInfo interface and BinaryClassInfo class. #49649

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,23 @@
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
Expand All @@ -27,37 +35,67 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
*/
String getName();

/**
* Returns the field containing the actual value
*/
String getActualField();

/**
* Returns the field containing the predicted value
*/
String getPredictedField();

/**
* Returns the list of metrics to evaluate
* @return list of metrics to evaluate
*/
List<? extends EvaluationMetric> getMetrics();

default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parsedMetrics, Supplier<List<T>> defaultMetricsSupplier) {
List<T> metrics = parsedMetrics == null ? defaultMetricsSupplier.get() : new ArrayList<>(parsedMetrics);
if (metrics.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName());
}
Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName));
return metrics;
}

/**
* Builds the search required to collect data to compute the evaluation result
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
*/
SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder);

/**
* Builds the search that verifies existence of required fields and applies user-provided query
* @param requiredFields fields that must exist
* @param userProvidedQueryBuilder user-provided query
*/
default SearchSourceBuilder newSearchSourceBuilder(List<String> requiredFields, QueryBuilder userProvidedQueryBuilder) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
for (String requiredField : requiredFields) {
boolQuery.filter(QueryBuilders.existsQuery(requiredField));
default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
Objects.requireNonNull(userProvidedQueryBuilder);
BoolQueryBuilder boolQuery =
QueryBuilders.boolQuery()
// Verify existence of required fields
.filter(QueryBuilders.existsQuery(getActualField()))
.filter(QueryBuilders.existsQuery(getPredictedField()))
// Apply user-provided query
.filter(userProvidedQueryBuilder);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (EvaluationMetric metric : getMetrics()) {
// Fetch aggregations requested by individual metrics
List<AggregationBuilder> aggs = metric.aggs(getActualField(), getPredictedField());
aggs.forEach(searchSourceBuilder::aggregation);
}
boolQuery.filter(userProvidedQueryBuilder);
return new SearchSourceBuilder().size(0).query(boolQuery);
return searchSourceBuilder;
}

/**
* Processes {@link SearchResponse} from the search action
* @param searchResponse response from the search action
*/
void process(SearchResponse searchResponse);
default void process(SearchResponse searchResponse) {
Objects.requireNonNull(searchResponse);
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException(
"No documents found containing both [{}, {}] fields", getActualField(), getPredictedField());
}
for (EvaluationMetric metric : getMetrics()) {
metric.process(searchResponse.getAggregations());
}
}

/**
* @return true iff all the metrics have their results computed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;

import java.util.List;
import java.util.Optional;

/**
Expand All @@ -20,6 +24,20 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
*/
String getName();

/**
* Builds the aggregation that collect required data to compute the metric
* @param actualField the field that stores the actual value
* @param predictedField the field that stores the predicted value (class name or probability)
* @return the aggregations required to compute the metric
*/
List<AggregationBuilder> aggs(String actualField, String predictedField);

/**
* Processes given aggregations as a step towards computing result
* @param aggs aggregations from {@link SearchResponse}
*/
void process(Aggregations aggs);

/**
* Gets the evaluation result for this metric.
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,18 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -55,13 +48,13 @@ public static Classification fromXContent(XContentParser parser) {

/**
* The field containing the actual value
* The value of this field is assumed to be numeric
* The value of this field is assumed to be categorical
*/
private final String actualField;

/**
* The field containing the predicted value
* The value of this field is assumed to be numeric
* The value of this field is assumed to be categorical
*/
private final String predictedField;

Expand All @@ -73,7 +66,11 @@ public static Classification fromXContent(XContentParser parser) {
public Classification(String actualField, String predictedField, @Nullable List<ClassificationMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
this.metrics = initMetrics(metrics);
this.metrics = initMetrics(metrics, Classification::defaultMetrics);
}

private static List<ClassificationMetric> defaultMetrics() {
return Arrays.asList(new MulticlassConfusionMatrix());
}

public Classification(StreamInput in) throws IOException {
Expand All @@ -82,49 +79,24 @@ public Classification(StreamInput in) throws IOException {
this.metrics = in.readNamedWriteableList(ClassificationMetric.class);
}

private static List<ClassificationMetric> initMetrics(@Nullable List<ClassificationMetric> parsedMetrics) {
List<ClassificationMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
if (metrics.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
}
Collections.sort(metrics, Comparator.comparing(ClassificationMetric::getName));
return metrics;
}

private static List<ClassificationMetric> defaultMetrics() {
return Arrays.asList(new MulticlassConfusionMatrix());
}

@Override
public String getName() {
return NAME.getPreferredName();
}

@Override
public List<ClassificationMetric> getMetrics() {
return metrics;
public String getActualField() {
return actualField;
}

@Override
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder(List.of(actualField, predictedField), userProvidedQueryBuilder);
for (ClassificationMetric metric : metrics) {
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
aggs.forEach(searchSourceBuilder::aggregation);
}
return searchSourceBuilder;
public String getPredictedField() {
return predictedField;
}

@Override
public void process(SearchResponse searchResponse) {
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
}
for (ClassificationMetric metric : metrics) {
metric.process(searchResponse.getAggregations());
}
public List<ClassificationMetric> getMetrics() {
return metrics;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;

import java.util.List;

public interface ClassificationMetric extends EvaluationMetric {

/**
* Builds the aggregation that collect required data to compute the metric
* @param actualField the field that stores the actual value
* @param predictedField the field that stores the predicted value
* @return the aggregations required to compute the metric
*/
List<AggregationBuilder> aggs(String actualField, String predictedField);

/**
* Processes given aggregations as a step towards computing result
* @param aggs aggregations from {@link SearchResponse}
*/
void process(Aggregations aggs);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,18 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;

import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -73,7 +66,11 @@ public static Regression fromXContent(XContentParser parser) {
public Regression(String actualField, String predictedField, @Nullable List<RegressionMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
this.metrics = initMetrics(metrics);
this.metrics = initMetrics(metrics, Regression::defaultMetrics);
}

private static List<RegressionMetric> defaultMetrics() {
return Arrays.asList(new MeanSquaredError(), new RSquared());
}

public Regression(StreamInput in) throws IOException {
Expand All @@ -82,49 +79,24 @@ public Regression(StreamInput in) throws IOException {
this.metrics = in.readNamedWriteableList(RegressionMetric.class);
}

private static List<RegressionMetric> initMetrics(@Nullable List<RegressionMetric> parsedMetrics) {
List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
if (metrics.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
}
Collections.sort(metrics, Comparator.comparing(RegressionMetric::getName));
return metrics;
}

private static List<RegressionMetric> defaultMetrics() {
return Arrays.asList(new MeanSquaredError(), new RSquared());
}

@Override
public String getName() {
return NAME.getPreferredName();
}

@Override
public List<RegressionMetric> getMetrics() {
return metrics;
public String getActualField() {
return actualField;
}

@Override
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder(List.of(actualField, predictedField), userProvidedQueryBuilder);
for (RegressionMetric metric : metrics) {
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
aggs.forEach(searchSourceBuilder::aggregation);
}
return searchSourceBuilder;
public String getPredictedField() {
return predictedField;
}

@Override
public void process(SearchResponse searchResponse) {
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
}
for (RegressionMetric metric : metrics) {
metric.process(searchResponse.getAggregations());
}
public List<RegressionMetric> getMetrics() {
return metrics;
}

@Override
Expand Down
Loading