Skip to content

Commit

Permalink
[7.x] Allow evaluation to consist of multiple steps. (#46653) (#47194)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek authored Sep 27, 2019
1 parent a1e2e20 commit 3fbd58d
Show file tree
Hide file tree
Showing 21 changed files with 338 additions and 172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,31 @@ public String[] getIndices() {
return indices;
}

public final void setIndices(List<String> indices) {
public final Request setIndices(List<String> indices) {
ExceptionsHelper.requireNonNull(indices, INDEX);
if (indices.isEmpty()) {
throw ExceptionsHelper.badRequestException("At least one index must be specified");
}
this.indices = indices.toArray(new String[indices.size()]);
return this;
}

public QueryBuilder getParsedQuery() {
return Optional.ofNullable(queryProvider).orElseGet(QueryProvider::defaultQuery).getParsedQuery();
}

public final void setQueryProvider(QueryProvider queryProvider) {
public final Request setQueryProvider(QueryProvider queryProvider) {
this.queryProvider = queryProvider;
return this;
}

public Evaluation getEvaluation() {
return evaluation;
}

public final void setEvaluation(Evaluation evaluation) {
public final Request setEvaluation(Evaluation evaluation) {
this.evaluation = ExceptionsHelper.requireNonNull(evaluation, EVALUATION);
return this;
}

@Override
Expand Down Expand Up @@ -203,6 +206,14 @@ public Response(String evaluationName, List<EvaluationMetricResult> metrics) {
this.metrics = Objects.requireNonNull(metrics);
}

public String getEvaluationName() {
return evaluationName;
}

public List<EvaluationMetricResult> getMetrics() {
return metrics;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(evaluationName);
Expand All @@ -214,7 +225,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.startObject(evaluationName);
for (EvaluationMetricResult metric : metrics) {
builder.field(metric.getName(), metric);
builder.field(metric.getMetricName(), metric);
}
builder.endObject();
builder.endObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;

import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* Defines an evaluation
Expand All @@ -24,16 +27,54 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
*/
String getName();

/**
* Returns the list of metrics to evaluate
* @return list of metrics to evaluate
*/
List<? extends EvaluationMetric> getMetrics();

/**
* Builds the search required to collect data to compute the evaluation result
* @param queryBuilder User-provided query that must be respected when collecting data
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
*/
SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder);

/**
* Builds the search that verifies existence of required fields and applies user-provided query
* @param requiredFields fields that must exist
* @param userProvidedQueryBuilder user-provided query
*/
default SearchSourceBuilder newSearchSourceBuilder(List<String> requiredFields, QueryBuilder userProvidedQueryBuilder) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
for (String requiredField : requiredFields) {
boolQuery.filter(QueryBuilders.existsQuery(requiredField));
}
boolQuery.filter(userProvidedQueryBuilder);
return new SearchSourceBuilder().size(0).query(boolQuery);
}

/**
* Processes {@link SearchResponse} from the search action
* @param searchResponse response from the search action
*/
void process(SearchResponse searchResponse);

/**
* @return true iff all the metrics have their results computed
*/
SearchSourceBuilder buildSearch(QueryBuilder queryBuilder);
default boolean hasAllResults() {
return getMetrics().stream().map(EvaluationMetric::getResult).allMatch(Optional::isPresent);
}

/**
* Computes the evaluation result
* @param searchResponse The search response required to compute the result
* @param listener A listener of the results
* Returns the list of evaluation results
* @return list of evaluation results
*/
void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener);
default List<EvaluationMetricResult> getResults() {
return getMetrics().stream()
.map(EvaluationMetric::getResult)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;

import java.util.Optional;

/**
* {@link EvaluationMetric} class represents a metric to evaluate.
*/
public interface EvaluationMetric extends ToXContentObject, NamedWriteable {

/**
* Returns the name of the metric (which may differ to the writeable name)
*/
String getName();

/**
* Gets the evaluation result for this metric.
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
*/
Optional<EvaluationMetricResult> getResult();
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
public interface EvaluationMetricResult extends ToXContentObject, NamedWriteable {

/**
* Returns the name of the metric
* Returns the name of the metric (which may differ to the writeable name)
*/
String getName();
String getMetricName();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import java.io.IOException;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;

/**
* Calculates the mean squared error between two known numerical fields.
Expand All @@ -48,28 +50,34 @@ public static MeanSquaredError fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

public MeanSquaredError(StreamInput in) {
private EvaluationMetricResult result;

}

public MeanSquaredError() {
public MeanSquaredError(StreamInput in) {}

}
public MeanSquaredError() {}

@Override
public String getMetricName() {
public String getName() {
return NAME.getPreferredName();
}

@Override
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
return Collections.singletonList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
if (result != null) {
return Collections.emptyList();
}
return Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
}

@Override
public EvaluationMetricResult evaluate(Aggregations aggs) {
public void process(Aggregations aggs) {
NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
return value == null ? new Result(0.0) : new Result(value.value());
result = value == null ? new Result(0.0) : new Result(value.value());
}

@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}

@Override
Expand Down Expand Up @@ -121,7 +129,7 @@ public String getWriteableName() {
}

@Override
public String getName() {
public String getMetricName() {
return NAME.getPreferredName();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
import java.io.IOException;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;

/**
* Calculates R-Squared between two known numerical fields.
Expand Down Expand Up @@ -53,36 +55,42 @@ public static RSquared fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

public RSquared(StreamInput in) {
private EvaluationMetricResult result;

}

public RSquared() {
public RSquared(StreamInput in) {}

}
public RSquared() {}

@Override
public String getMetricName() {
public String getName() {
return NAME.getPreferredName();
}

@Override
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
if (result != null) {
return Collections.emptyList();
}
return Arrays.asList(
AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField));
}

@Override
public EvaluationMetricResult evaluate(Aggregations aggs) {
public void process(Aggregations aggs) {
NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES);
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
result = residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
new Result(0.0) :
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
}

@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}

@Override
public String getWriteableName() {
return NAME.getPreferredName();
Expand Down Expand Up @@ -132,7 +140,7 @@ public String getWriteableName() {
}

@Override
public String getName() {
public String getMetricName() {
return NAME.getPreferredName();
}

Expand Down
Loading

0 comments on commit 3fbd58d

Please sign in to comment.