Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse CompensatedSum object in agg collect loop #49548

Merged
merged 2 commits into from
Nov 25, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);

return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
Expand All @@ -87,7 +89,8 @@ public void collect(int doc, long bucket) throws IOException {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);

kahanSummation.reset(sum, compensation);

for (int i = 0; i < valueCount; i++) {
double value = values.nextValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ public CompensatedSum add(double value) {
return add(value, NO_CORRECTION);
}

/**
* Resets the internal state to use the new value and compensation delta
*/
public void reset(double value, double delta) {
this.value = value;
this.delta = delta;
}

/**
* Increments the Kahan sum by adding two sums, and updating the correction term for reducing numeric errors.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum compensatedSum = new CompensatedSum(0, 0);
final CompensatedSum compensatedSumOfSqr = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {

@Override
Expand Down Expand Up @@ -117,11 +119,11 @@ public void collect(int doc, long bucket) throws IOException {
// which is more accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum compensatedSum = new CompensatedSum(sum, compensation);
compensatedSum.reset(sum, compensation);

double sumOfSqr = sumOfSqrs.get(bucket);
double compensationOfSqr = compensationOfSqrs.get(bucket);
CompensatedSum compensatedSumOfSqr = new CompensatedSum(sumOfSqr, compensationOfSqr);
compensatedSumOfSqr.reset(sumOfSqr, compensationOfSqr);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
}
final BigArrays bigArrays = context.bigArrays();
final MultiGeoPointValues values = valuesSource.geoPointValues(ctx);
final CompensatedSum compensatedSumLat = new CompensatedSum(0, 0);
final CompensatedSum compensatedSumLon = new CompensatedSum(0, 0);

return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
Expand All @@ -88,8 +91,8 @@ public void collect(int doc, long bucket) throws IOException {
double sumLon = lonSum.get(bucket);
double compensationLon = lonCompensations.get(bucket);

CompensatedSum compensatedSumLat = new CompensatedSum(sumLat, compensationLat);
CompensatedSum compensatedSumLon = new CompensatedSum(sumLon, compensationLon);
compensatedSumLat.reset(sumLat, compensationLat);
compensatedSumLon.reset(sumLon, compensationLon);

// update the sum
for (int i = 0; i < valueCount; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);

return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
Expand All @@ -105,7 +107,7 @@ public void collect(int doc, long bucket) throws IOException {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
kahanSummation.reset(sum, compensation);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
Expand All @@ -81,7 +82,7 @@ public void collect(int doc, long bucket) throws IOException {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
kahanSummation.reset(sum, compensation);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
private final MultiValuesSource.NumericMultiValuesSource valuesSources;

private DoubleArray weights;
private DoubleArray sums;
private DoubleArray sumCompensations;
private DoubleArray valueSums;
private DoubleArray valueCompensations;
private DoubleArray weightCompensations;
private DocValueFormat format;

Expand All @@ -60,8 +60,8 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
if (valuesSources != null) {
final BigArrays bigArrays = context.bigArrays();
weights = bigArrays.newDoubleArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
sumCompensations = bigArrays.newDoubleArray(1, true);
valueSums = bigArrays.newDoubleArray(1, true);
valueCompensations = bigArrays.newDoubleArray(1, true);
weightCompensations = bigArrays.newDoubleArray(1, true);
}
}
Expand All @@ -80,13 +80,15 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx);
final SortedNumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), ctx);
final CompensatedSum compensatedValueSum = new CompensatedSum(0, 0);
final CompensatedSum compensatedWeightSum = new CompensatedSum(0, 0);

return new LeafBucketCollectorBase(sub, docValues) {
@Override
public void collect(int doc, long bucket) throws IOException {
weights = bigArrays.grow(weights, bucket + 1);
sums = bigArrays.grow(sums, bucket + 1);
sumCompensations = bigArrays.grow(sumCompensations, bucket + 1);
valueSums = bigArrays.grow(valueSums, bucket + 1);
valueCompensations = bigArrays.grow(valueCompensations, bucket + 1);
weightCompensations = bigArrays.grow(weightCompensations, bucket + 1);

if (docValues.advanceExact(doc) && docWeights.advanceExact(doc)) {
Expand All @@ -102,42 +104,43 @@ public void collect(int doc, long bucket) throws IOException {
final int numValues = docValues.docValueCount();
assert numValues > 0;

double valueSum = valueSums.get(bucket);
double valueCompensation = valueCompensations.get(bucket);
compensatedValueSum.reset(valueSum, valueCompensation);

double weightSum = weights.get(bucket);
double weightCompensation = weightCompensations.get(bucket);
compensatedWeightSum.reset(weightSum, weightCompensation);

for (int i = 0; i < numValues; i++) {
kahanSum(docValues.nextValue() * weight, sums, sumCompensations, bucket);
kahanSum(weight, weights, weightCompensations, bucket);
compensatedValueSum.add(docValues.nextValue() * weight);
compensatedWeightSum.add(weight);
}

valueSums.set(bucket, compensatedValueSum.value());
valueCompensations.set(bucket, compensatedValueSum.delta());
weights.set(bucket, compensatedWeightSum.value());
weightCompensations.set(bucket, compensatedWeightSum.delta());
}
}
};
}

private static void kahanSum(double value, DoubleArray values, DoubleArray compensations, long bucket) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = values.get(bucket);
double compensation = compensations.get(bucket);

CompensatedSum kahanSummation = new CompensatedSum(sum, compensation)
.add(value);

values.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
}

@Override
public double metric(long owningBucketOrd) {
if (valuesSources == null || owningBucketOrd >= sums.size()) {
if (valuesSources == null || owningBucketOrd >= valueSums.size()) {
return Double.NaN;
}
return sums.get(owningBucketOrd) / weights.get(owningBucketOrd);
return valueSums.get(owningBucketOrd) / weights.get(owningBucketOrd);
}

@Override
public InternalAggregation buildAggregation(long bucket) {
if (valuesSources == null || bucket >= sums.size()) {
if (valuesSources == null || bucket >= valueSums.size()) {
return buildEmptyAggregation();
}
return new InternalWeightedAvg(name, sums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData());
return new InternalWeightedAvg(name, valueSums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData());
}

@Override
Expand All @@ -147,7 +150,7 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(weights, sums, sumCompensations, weightCompensations);
Releasables.close(weights, valueSums, valueCompensations, weightCompensations);
}

}