Skip to content

Commit

Permalink
Simplify handling actualField in process() method
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek committed Dec 19, 2019
1 parent 7f571ae commit dd74939
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ public class Precision implements EvaluationMetric {
static final String BY_PREDICTED_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_predicted_class";
static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "per_predicted_class_precision";
static final String AVG_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "avg_precision";
private static String ACTUAL_FIELD_METADATA_KEY = "actual_field";

private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
Expand All @@ -79,6 +78,7 @@ public static Precision fromXContent(XContentParser parser) {
private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;

private final int maxClassesCardinality;
private String actualField;
private List<String> topActualClassNames;
private EvaluationMetricResult result;

Expand Down Expand Up @@ -107,14 +107,15 @@ public String getName() {

@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField = actualField;
if (topActualClassNames == null) { // This is step 1
return Tuple.tuple(
List.of(
AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
.field(actualField)
.order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
.size(maxClassesCardinality)
.setMetaData(Collections.singletonMap(ACTUAL_FIELD_METADATA_KEY, actualField))),
.size(maxClassesCardinality)),
List.of());
}
if (result == null) { // This is step 2
Expand Down Expand Up @@ -143,8 +144,7 @@ public void process(Aggregations aggs) {
// This means there were more than {@code maxClassesCardinality} buckets.
// We cannot calculate average precision accurately, so we fail.
throw ExceptionsHelper.badRequestException(
"Cannot calculate average precision. Cardinality of field [{}] is too high",
topActualClassesAgg.getMetaData().get(ACTUAL_FIELD_METADATA_KEY));
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField);
}
topActualClassNames =
topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ public class Recall implements EvaluationMetric {
static final String BY_ACTUAL_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
static final String PER_ACTUAL_CLASS_RECALL_AGG_NAME = AGG_NAME_PREFIX + "per_actual_class_recall";
static final String AVG_RECALL_AGG_NAME = AGG_NAME_PREFIX + "avg_recall";
private static String ACTUAL_FIELD_METADATA_KEY = "actual_field";

private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
Expand All @@ -73,6 +72,7 @@ public static Recall fromXContent(XContentParser parser) {
private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;

private final int maxClassesCardinality;
private String actualField;
private EvaluationMetricResult result;

public Recall() {
Expand Down Expand Up @@ -100,6 +100,8 @@ public String getName() {

@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField = actualField;
if (result != null) {
return Tuple.tuple(List.of(), List.of());
}
Expand All @@ -109,7 +111,6 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME)
.field(actualField)
.size(maxClassesCardinality)
.setMetaData(Collections.singletonMap(ACTUAL_FIELD_METADATA_KEY, actualField))
.subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))),
List.of(
PipelineAggregatorBuilders.avgBucket(
Expand All @@ -127,8 +128,7 @@ public void process(Aggregations aggs) {
// This means there were more than {@code maxClassesCardinality} buckets.
// We cannot calculate average recall accurately, so we fail.
throw ExceptionsHelper.badRequestException(
"Cannot calculate average recall. Cardinality of field [{}] is too high",
byActualClassAgg.getMetaData().get(ACTUAL_FIELD_METADATA_KEY));
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField);
}
NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME);
List<PerClassResult> classes = new ArrayList<>(byActualClassAgg.getBuckets().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.test.AbstractSerializingTestCase;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
Expand All @@ -26,7 +24,6 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.when;

public class PrecisionTests extends AbstractSerializingTestCase<Precision> {

Expand Down Expand Up @@ -112,11 +109,10 @@ public void testProcess_GivenAggOfWrongType() {
}

public void testProcess_GivenCardinalityTooHigh() {
Terms actualClassesAgg = mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1);
when(actualClassesAgg.getMetaData()).thenReturn(Map.of("actual_field", "foo"));
Aggregations aggs = new Aggregations(Collections.singletonList(actualClassesAgg));
Aggregations aggs =
new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1)));
Precision precision = new Precision();

precision.aggs("foo", "bar");
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.test.AbstractSerializingTestCase;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
Expand All @@ -25,7 +23,6 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.when;

public class RecallTests extends AbstractSerializingTestCase<Recall> {

Expand Down Expand Up @@ -110,11 +107,11 @@ public void testProcess_GivenAggOfWrongType() {
}

public void testProcess_GivenCardinalityTooHigh() {
Terms byActualClassAgg = mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1);
when(byActualClassAgg.getMetaData()).thenReturn(Map.of("actual_field", "foo"));
Aggregations aggs = new Aggregations(Arrays.asList(byActualClassAgg, mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)));
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1),
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)));
Recall recall = new Recall();

recall.aggs("foo", "bar");
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
}
Expand Down

0 comments on commit dd74939

Please sign in to comment.