From 9e5d3be6ed7b8eb7a8292940afc1ae11438831ac Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Mon, 25 Nov 2019 13:47:12 -0500 Subject: [PATCH] Reuse CompensatedSum object in agg collect loops --- .../aggregations/metrics/AvgAggregator.java | 5 +- .../aggregations/metrics/CompensatedSum.java | 8 +++ .../metrics/ExtendedStatsAggregator.java | 6 ++- .../metrics/GeoCentroidAggregator.java | 7 ++- .../aggregations/metrics/StatsAggregator.java | 4 +- .../aggregations/metrics/SumAggregator.java | 3 +- .../metrics/WeightedAvgAggregator.java | 53 ++++++++++--------- 7 files changed, 54 insertions(+), 32 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/AvgAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/AvgAggregator.java index 843e380e425ea..9dee689831ac9 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/AvgAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/AvgAggregator.java @@ -73,6 +73,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, } final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); + final CompensatedSum kahanSummation = new CompensatedSum(0, 0); + return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long bucket) throws IOException { @@ -87,7 +89,8 @@ public void collect(int doc, long bucket) throws IOException { // accurate than naive summation. double sum = sums.get(bucket); double compensation = compensations.get(bucket); - CompensatedSum kahanSummation = new CompensatedSum(sum, compensation); + + kahanSummation.reset(sum, compensation); for (int i = 0; i < valueCount; i++) { double value = values.nextValue(); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/CompensatedSum.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/CompensatedSum.java index 965ac665159a0..85ae940681637 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/CompensatedSum.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/CompensatedSum.java @@ -68,6 +68,14 @@ public CompensatedSum add(double value) { return add(value, NO_CORRECTION); } + /** + * Resets the internal state to use the new value and compensation delta + */ + public void reset(double value, double delta) { + this.value = value; + this.delta = delta; + } + /** * Increments the Kahan sum by adding two sums, and updating the correction term for reducing numeric errors. */ diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ExtendedStatsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ExtendedStatsAggregator.java index c4dcfebf5e1be..f6b9420997fc0 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ExtendedStatsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ExtendedStatsAggregator.java @@ -90,6 +90,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, } final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); + final CompensatedSum compensatedSum = new CompensatedSum(0, 0); + final CompensatedSum compensatedSumOfSqr = new CompensatedSum(0, 0); return new LeafBucketCollectorBase(sub, values) { @Override @@ -117,11 +119,11 @@ public void collect(int doc, long bucket) throws IOException { // which is more accurate than naive summation. double sum = sums.get(bucket); double compensation = compensations.get(bucket); - CompensatedSum compensatedSum = new CompensatedSum(sum, compensation); + compensatedSum.reset(sum, compensation); double sumOfSqr = sumOfSqrs.get(bucket); double compensationOfSqr = compensationOfSqrs.get(bucket); - CompensatedSum compensatedSumOfSqr = new CompensatedSum(sumOfSqr, compensationOfSqr); + compensatedSumOfSqr.reset(sumOfSqr, compensationOfSqr); for (int i = 0; i < valuesCount; i++) { double value = values.nextValue(); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/GeoCentroidAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/GeoCentroidAggregator.java index d5a91b002213e..bf318896e55e0 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/GeoCentroidAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/GeoCentroidAggregator.java @@ -68,6 +68,9 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol } final BigArrays bigArrays = context.bigArrays(); final MultiGeoPointValues values = valuesSource.geoPointValues(ctx); + final CompensatedSum compensatedSumLat = new CompensatedSum(0, 0); + final CompensatedSum compensatedSumLon = new CompensatedSum(0, 0); + return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long bucket) throws IOException { @@ -88,8 +91,8 @@ public void collect(int doc, long bucket) throws IOException { double sumLon = lonSum.get(bucket); double compensationLon = lonCompensations.get(bucket); - CompensatedSum compensatedSumLat = new CompensatedSum(sumLat, compensationLat); - CompensatedSum compensatedSumLon = new CompensatedSum(sumLon, compensationLon); + compensatedSumLat.reset(sumLat, compensationLat); + compensatedSumLon.reset(sumLon, compensationLon); // update the sum for (int i = 0; i < valueCount; ++i) { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/StatsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/StatsAggregator.java index 7799f498dd491..7ae5b016f75c3 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/StatsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/StatsAggregator.java @@ -81,6 +81,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, } final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); + final CompensatedSum kahanSummation = new CompensatedSum(0, 0); + return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long bucket) throws IOException { @@ -105,7 +107,7 @@ public void collect(int doc, long bucket) throws IOException { // accurate than naive summation. double sum = sums.get(bucket); double compensation = compensations.get(bucket); - CompensatedSum kahanSummation = new CompensatedSum(sum, compensation); + kahanSummation.reset(sum, compensation); for (int i = 0; i < valuesCount; i++) { double value = values.nextValue(); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/SumAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/SumAggregator.java index cc440fd7d0554..ebb0e36dbf5db 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/SumAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/SumAggregator.java @@ -69,6 +69,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, } final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); + final CompensatedSum kahanSummation = new CompensatedSum(0, 0); return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long bucket) throws IOException { @@ -81,7 +82,7 @@ public void collect(int doc, long bucket) throws IOException { // accurate than naive summation. double sum = sums.get(bucket); double compensation = compensations.get(bucket); - CompensatedSum kahanSummation = new CompensatedSum(sum, compensation); + kahanSummation.reset(sum, compensation); for (int i = 0; i < valuesCount; i++) { double value = values.nextValue(); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/WeightedAvgAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/WeightedAvgAggregator.java index 11b4a5df951dd..ab5d1669e036f 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/WeightedAvgAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/WeightedAvgAggregator.java @@ -46,8 +46,8 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue { private final MultiValuesSource.NumericMultiValuesSource valuesSources; private DoubleArray weights; - private DoubleArray sums; - private DoubleArray sumCompensations; + private DoubleArray valueSums; + private DoubleArray valueCompensations; private DoubleArray weightCompensations; private DocValueFormat format; @@ -60,8 +60,8 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue { if (valuesSources != null) { final BigArrays bigArrays = context.bigArrays(); weights = bigArrays.newDoubleArray(1, true); - sums = bigArrays.newDoubleArray(1, true); - sumCompensations = bigArrays.newDoubleArray(1, true); + valueSums = bigArrays.newDoubleArray(1, true); + valueCompensations = bigArrays.newDoubleArray(1, true); weightCompensations = bigArrays.newDoubleArray(1, true); } } @@ -80,13 +80,15 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx); final SortedNumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), ctx); + final CompensatedSum compensatedValueSum = new CompensatedSum(0, 0); + final CompensatedSum compensatedWeightSum = new CompensatedSum(0, 0); return new LeafBucketCollectorBase(sub, docValues) { @Override public void collect(int doc, long bucket) throws IOException { weights = bigArrays.grow(weights, bucket + 1); - sums = bigArrays.grow(sums, bucket + 1); - sumCompensations = bigArrays.grow(sumCompensations, bucket + 1); + valueSums = bigArrays.grow(valueSums, bucket + 1); + valueCompensations = bigArrays.grow(valueCompensations, bucket + 1); weightCompensations = bigArrays.grow(weightCompensations, bucket + 1); if (docValues.advanceExact(doc) && docWeights.advanceExact(doc)) { @@ -102,42 +104,43 @@ public void collect(int doc, long bucket) throws IOException { final int numValues = docValues.docValueCount(); assert numValues > 0; + double valueSum = valueSums.get(bucket); + double valueCompensation = valueCompensations.get(bucket); + compensatedValueSum.reset(valueSum, valueCompensation); + + double weightSum = weights.get(bucket); + double weightCompensation = weightCompensations.get(bucket); + compensatedWeightSum.reset(weightSum, weightCompensation); + for (int i = 0; i < numValues; i++) { - kahanSum(docValues.nextValue() * weight, sums, sumCompensations, bucket); - kahanSum(weight, weights, weightCompensations, bucket); + compensatedValueSum.add(docValues.nextValue() * weight); + compensatedWeightSum.add(weight); } + + valueSums.set(bucket, compensatedValueSum.value()); + valueCompensations.set(bucket, compensatedValueSum.delta()); + weights.set(bucket, compensatedWeightSum.value()); + weightCompensations.set(bucket, compensatedWeightSum.delta()); } } }; } - private static void kahanSum(double value, DoubleArray values, DoubleArray compensations, long bucket) { - // Compute the sum of double values with Kahan summation algorithm which is more - // accurate than naive summation. - double sum = values.get(bucket); - double compensation = compensations.get(bucket); - - CompensatedSum kahanSummation = new CompensatedSum(sum, compensation) - .add(value); - - values.set(bucket, kahanSummation.value()); - compensations.set(bucket, kahanSummation.delta()); - } @Override public double metric(long owningBucketOrd) { - if (valuesSources == null || owningBucketOrd >= sums.size()) { + if (valuesSources == null || owningBucketOrd >= valueSums.size()) { return Double.NaN; } - return sums.get(owningBucketOrd) / weights.get(owningBucketOrd); + return valueSums.get(owningBucketOrd) / weights.get(owningBucketOrd); } @Override public InternalAggregation buildAggregation(long bucket) { - if (valuesSources == null || bucket >= sums.size()) { + if (valuesSources == null || bucket >= valueSums.size()) { return buildEmptyAggregation(); } - return new InternalWeightedAvg(name, sums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData()); + return new InternalWeightedAvg(name, valueSums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData()); } @Override @@ -147,7 +150,7 @@ public InternalAggregation buildEmptyAggregation() { @Override public void doClose() { - Releasables.close(weights, sums, sumCompensations, weightCompensations); + Releasables.close(weights, valueSums, valueCompensations, weightCompensations); } }