forked from facebookincubator/velox
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
14a963b
commit 7b49ca1
Showing
16 changed files
with
578 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
347 changes: 347 additions & 0 deletions
347
velox/functions/prestosql/aggregates/AverageAggregate.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,347 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include "velox/exec/Aggregate.h" | ||
#include "velox/expression/FunctionSignature.h" | ||
#include "velox/functions/prestosql/aggregates/AggregateNames.h" | ||
#include "velox/vector/ComplexVector.h" | ||
#include "velox/vector/DecodedVector.h" | ||
#include "velox/vector/FlatVector.h" | ||
|
||
namespace facebook::velox::aggregate { | ||
|
||
struct SumCount { | ||
double sum{0}; | ||
int64_t count{0}; | ||
}; | ||
|
||
// Partial aggregation produces a pair of sum and count. | ||
// Final aggregation takes a pair of sum and count and returns a real for real | ||
// input types and double for other input types. | ||
// T is the input type for partial aggregation. Not used for final aggregation. | ||
template <typename T> | ||
class AverageAggregate : public exec::Aggregate { | ||
public: | ||
explicit AverageAggregate(TypePtr resultType) : exec::Aggregate(resultType) {} | ||
|
||
int32_t accumulatorFixedWidthSize() const override { | ||
return sizeof(SumCount); | ||
} | ||
|
||
void initializeNewGroups( | ||
char** groups, | ||
folly::Range<const vector_size_t*> indices) override { | ||
setAllNulls(groups, indices); | ||
for (auto i : indices) { | ||
new (groups[i] + offset_) SumCount(); | ||
} | ||
} | ||
|
||
void finalize(char** /* unused */, int32_t /* unused */) override {} | ||
|
||
void extractValues(char** groups, int32_t numGroups, VectorPtr* result) | ||
override { | ||
// Real input type in Presto has special case and returns REAL, not DOUBLE. | ||
if (resultType_->isDouble()) { | ||
extractValuesImpl<double>(groups, numGroups, result); | ||
} else { | ||
extractValuesImpl<float>(groups, numGroups, result); | ||
} | ||
} | ||
|
||
void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) | ||
override { | ||
auto rowVector = (*result)->as<RowVector>(); | ||
auto sumVector = rowVector->childAt(0)->asFlatVector<double>(); | ||
auto countVector = rowVector->childAt(1)->asFlatVector<int64_t>(); | ||
|
||
rowVector->resize(numGroups); | ||
sumVector->resize(numGroups); | ||
countVector->resize(numGroups); | ||
uint64_t* rawNulls = getRawNulls(rowVector); | ||
|
||
int64_t* rawCounts = countVector->mutableRawValues(); | ||
double* rawSums = sumVector->mutableRawValues(); | ||
for (auto i = 0; i < numGroups; ++i) { | ||
char* group = groups[i]; | ||
if (isNull(group)) { | ||
rowVector->setNull(i, true); | ||
} else { | ||
clearNull(rawNulls, i); | ||
auto* sumCount = accumulator(group); | ||
rawCounts[i] = sumCount->count; | ||
rawSums[i] = sumCount->sum; | ||
} | ||
} | ||
} | ||
|
||
void addRawInput( | ||
char** groups, | ||
const SelectivityVector& rows, | ||
const std::vector<VectorPtr>& args, | ||
bool /*mayPushdown*/) override { | ||
decodedRaw_.decode(*args[0], rows); | ||
if (decodedRaw_.isConstantMapping()) { | ||
if (!decodedRaw_.isNullAt(0)) { | ||
auto value = decodedRaw_.valueAt<T>(0); | ||
rows.applyToSelected( | ||
[&](vector_size_t i) { updateNonNullValue(groups[i], value); }); | ||
} | ||
} else if (decodedRaw_.mayHaveNulls()) { | ||
rows.applyToSelected([&](vector_size_t i) { | ||
if (decodedRaw_.isNullAt(i)) { | ||
return; | ||
} | ||
updateNonNullValue(groups[i], decodedRaw_.valueAt<T>(i)); | ||
}); | ||
} else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { | ||
auto data = decodedRaw_.data<T>(); | ||
rows.applyToSelected([&](vector_size_t i) { | ||
updateNonNullValue<false>(groups[i], data[i]); | ||
}); | ||
} else { | ||
rows.applyToSelected([&](vector_size_t i) { | ||
updateNonNullValue(groups[i], decodedRaw_.valueAt<T>(i)); | ||
}); | ||
} | ||
} | ||
|
||
void addSingleGroupRawInput( | ||
char* group, | ||
const SelectivityVector& rows, | ||
const std::vector<VectorPtr>& args, | ||
bool /*mayPushdown*/) override { | ||
decodedRaw_.decode(*args[0], rows); | ||
|
||
if (decodedRaw_.isConstantMapping()) { | ||
if (!decodedRaw_.isNullAt(0)) { | ||
const T value = decodedRaw_.valueAt<T>(0); | ||
const auto numRows = rows.countSelected(); | ||
updateNonNullValue(group, numRows, value * numRows); | ||
} | ||
} else if (decodedRaw_.mayHaveNulls()) { | ||
rows.applyToSelected([&](vector_size_t i) { | ||
if (!decodedRaw_.isNullAt(i)) { | ||
updateNonNullValue(group, decodedRaw_.valueAt<T>(i)); | ||
} | ||
}); | ||
} else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { | ||
const T* data = decodedRaw_.data<T>(); | ||
double totalSum = 0; | ||
rows.applyToSelected([&](vector_size_t i) { totalSum += data[i]; }); | ||
updateNonNullValue<false>(group, rows.countSelected(), totalSum); | ||
} else { | ||
double totalSum = 0; | ||
rows.applyToSelected( | ||
[&](vector_size_t i) { totalSum += decodedRaw_.valueAt<T>(i); }); | ||
updateNonNullValue(group, rows.countSelected(), totalSum); | ||
} | ||
} | ||
|
||
void addIntermediateResults( | ||
char** groups, | ||
const SelectivityVector& rows, | ||
const std::vector<VectorPtr>& args, | ||
bool /* mayPushdown */) override { | ||
decodedPartial_.decode(*args[0], rows); | ||
auto baseRowVector = dynamic_cast<const RowVector*>(decodedPartial_.base()); | ||
auto baseSumVector = baseRowVector->childAt(0)->as<SimpleVector<double>>(); | ||
auto baseCountVector = | ||
baseRowVector->childAt(1)->as<SimpleVector<int64_t>>(); | ||
|
||
if (decodedPartial_.isConstantMapping()) { | ||
if (!decodedPartial_.isNullAt(0)) { | ||
auto decodedIndex = decodedPartial_.index(0); | ||
auto count = baseCountVector->valueAt(decodedIndex); | ||
auto sum = baseSumVector->valueAt(decodedIndex); | ||
rows.applyToSelected([&](vector_size_t i) { | ||
updateNonNullValue(groups[i], count, sum); | ||
}); | ||
} | ||
} else if (decodedPartial_.mayHaveNulls()) { | ||
rows.applyToSelected([&](vector_size_t i) { | ||
if (decodedPartial_.isNullAt(i)) { | ||
return; | ||
} | ||
auto decodedIndex = decodedPartial_.index(i); | ||
updateNonNullValue( | ||
groups[i], | ||
baseCountVector->valueAt(decodedIndex), | ||
baseSumVector->valueAt(decodedIndex)); | ||
}); | ||
} else { | ||
rows.applyToSelected([&](vector_size_t i) { | ||
auto decodedIndex = decodedPartial_.index(i); | ||
updateNonNullValue( | ||
groups[i], | ||
baseCountVector->valueAt(decodedIndex), | ||
baseSumVector->valueAt(decodedIndex)); | ||
}); | ||
} | ||
} | ||
|
||
void addSingleGroupIntermediateResults( | ||
char* group, | ||
const SelectivityVector& rows, | ||
const std::vector<VectorPtr>& args, | ||
bool /* mayPushdown */) override { | ||
decodedPartial_.decode(*args[0], rows); | ||
auto baseRowVector = dynamic_cast<const RowVector*>(decodedPartial_.base()); | ||
auto baseSumVector = baseRowVector->childAt(0)->as<SimpleVector<double>>(); | ||
auto baseCountVector = | ||
baseRowVector->childAt(1)->as<SimpleVector<int64_t>>(); | ||
|
||
if (decodedPartial_.isConstantMapping()) { | ||
if (!decodedPartial_.isNullAt(0)) { | ||
auto decodedIndex = decodedPartial_.index(0); | ||
const auto numRows = rows.countSelected(); | ||
auto totalCount = baseCountVector->valueAt(decodedIndex) * numRows; | ||
auto totalSum = baseSumVector->valueAt(decodedIndex) * numRows; | ||
updateNonNullValue(group, totalCount, totalSum); | ||
} | ||
} else if (decodedPartial_.mayHaveNulls()) { | ||
rows.applyToSelected([&](vector_size_t i) { | ||
if (!decodedPartial_.isNullAt(i)) { | ||
auto decodedIndex = decodedPartial_.index(i); | ||
updateNonNullValue( | ||
group, | ||
baseCountVector->valueAt(decodedIndex), | ||
baseSumVector->valueAt(decodedIndex)); | ||
} | ||
}); | ||
} else { | ||
double totalSum = 0; | ||
int64_t totalCount = 0; | ||
rows.applyToSelected([&](vector_size_t i) { | ||
auto decodedIndex = decodedPartial_.index(i); | ||
totalCount += baseCountVector->valueAt(decodedIndex); | ||
totalSum += baseSumVector->valueAt(decodedIndex); | ||
}); | ||
updateNonNullValue(group, totalCount, totalSum); | ||
} | ||
} | ||
|
||
private: | ||
// partial | ||
template <bool tableHasNulls = true> | ||
inline void updateNonNullValue(char* group, T value) { | ||
if constexpr (tableHasNulls) { | ||
exec::Aggregate::clearNull(group); | ||
} | ||
accumulator(group)->sum += value; | ||
accumulator(group)->count += 1; | ||
} | ||
|
||
template <bool tableHasNulls = true> | ||
inline void updateNonNullValue(char* group, int64_t count, double sum) { | ||
if constexpr (tableHasNulls) { | ||
exec::Aggregate::clearNull(group); | ||
} | ||
accumulator(group)->sum += sum; | ||
accumulator(group)->count += count; | ||
} | ||
|
||
inline SumCount* accumulator(char* group) { | ||
return exec::Aggregate::value<SumCount>(group); | ||
} | ||
|
||
template <typename TResult> | ||
void extractValuesImpl(char** groups, int32_t numGroups, VectorPtr* result) { | ||
auto vector = (*result)->as<FlatVector<TResult>>(); | ||
VELOX_CHECK(vector); | ||
vector->resize(numGroups); | ||
uint64_t* rawNulls = getRawNulls(vector); | ||
|
||
TResult* rawValues = vector->mutableRawValues(); | ||
for (int32_t i = 0; i < numGroups; ++i) { | ||
char* group = groups[i]; | ||
if (isNull(group)) { | ||
vector->setNull(i, true); | ||
} else { | ||
clearNull(rawNulls, i); | ||
auto* sumCount = accumulator(group); | ||
rawValues[i] = (TResult)sumCount->sum / sumCount->count; | ||
} | ||
} | ||
} | ||
|
||
DecodedVector decodedRaw_; | ||
DecodedVector decodedPartial_; | ||
}; | ||
|
||
void checkSumCountRowType(TypePtr type, const std::string& errorMessage) { | ||
VELOX_CHECK_EQ(type->kind(), TypeKind::ROW, "{}", errorMessage); | ||
VELOX_CHECK_EQ( | ||
type->childAt(0)->kind(), TypeKind::DOUBLE, "{}", errorMessage); | ||
VELOX_CHECK_EQ( | ||
type->childAt(1)->kind(), TypeKind::BIGINT, "{}", errorMessage); | ||
} | ||
|
||
bool registerAverageAggregate(const std::string& name) { | ||
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures; | ||
|
||
for (const auto& inputType : {"smallint", "integer", "bigint", "double"}) { | ||
signatures.push_back(exec::AggregateFunctionSignatureBuilder() | ||
.returnType("double") | ||
.intermediateType("row(double,bigint)") | ||
.argumentType(inputType) | ||
.build()); | ||
} | ||
// Real input type in Presto has special case and returns REAL, not DOUBLE. | ||
signatures.push_back(exec::AggregateFunctionSignatureBuilder() | ||
.returnType("real") | ||
.intermediateType("row(double,bigint)") | ||
.argumentType("real") | ||
.build()); | ||
|
||
exec::registerAggregateFunction( | ||
name, | ||
std::move(signatures), | ||
[name]( | ||
core::AggregationNode::Step step, | ||
const std::vector<TypePtr>& argTypes, | ||
const TypePtr& resultType) -> std::unique_ptr<exec::Aggregate> { | ||
VELOX_CHECK_LE( | ||
argTypes.size(), 1, "{} takes at most one argument", name); | ||
auto inputType = argTypes[0]; | ||
if (exec::isRawInput(step)) { | ||
switch (inputType->kind()) { | ||
case TypeKind::SMALLINT: | ||
return std::make_unique<AverageAggregate<int16_t>>(resultType); | ||
case TypeKind::INTEGER: | ||
return std::make_unique<AverageAggregate<int32_t>>(resultType); | ||
case TypeKind::BIGINT: | ||
return std::make_unique<AverageAggregate<int64_t>>(resultType); | ||
case TypeKind::REAL: | ||
return std::make_unique<AverageAggregate<float>>(resultType); | ||
case TypeKind::DOUBLE: | ||
return std::make_unique<AverageAggregate<double>>(resultType); | ||
default: | ||
VELOX_FAIL( | ||
"Unknown input type for {} aggregation {}", | ||
name, | ||
inputType->kindName()); | ||
} | ||
} else { | ||
checkSumCountRowType( | ||
inputType, | ||
"Input type for final aggregation must be (sum:double, count:bigint) struct"); | ||
return std::make_unique<AverageAggregate<int64_t>>(resultType); | ||
} | ||
}); | ||
return true; | ||
} | ||
|
||
} // namespace facebook::velox::aggregate |
Oops, something went wrong.