Skip to content

Commit

Permalink
fix sparksql decimal avg agg accuracy problem (facebookincubator#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored and JkSelf committed Mar 24, 2023
1 parent 7f7dd5a commit c30b97f
Showing 1 changed file with 152 additions and 113 deletions.
265 changes: 152 additions & 113 deletions velox/functions/sparksql/aggregates/DecimalAvgAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ namespace facebook::velox::functions::sparksql::aggregates {

using velox::aggregate::LongDecimalWithOverflowState;

template <typename TResultType, typename TInputType = TResultType>
class DecimalAggregate : public exec::Aggregate {
template <typename TInputType, typename TResultType>
class DecimalAverageAggregate : public exec::Aggregate {
public:
explicit DecimalAggregate(TypePtr resultType) : exec::Aggregate(resultType) {}
explicit DecimalAverageAggregate(TypePtr inputType, TypePtr resultType)
: exec::Aggregate(resultType), inputType_(inputType) {}

int32_t accumulatorFixedWidthSize() const override {
return sizeof(DecimalAggregate);
return sizeof(DecimalAverageAggregate);
}

int32_t accumulatorAlignmentSize() const override {
Expand All @@ -58,7 +59,7 @@ class DecimalAggregate : public exec::Aggregate {
if (!decodedRaw_.isNullAt(0)) {
auto value = decodedRaw_.valueAt<TInputType>(0);
rows.applyToSelected([&](vector_size_t i) {
updateNonNullValue(groups[i], TResultType(value));
updateNonNullValue(groups[i], UnscaledLongDecimal(value));
});
} else {
// Spark expects the result of partial avg to be non-nullable.
Expand All @@ -73,17 +74,17 @@ class DecimalAggregate : public exec::Aggregate {
return;
}
updateNonNullValue(
groups[i], TResultType(decodedRaw_.valueAt<TInputType>(i)));
groups[i], UnscaledLongDecimal(decodedRaw_.valueAt<TInputType>(i)));
});
} else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) {
auto data = decodedRaw_.data<TInputType>();
rows.applyToSelected([&](vector_size_t i) {
updateNonNullValue<false>(groups[i], TResultType(data[i]));
updateNonNullValue<false>(groups[i], UnscaledLongDecimal(data[i]));
});
} else {
rows.applyToSelected([&](vector_size_t i) {
updateNonNullValue(
groups[i], TResultType(decodedRaw_.valueAt<TInputType>(i)));
groups[i], UnscaledLongDecimal(decodedRaw_.valueAt<TInputType>(i)));
});
}
}
Expand All @@ -101,7 +102,7 @@ class DecimalAggregate : public exec::Aggregate {
int128_t totalSum{0};
auto value = decodedRaw_.valueAt<TInputType>(0);
rows.template applyToSelected([&](vector_size_t i) {
updateNonNullValue(group, TResultType(value));
updateNonNullValue(group, UnscaledLongDecimal(value));
});
} else {
// Spark expects the result of partial avg to be non-nullable.
Expand All @@ -111,7 +112,7 @@ class DecimalAggregate : public exec::Aggregate {
rows.applyToSelected([&](vector_size_t i) {
if (!decodedRaw_.isNullAt(i)) {
updateNonNullValue(
group, TResultType(decodedRaw_.valueAt<TInputType>(i)));
group, UnscaledLongDecimal(decodedRaw_.valueAt<TInputType>(i)));
} else {
// Spark expects the result of partial avg to be non-nullable.
exec::Aggregate::clearNull(group);
Expand Down Expand Up @@ -227,38 +228,9 @@ class DecimalAggregate : public exec::Aggregate {
bool /* mayPushdown */) override {
decodedPartial_.decode(*args[0], rows);
auto baseRowVector = dynamic_cast<const RowVector*>(decodedPartial_.base());
auto sumCol = baseRowVector->childAt(0);
auto countCol = baseRowVector->childAt(1);
switch (sumCol->typeKind()) {
case TypeKind::SHORT_DECIMAL: {
addSingleGroupIntermediateDecimalResults(
group,
rows,
sumCol->as<SimpleVector<UnscaledShortDecimal>>(),
countCol->as<SimpleVector<int64_t>>());
break;
}
case TypeKind::LONG_DECIMAL: {
addSingleGroupIntermediateDecimalResults(
group,
rows,
sumCol->as<SimpleVector<UnscaledLongDecimal>>(),
countCol->as<SimpleVector<int64_t>>());
break;
}
default:
VELOX_FAIL(
"Unsupported sum type for decimal aggregation: {}",
sumCol->typeKind());
}
}
auto sumVector = baseRowVector->childAt(0)->as<SimpleVector<TInputType>>();
auto countVector = baseRowVector->childAt(1)->as<SimpleVector<int64_t>>();

template <class UnscaledType>
void addSingleGroupIntermediateDecimalResults(
char* group,
const SelectivityVector& rows,
SimpleVector<UnscaledType>* sumVector,
SimpleVector<int64_t>* countVector) {
if (decodedPartial_.isConstantMapping()) {
if (!decodedPartial_.isNullAt(0)) {
auto decodedIndex = decodedPartial_.index(0);
Expand Down Expand Up @@ -292,48 +264,17 @@ class DecimalAggregate : public exec::Aggregate {
void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result)
override {
auto rowVector = (*result)->as<RowVector>();
auto sumCol = rowVector->childAt(0);
auto countCol = rowVector->childAt(1);
switch (sumCol->typeKind()) {
case TypeKind::SHORT_DECIMAL: {
extractDecimalAccumulator(
groups,
numGroups,
rowVector,
sumCol->asFlatVector<UnscaledShortDecimal>(),
countCol->asFlatVector<int64_t>());
break;
}
case TypeKind::LONG_DECIMAL: {
extractDecimalAccumulator(
groups,
numGroups,
rowVector,
sumCol->asFlatVector<UnscaledLongDecimal>(),
countCol->asFlatVector<int64_t>());
break;
}
default:
VELOX_FAIL(
"Unsupported sum type for decimal aggregation: {}",
sumCol->typeKind());
}
}

template <class UnscaledType>
void extractDecimalAccumulator(
char** groups,
int32_t numGroups,
RowVector* rowVector,
FlatVector<UnscaledType>* sumVector,
FlatVector<int64_t>* countVector) {
auto sumVector = rowVector->childAt(0)->asFlatVector<TResultType>();
auto countVector = rowVector->childAt(1)->asFlatVector<int64_t>();
rowVector->resize(numGroups);
sumVector->resize(numGroups);
countVector->resize(numGroups);

uint64_t* rawNulls = getRawNulls(rowVector);

int64_t* rawCounts = countVector->mutableRawValues();
UnscaledType* rawSums = sumVector->mutableRawValues();
TResultType* rawSums = sumVector->mutableRawValues();

for (auto i = 0; i < numGroups; ++i) {
char* group = groups[i];
if (isNull(group)) {
Expand All @@ -342,13 +283,99 @@ class DecimalAggregate : public exec::Aggregate {
clearNull(rawNulls, i);
auto* accumulator = decimalAccumulator(group);
rawCounts[i] = accumulator->count;
rawSums[i] = (UnscaledType)accumulator->sum;
rawSums[i] = (TResultType)accumulator->sum;
}
}
}

virtual TResultType computeFinalValue(
LongDecimalWithOverflowState* accumulator) = 0;
TResultType computeFinalValue(LongDecimalWithOverflowState* accumulator) {
// Handles round-up of fraction results.
auto [sumPrecision, sumScale] =
getDecimalPrecisionScale(*this->inputType().get());
auto [rPrecision, rScale] =
getDecimalPrecisionScale(*this->resultType().get());
int countScale = 0;
auto sumRescale = computeRescaleFactor(sumScale, countScale, rScale);

TResultType average = TResultType(0);
if constexpr (std::is_same_v<TInputType, UnscaledLongDecimal>) {
if constexpr (std::is_same_v<TResultType, UnscaledLongDecimal>) {
// Spark use DECIMAL(20, 0) to represent long value
auto countDecimal = UnscaledLongDecimal(accumulator->count);

DecimalUtil::divideWithRoundUp<
UnscaledLongDecimal,
UnscaledLongDecimal,
UnscaledLongDecimal>(
average,
(UnscaledLongDecimal)accumulator->sum,
countDecimal,
false,
sumRescale,
0);
} else if constexpr (std::is_same_v<TResultType, UnscaledShortDecimal>) {
// we enter this case when input type is DECIMAL(10, 2), final Agg input
// sum type is DECIMAL(20, 2), but output type is DECIMAL(14, 2) Spark

// Spark use DECIMAL(20, 0) to represent long value, but we need to
// create a SHORT_DECIMAL to get SHORT_DECIMAL result
auto countDecimal = UnscaledShortDecimal(accumulator->count);
auto longUnscaledSum = (int64_t)accumulator->sum;
DecimalUtil::divideWithRoundUp<
UnscaledShortDecimal,
UnscaledShortDecimal,
UnscaledShortDecimal>(
average,
UnscaledShortDecimal(longUnscaledSum),
countDecimal,
false,
sumRescale,
0);
} else {
VELOX_FAIL("Final Avg Agg result type must be DECIMAL");
}
} else if constexpr (std::is_same_v<TInputType, UnscaledShortDecimal>) {
if constexpr (std::is_same_v<TResultType, UnscaledLongDecimal>) {
// Spark use DECIMAL(20, 0) to represent long value
auto countDecimal = UnscaledLongDecimal(accumulator->count);
auto longUnscaledSum = (int64_t)accumulator->sum;
DecimalUtil::divideWithRoundUp<
UnscaledLongDecimal,
UnscaledShortDecimal,
UnscaledLongDecimal>(
average,
UnscaledShortDecimal(longUnscaledSum),
countDecimal,
false,
sumRescale,
0);
} else if constexpr (std::is_same_v<TResultType, UnscaledShortDecimal>) {
// we enter this case when input type is DECIMAL(10, 2), final Agg input
// sum type is DECIMAL(20, 2), but output type is DECIMAL(14, 2) Spark

// Spark use DECIMAL(20, 0) to represent long value, but we need to
// create a SHORT_DECIMAL to get SHORT_DECIMAL result
auto countDecimal = UnscaledShortDecimal(accumulator->count);
auto longUnscaledSum = (int64_t)accumulator->sum;
DecimalUtil::divideWithRoundUp<
UnscaledShortDecimal,
UnscaledShortDecimal,
UnscaledShortDecimal>(
average,
UnscaledShortDecimal(longUnscaledSum),
countDecimal,
false,
sumRescale,
0);
} else {
VELOX_FAIL("Final Avg Agg result type must be DECIMAL");
}
} else {
VELOX_FAIL("Final Avg Agg result type must be DECIMAL");
}

return average;
}

void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
override {
Expand Down Expand Up @@ -392,7 +419,7 @@ class DecimalAggregate : public exec::Aggregate {
}

template <bool tableHasNulls = true>
void updateNonNullValue(char* group, TResultType value) {
void updateNonNullValue(char* group, UnscaledLongDecimal value) {
if constexpr (tableHasNulls) {
exec::Aggregate::clearNull(group);
}
Expand All @@ -412,40 +439,35 @@ class DecimalAggregate : public exec::Aggregate {
accumulator->sum, sum.unscaledValue(), accumulator->sum);
}

TypePtr inputType() const {
return inputType_;
}

private:
inline LongDecimalWithOverflowState* decimalAccumulator(char* group) {
return exec::Aggregate::value<LongDecimalWithOverflowState>(group);
}

inline static uint8_t
computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale) {
return rScale - fromScale + toScale;
}

DecodedVector decodedRaw_;
DecodedVector decodedPartial_;
};

template <typename TUnscaledType>
class DecimalAverageAggregate : public DecimalAggregate<TUnscaledType> {
public:
explicit DecimalAverageAggregate(TypePtr resultType)
: DecimalAggregate<TUnscaledType>(resultType) {}

virtual TUnscaledType computeFinalValue(
LongDecimalWithOverflowState* accumulator) final {
// Handles round-up of fraction results.
int128_t average{0};
DecimalUtil::computeAverage(
average, accumulator->sum, accumulator->count, accumulator->overflow);
return TUnscaledType(average);
}
const TypePtr inputType_;
};

bool registerDecimalAvgAggregate(const std::string& name) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.argumentType("DECIMAL(a_precision, a_scale)")
.intermediateType("VARBINARY")
.returnType("DECIMAL(a_precision, a_scale)")
.build());
signatures.push_back(
exec::AggregateFunctionSignatureBuilder()
.integerVariable("a_precision")
.integerVariable("a_scale")
.argumentType("DECIMAL(a_precision, a_scale)")
.intermediateType("ROW(DECIMAL(a_precision, a_scale), BIGINT)")
.returnType("DECIMAL(a_precision, a_scale)")
.build());

return exec::registerAggregateFunction(
name,
Expand All @@ -459,21 +481,38 @@ bool registerDecimalAvgAggregate(const std::string& name) {
auto& inputType = argTypes[0];
switch (inputType->kind()) {
case TypeKind::SHORT_DECIMAL:
return std::make_unique<
DecimalAverageAggregate<UnscaledShortDecimal>>(resultType);
if (resultType->kind() == TypeKind::SHORT_DECIMAL) {
return std::make_unique<DecimalAverageAggregate<
UnscaledShortDecimal,
UnscaledShortDecimal>>(inputType, resultType);
} else {
return std::make_unique<DecimalAverageAggregate<
UnscaledShortDecimal,
UnscaledLongDecimal>>(inputType, resultType);
}
case TypeKind::LONG_DECIMAL:
return std::make_unique<
DecimalAverageAggregate<UnscaledLongDecimal>>(resultType);
if (resultType->kind() == TypeKind::LONG_DECIMAL) {
return std::make_unique<DecimalAverageAggregate<
UnscaledLongDecimal,
UnscaledLongDecimal>>(inputType, resultType);
} else {
VELOX_FAIL(
"Partial Avg Agg result type must greater than input type.");
}
case TypeKind::ROW: {
DCHECK(!exec::isRawInput(step));
auto sumInputType = inputType->asRow().childAt(0);
switch (sumInputType->kind()) {
case TypeKind::SHORT_DECIMAL:
return std::make_unique<
DecimalAverageAggregate<UnscaledShortDecimal>>(resultType);
case TypeKind::LONG_DECIMAL:
return std::make_unique<
DecimalAverageAggregate<UnscaledLongDecimal>>(resultType);
if (resultType->kind() == TypeKind::SHORT_DECIMAL) {
return std::make_unique<DecimalAverageAggregate<
UnscaledLongDecimal,
UnscaledShortDecimal>>(sumInputType, resultType);
} else {
return std::make_unique<DecimalAverageAggregate<
UnscaledLongDecimal,
UnscaledLongDecimal>>(sumInputType, resultType);
}
default:
VELOX_FAIL(
"Unknown sum type for {} aggregation {}",
Expand Down

0 comments on commit c30b97f

Please sign in to comment.