From 843fa42c1ea2e57f59b901ea0d73673f0798debe Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 12 Oct 2021 10:38:09 -0400 Subject: [PATCH] [ML] add new normalize_above parameter to p_value significant terms heuristic (#78833) This commit adds the new normalize_above parameter to the p_value significant terms heuristic. This parameter allows for consistent significance results at various scales. When a total count (in or out of the set background set) is above the normalize_above parameter, both the total set and the set including the term are scaled by normalize_above/count where count is term in the set or total set size. --- .../significantterms-aggregation.asciidoc | 6 +- ...AbstractSignificanceHeuristicTestCase.java | 8 +- .../xpack/ml/aggs/heuristic/PValueScore.java | 83 ++++++++++++++--- .../ml/aggs/heuristic/PValueScoreTests.java | 89 ++++++++++++++----- 4 files changed, 151 insertions(+), 35 deletions(-) diff --git a/docs/reference/aggregations/bucket/significantterms-aggregation.asciidoc b/docs/reference/aggregations/bucket/significantterms-aggregation.asciidoc index 42fec9c9d74ba..003bdd047113e 100644 --- a/docs/reference/aggregations/bucket/significantterms-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/significantterms-aggregation.asciidoc @@ -404,6 +404,10 @@ the foreground set of "ended in failure" versus "NOT ended in failure". `"background_is_superset": false` indicates that the background set does not contain the counts of the foreground set as they are filtered out. +`"normalize_above": 1000` facilitates returning consistent significance results +at various scales. `1000` indicates that term counts greater than `1000` are +scaled down by a factor of `1000/term_count`. + [source,console] -------------------------------------------------- GET /_search @@ -466,7 +470,7 @@ GET /_search ] } }, - "p_value": {"background_is_superset": false} + "p_value": {"background_is_superset": false, "normalize_above": 1000} } } } diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/bucket/AbstractSignificanceHeuristicTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/bucket/AbstractSignificanceHeuristicTestCase.java index 5195619945f45..7e0143e3a9ceb 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/bucket/AbstractSignificanceHeuristicTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/bucket/AbstractSignificanceHeuristicTestCase.java @@ -37,6 +37,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.heuristic.SignificanceHeuristic; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.InternalAggregationTestCase; +import org.elasticsearch.test.VersionUtils; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -52,7 +53,6 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; import static org.elasticsearch.search.aggregations.AggregationBuilders.significantTerms; -import static org.elasticsearch.test.VersionUtils.randomVersion; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -69,9 +69,13 @@ public abstract class AbstractSignificanceHeuristicTestCase extends ESTestCase { */ protected abstract SignificanceHeuristic getHeuristic(); + protected Version randomVersion() { + return VersionUtils.randomVersion(random()); + } + // test that stream output can actually be read - does not replace bwc test public void testStreamResponse() throws Exception { - Version version = randomVersion(random()); + Version version = randomVersion(); InternalMappedSignificantTerms sigTerms = getRandomSignificantTerms(getHeuristic()); // write diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScore.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScore.java index 51d9f5ee54682..0de8e6dbd1c56 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScore.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScore.java @@ -10,9 +10,11 @@ import org.apache.commons.math3.util.FastMath; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.search.aggregations.AggregationExecutionException; @@ -20,47 +22,80 @@ import org.elasticsearch.search.aggregations.bucket.terms.heuristic.SignificanceHeuristic; import java.io.IOException; +import java.util.Objects; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +/** + * Significant terms heuristic that calculates the p-value between the term existing in foreground and background sets. + * + * The p-value is the probability of obtaining test results at least as extreme as + * the results actually observed, under the assumption that the null hypothesis is + * correct. The p-value is calculated assuming that the foreground set and the + * background set are independent https://en.wikipedia.org/wiki/Bernoulli_trial, with the null + * hypothesis that the probabilities are the same. + */ public class PValueScore extends NXYSignificanceHeuristic { public static final String NAME = "p_value"; + public static final ParseField NORMALIZE_ABOVE = new ParseField("normalize_above"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { boolean backgroundIsSuperset = args[0] == null || (boolean) args[0]; - return new PValueScore(backgroundIsSuperset); + return new PValueScore(backgroundIsSuperset, (Long)args[1]); }); static { PARSER.declareBoolean(optionalConstructorArg(), BACKGROUND_IS_SUPERSET); + PARSER.declareLong(optionalConstructorArg(), NORMALIZE_ABOVE); } private static final MlChiSquaredDistribution CHI_SQUARED_DISTRIBUTION = new MlChiSquaredDistribution(1); - public PValueScore(boolean backgroundIsSuperset) { + // NOTE: `0` is a magic value indicating no normalization occurs + private final long normalizeAbove; + + /** + * @param backgroundIsSuperset Does the background contain the foreground docs? + * @param normalizeAbove Should the results be normalized when above the given value. + * Note: `0` is a special value which means no normalization (set as such when `null` is provided) + */ + public PValueScore(boolean backgroundIsSuperset, Long normalizeAbove) { super(true, backgroundIsSuperset); + if (normalizeAbove != null && normalizeAbove <= 0) { + throw new IllegalArgumentException( + "[" + NORMALIZE_ABOVE.getPreferredName() + "] must be a positive value, provided [" + normalizeAbove + "]" + ); + } + this.normalizeAbove = normalizeAbove == null ? 0L : normalizeAbove; } public PValueScore(StreamInput in) throws IOException { super(true, in.readBoolean()); + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { + normalizeAbove = in.readVLong(); + } else { + normalizeAbove = 0L; + } } @Override public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(backgroundIsSuperset); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { + out.writeVLong(normalizeAbove); + } } @Override - public boolean equals(Object obj) { - if ((obj instanceof PValueScore) == false) { - return false; - } - return super.equals(obj); + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + PValueScore that = (PValueScore) o; + return normalizeAbove == that.normalizeAbove; } @Override public int hashCode() { - int result = NAME.hashCode(); - result = 31 * result + super.hashCode(); - return result; + return Objects.hash(super.hashCode(), normalizeAbove); } @Override @@ -72,6 +107,9 @@ public String getWriteableName() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(NAME); builder.field(BACKGROUND_IS_SUPERSET.getPreferredName(), backgroundIsSuperset); + if (normalizeAbove > 0) { + builder.field(NORMALIZE_ABOVE.getPreferredName(), normalizeAbove); + } builder.endObject(); return builder; } @@ -113,6 +151,19 @@ public double getScore(long subsetFreq, long subsetSize, long supersetFreq, long return 0.0; } + if (normalizeAbove > 0L) { + if (allDocsInClass > normalizeAbove) { + double factor = (double) normalizeAbove / allDocsInClass; + allDocsInClass = (long)(allDocsInClass * factor); + docsContainTermInClass = (long)(docsContainTermInClass * factor); + } + if (allDocsNotInClass > normalizeAbove) { + double factor = (double) normalizeAbove / allDocsNotInClass; + allDocsNotInClass = (long)(allDocsNotInClass * factor); + docsContainTermNotInClass = (long)(docsContainTermNotInClass * factor); + } + } + // casting to `long` to round down to nearest whole number double epsAllDocsInClass = (long)eps(allDocsInClass); double epsAllDocsNotInClass = (long)eps(allDocsNotInClass); @@ -164,15 +215,25 @@ private double eps(double value) { } public static class PValueScoreBuilder extends NXYBuilder { + private final long normalizeAbove; - public PValueScoreBuilder(boolean backgroundIsSuperset) { + public PValueScoreBuilder(boolean backgroundIsSuperset, Long normalizeAbove) { super(true, backgroundIsSuperset); + this.normalizeAbove = normalizeAbove == null ? 0L : normalizeAbove; + if (normalizeAbove != null && normalizeAbove <= 0) { + throw new IllegalArgumentException( + "[" + NORMALIZE_ABOVE.getPreferredName() + "] must be a positive value, provided [" + normalizeAbove + "]" + ); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(NAME); builder.field(BACKGROUND_IS_SUPERSET.getPreferredName(), backgroundIsSuperset); + if (normalizeAbove > 0) { + builder.field(NORMALIZE_ABOVE.getPreferredName(), normalizeAbove); + } builder.endObject(); return builder; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScoreTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScoreTests.java index f7bcce1e3b88f..f5d0813d949ed 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScoreTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScoreTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.aggs.heuristic; import org.apache.commons.math3.util.FastMath; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -29,19 +30,27 @@ public class PValueScoreTests extends AbstractNXYSignificanceHeuristicTestCase { private static final double eps = 1e-9; + @Override + protected Version randomVersion() { + return Version.V_8_0_0; + } + @Override protected SignificanceHeuristic getHeuristic() { - return new PValueScore(randomBoolean()); + return new PValueScore(randomBoolean(), randomBoolean() ? null : randomLongBetween(1, 10000000L)); } @Override protected SignificanceHeuristic getHeuristic(boolean includeNegatives, boolean backgroundIsSuperset) { - return new PValueScore(backgroundIsSuperset); + return new PValueScore(backgroundIsSuperset, randomBoolean() ? null : randomLongBetween(1, 10000000L)); } @Override public void testAssertions() { - testBackgroundAssertions(new PValueScore(true), new PValueScore(false)); + testBackgroundAssertions( + new PValueScore(true, randomBoolean() ? null : randomLongBetween(1, 10000000L)), + new PValueScore(false, randomBoolean() ? null : randomLongBetween(1, 10000000L)) + ); } @Override @@ -59,7 +68,7 @@ protected NamedWriteableRegistry writableRegistry() { } public void testPValueScore_WhenAllDocsContainTerm() { - PValueScore pValueScore = new PValueScore(randomBoolean()); + PValueScore pValueScore = new PValueScore(randomBoolean(), null); long supersetCount = randomNonNegativeLong(); long subsetCount = randomLongBetween(0L, supersetCount); assertThat(pValueScore.getScore(subsetCount, subsetCount, supersetCount, supersetCount), equalTo(0.0)); @@ -78,7 +87,7 @@ public void testHighPValueScore() { supersetFreqCount += subsetFreqCount; } - PValueScore pValueScore = new PValueScore(backgroundIsSuperset); + PValueScore pValueScore = new PValueScore(backgroundIsSuperset, null); assertThat(pValueScore.getScore(subsetFreqCount, subsetCount, supersetFreqCount, supersetCount), greaterThanOrEqualTo(700.0)); } @@ -95,7 +104,7 @@ public void testLowPValueScore() { supersetFreqCount += subsetFreqCount; } - PValueScore pValueScore = new PValueScore(backgroundIsSuperset); + PValueScore pValueScore = new PValueScore(backgroundIsSuperset, null); assertThat( pValueScore.getScore(subsetFreqCount, subsetCount, supersetFreqCount, supersetCount), allOf(lessThanOrEqualTo(5.0), greaterThanOrEqualTo(0.0)) @@ -104,66 +113,104 @@ public void testLowPValueScore() { public void testPValueScore() { assertThat( - FastMath.exp(-new PValueScore(false).getScore(10, 100, 100, 1000)), + FastMath.exp(-new PValueScore(false, null).getScore(10, 100, 100, 1000)), closeTo(1.0, eps) ); assertThat( - FastMath.exp(-new PValueScore(false).getScore(10, 100, 10, 1000)), + FastMath.exp(-new PValueScore(false, 200L).getScore(10, 100, 100, 1000)), + closeTo(1.0, eps) + ); + assertThat( + FastMath.exp(-new PValueScore(false, null).getScore(10, 100, 10, 1000)), closeTo(0.003972388976814195, eps) ); assertThat( - FastMath.exp(-new PValueScore(false).getScore(10, 100, 200, 1000)), + FastMath.exp(-new PValueScore(false, 200L).getScore(10, 100, 10, 1000)), + closeTo(0.020890782016496683, eps) + ); + assertThat( + FastMath.exp(-new PValueScore(false, null).getScore(10, 100, 200, 1000)), + closeTo(1.0, eps) + ); + assertThat( + FastMath.exp(-new PValueScore(false, 200L).getScore(10, 100, 200, 1000)), + closeTo(1.0, eps) + ); + assertThat( + FastMath.exp(-new PValueScore(false, null).getScore(20, 10000, 5, 10000)), closeTo(1.0, eps) ); assertThat( - FastMath.exp(-new PValueScore(false).getScore(20, 10000, 5, 10000)), + FastMath.exp(-new PValueScore(false, 200L).getScore(20, 10000, 5, 10000)), closeTo(1.0, eps) ); } public void testSmallChanges() { assertThat( - FastMath.exp(-new PValueScore(false).getScore(1, 4205, 0, 821496)), + FastMath.exp(-new PValueScore(false, null).getScore(1, 4205, 0, 821496)), closeTo(0.9999037287868853, eps) ); + // Same(ish) ratios assertThat( - FastMath.exp(-new PValueScore(false).getScore(10, 4205, 195, 82149)), + FastMath.exp(-new PValueScore(false, null).getScore(10, 4205, 195, 82149)), closeTo(0.9995943820612134, eps) ); assertThat( - FastMath.exp(-new PValueScore(false).getScore(10, 4205, 1950, 821496)), + FastMath.exp(-new PValueScore(false, 100L).getScore(10, 4205, 195, 82149)), + closeTo(0.9876284079864467, eps) + ); + + assertThat( + FastMath.exp(-new PValueScore(false, null).getScore(10, 4205, 1950, 821496)), closeTo(0.9999942565428899, eps) ); + assertThat( + FastMath.exp(-new PValueScore(false, 100L).getScore(10, 4205, 1950, 821496)), + closeTo(1.0, eps) + ); // 4% vs 0% assertThat( - FastMath.exp(-new PValueScore(false).getScore(168, 4205, 0, 821496)), + FastMath.exp(-new PValueScore(false, null).getScore(168, 4205, 0, 821496)), closeTo(1.2680918648731284e-26, eps) ); + assertThat( + FastMath.exp(-new PValueScore(false, 100L).getScore(168, 4205, 0, 821496)), + closeTo(0.3882951183744724, eps) + ); // 4% vs 2% assertThat( - FastMath.exp(-new PValueScore(false).getScore(168, 4205, 16429, 821496)), + FastMath.exp(-new PValueScore(false, null).getScore(168, 4205, 16429, 821496)), closeTo(8.542608559219833e-5, eps) ); + assertThat( + FastMath.exp(-new PValueScore(false, 100L).getScore(168, 4205, 16429, 821496)), + closeTo(0.579463586350363, eps) + ); // 4% vs 3.5% assertThat( - FastMath.exp(-new PValueScore(false).getScore(168, 4205, 28752, 821496)), + FastMath.exp(-new PValueScore(false, null).getScore(168, 4205, 28752, 821496)), closeTo(0.8833950526957098, eps) ); + assertThat( + FastMath.exp(-new PValueScore(false, 100L).getScore(168, 4205, 28752, 821496)), + closeTo(1.0, eps) + ); } public void testLargerValues() { assertThat( - FastMath.exp(-new PValueScore(false).getScore(101000, 1000000, 500000, 5000000)), + FastMath.exp(-new PValueScore(false, null).getScore(101000, 1000000, 500000, 5000000)), closeTo(1.0, eps) ); assertThat( - FastMath.exp(-new PValueScore(false).getScore(102000, 1000000, 500000, 5000000)), + FastMath.exp(-new PValueScore(false, null).getScore(102000, 1000000, 500000, 5000000)), closeTo(1.0, eps) ); assertThat( - FastMath.exp(-new PValueScore(false).getScore(103000, 1000000, 500000, 5000000)), + FastMath.exp(-new PValueScore(false, null).getScore(103000, 1000000, 500000, 5000000)), closeTo(1.0, eps) ); } @@ -171,7 +218,7 @@ public void testLargerValues() { public void testScoreIsZero() { for (int j = 0; j < 10; j++) { assertThat( - new PValueScore(false).getScore((j + 1)*5, (j + 10)*100, (j + 1)*10, (j + 10)*100), + new PValueScore(false, null).getScore((j + 1)*5, (j + 10)*100, (j + 1)*10, (j + 10)*100), equalTo(0.0) ); } @@ -179,7 +226,7 @@ public void testScoreIsZero() { public void testIncreasedSubsetIncreasedScore() { final Function getScore = (subsetFreq) -> - new PValueScore(false).getScore(subsetFreq, 5000, 5, 5000); + new PValueScore(false, null).getScore(subsetFreq, 5000, 5, 5000); double priorScore = getScore.apply(5L); assertThat(priorScore, greaterThanOrEqualTo(0.0)); for (int j = 1; j < 11; j++) {