Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve BenchmarkDecimalAggregation #13939

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@
import org.testng.annotations.Test;

import java.util.OptionalInt;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

@@ -53,12 +54,12 @@
@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(3)
@Warmup(iterations = 10)
@Measurement(iterations = 10)
@Warmup(iterations = 20, time = 5)
@Measurement(iterations = 10, time = 1)
@BenchmarkMode(Mode.AverageTime)
public class BenchmarkDecimalAggregation
{
private static final int ELEMENT_COUNT = 1_000_000;
private static final int ELEMENT_COUNT = 10_000_000;

@Benchmark
@OperationsPerInvocation(ELEMENT_COUNT)
@@ -82,6 +83,27 @@ public Block benchmarkEvaluateIntermediate(BenchmarkData data)
return builder.build();
}

@Benchmark
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public Block benchmarkEvaluateIntermediateOnly(BenchmarkData data)
{
GroupedAggregator aggregator = data.getPartialAggregator();
BlockBuilder builder = aggregator.getType().createBlockBuilder(null, data.getGroupCount());
for (int groupId = 0; groupId < data.getGroupCount(); groupId++) {
aggregator.evaluate(groupId, builder);
}
return builder.build();
}

@Benchmark
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public Object benchmarkAddIntermediate(BenchmarkData data)
{
GroupedAggregator aggregator = data.getFinalAggregatorFactory().createGroupedAggregator();
aggregator.processPage(data.getGroupIds(), data.getIntermediateValues());
return aggregator;
}

@Benchmark
public Block benchmarkEvaluateFinal(BenchmarkData data)
{
@@ -108,36 +130,66 @@ public static class BenchmarkData
@Param({"10", "1000"})
private int groupCount = 10;

@Param({"SMALL", "BIG", "MIXED"})
private String decimalSize = "SMALL";

@Param({"true", "false"})
private boolean groupIdsInOrder;

private AggregatorFactory partialAggregatorFactory;
private AggregatorFactory finalAggregatorFactory;
private GroupByIdBlock groupIds;
private Page values;
private Page intermediateValues;
private GroupedAggregator partialAggregator;

@Setup
public void setup()
{
TestingFunctionResolution functionResolution = new TestingFunctionResolution();

Random random = new Random(343526534);
switch (type) {
case "SHORT": {
DecimalType type = createDecimalType(14, 3);
values = createValues(functionResolution, type, type::writeLong);
values = createValues(functionResolution, type, (builder, value) -> {
boolean writeShort = "SMALL".equals(decimalSize) || ("MIXED".equals(decimalSize) && random.nextBoolean());
if (writeShort) {
builder.writeLong(value);
}
else {
// long
builder.writeLong(Long.MAX_VALUE - value);
}
});
break;
}
case "LONG": {
DecimalType type = createDecimalType(30, 10);
values = createValues(functionResolution, type, (builder, value) -> type.writeObject(builder, Int128.valueOf(value)));
values = createValues(functionResolution, type, (builder, value) -> {
boolean writeShort = "SMALL".equals(decimalSize) || ("MIXED".equals(decimalSize) && random.nextBoolean());
if (writeShort) {
type.writeObject(builder, Int128.valueOf(value));
}
else {
// long
type.writeObject(builder, Int128.valueOf(Long.MAX_VALUE - value, value));
}
});
break;
}
}

BlockBuilder ids = BIGINT.createBlockBuilder(null, ELEMENT_COUNT);
for (int i = 0; i < ELEMENT_COUNT; i++) {
BIGINT.writeLong(ids, ThreadLocalRandom.current().nextLong(groupCount));
long groupId = groupIdsInOrder ? i % groupCount : ThreadLocalRandom.current().nextLong(groupCount);
BIGINT.writeLong(ids, groupId);
}

groupIds = new GroupByIdBlock(groupCount, ids.build());
intermediateValues = new Page(createIntermediateValues(partialAggregatorFactory.createGroupedAggregator(), groupIds, values));

partialAggregator = partialAggregatorFactory.createGroupedAggregator();
partialAggregator.processPage(getGroupIds(), getValues());
}

private Block createIntermediateValues(GroupedAggregator aggregator, GroupByIdBlock groupIds, Page inputPage)
@@ -193,6 +245,11 @@ public Page getIntermediateValues()
return intermediateValues;
}

public GroupedAggregator getPartialAggregator()
{
return partialAggregator;
}

interface ValueWriter
{
void write(BlockBuilder valuesBuilder, int value);