From d07c4bff4106c4dc0f8890c3211978738fb71373 Mon Sep 17 00:00:00 2001 From: Rui Mo Date: Fri, 6 May 2022 17:22:22 +0800 Subject: [PATCH] Expose the registrations for avg and count (#4) --- .../prestosql/aggregates/AverageAggregate.cpp | 332 +---------------- .../prestosql/aggregates/AverageAggregate.h | 347 ++++++++++++++++++ .../prestosql/aggregates/CMakeLists.txt | 2 + .../prestosql/aggregates/CountAggregate.cpp | 160 +------- .../prestosql/aggregates/CountAggregate.h | 176 +++++++++ 5 files changed, 528 insertions(+), 489 deletions(-) create mode 100644 velox/functions/prestosql/aggregates/AverageAggregate.h create mode 100644 velox/functions/prestosql/aggregates/CountAggregate.h diff --git a/velox/functions/prestosql/aggregates/AverageAggregate.cpp b/velox/functions/prestosql/aggregates/AverageAggregate.cpp index 6cf3fc10fa10f..d23bc7d751178 100644 --- a/velox/functions/prestosql/aggregates/AverageAggregate.cpp +++ b/velox/functions/prestosql/aggregates/AverageAggregate.cpp @@ -13,340 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "velox/exec/Aggregate.h" -#include "velox/expression/FunctionSignature.h" +#include "velox/functions/prestosql/aggregates/AverageAggregate.h" #include "velox/functions/prestosql/aggregates/AggregateNames.h" -#include "velox/vector/ComplexVector.h" -#include "velox/vector/DecodedVector.h" -#include "velox/vector/FlatVector.h" namespace facebook::velox::aggregate { -namespace { - -struct SumCount { - double sum{0}; - int64_t count{0}; -}; - -// Partial aggregation produces a pair of sum and count. -// Final aggregation takes a pair of sum and count and returns a real for real -// input types and double for other input types. -// T is the input type for partial aggregation. Not used for final aggregation. -template -class AverageAggregate : public exec::Aggregate { - public: - explicit AverageAggregate(TypePtr resultType) : exec::Aggregate(resultType) {} - - int32_t accumulatorFixedWidthSize() const override { - return sizeof(SumCount); - } - - void initializeNewGroups( - char** groups, - folly::Range indices) override { - setAllNulls(groups, indices); - for (auto i : indices) { - new (groups[i] + offset_) SumCount(); - } - } - - void finalize(char** /* unused */, int32_t /* unused */) override {} - - void extractValues(char** groups, int32_t numGroups, VectorPtr* result) - override { - // Real input type in Presto has special case and returns REAL, not DOUBLE. - if (resultType_->isDouble()) { - extractValuesImpl(groups, numGroups, result); - } else { - extractValuesImpl(groups, numGroups, result); - } - } - - void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) - override { - auto rowVector = (*result)->as(); - auto sumVector = rowVector->childAt(0)->asFlatVector(); - auto countVector = rowVector->childAt(1)->asFlatVector(); - - rowVector->resize(numGroups); - sumVector->resize(numGroups); - countVector->resize(numGroups); - uint64_t* rawNulls = getRawNulls(rowVector); - - int64_t* rawCounts = countVector->mutableRawValues(); - double* rawSums = sumVector->mutableRawValues(); - for (auto i = 0; i < numGroups; ++i) { - char* group = groups[i]; - if (isNull(group)) { - rowVector->setNull(i, true); - } else { - clearNull(rawNulls, i); - auto* sumCount = accumulator(group); - rawCounts[i] = sumCount->count; - rawSums[i] = sumCount->sum; - } - } - } - - void addRawInput( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - decodedRaw_.decode(*args[0], rows); - if (decodedRaw_.isConstantMapping()) { - if (!decodedRaw_.isNullAt(0)) { - auto value = decodedRaw_.valueAt(0); - rows.applyToSelected( - [&](vector_size_t i) { updateNonNullValue(groups[i], value); }); - } - } else if (decodedRaw_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (decodedRaw_.isNullAt(i)) { - return; - } - updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); - }); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { - auto data = decodedRaw_.data(); - rows.applyToSelected([&](vector_size_t i) { - updateNonNullValue(groups[i], data[i]); - }); - } else { - rows.applyToSelected([&](vector_size_t i) { - updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); - }); - } - } - - void addSingleGroupRawInput( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - decodedRaw_.decode(*args[0], rows); - - if (decodedRaw_.isConstantMapping()) { - if (!decodedRaw_.isNullAt(0)) { - const T value = decodedRaw_.valueAt(0); - const auto numRows = rows.countSelected(); - updateNonNullValue(group, numRows, value * numRows); - } - } else if (decodedRaw_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (!decodedRaw_.isNullAt(i)) { - updateNonNullValue(group, decodedRaw_.valueAt(i)); - } - }); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { - const T* data = decodedRaw_.data(); - double totalSum = 0; - rows.applyToSelected([&](vector_size_t i) { totalSum += data[i]; }); - updateNonNullValue(group, rows.countSelected(), totalSum); - } else { - double totalSum = 0; - rows.applyToSelected( - [&](vector_size_t i) { totalSum += decodedRaw_.valueAt(i); }); - updateNonNullValue(group, rows.countSelected(), totalSum); - } - } - - void addIntermediateResults( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /* mayPushdown */) override { - decodedPartial_.decode(*args[0], rows); - auto baseRowVector = dynamic_cast(decodedPartial_.base()); - auto baseSumVector = baseRowVector->childAt(0)->as>(); - auto baseCountVector = - baseRowVector->childAt(1)->as>(); - - if (decodedPartial_.isConstantMapping()) { - if (!decodedPartial_.isNullAt(0)) { - auto decodedIndex = decodedPartial_.index(0); - auto count = baseCountVector->valueAt(decodedIndex); - auto sum = baseSumVector->valueAt(decodedIndex); - rows.applyToSelected([&](vector_size_t i) { - updateNonNullValue(groups[i], count, sum); - }); - } - } else if (decodedPartial_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (decodedPartial_.isNullAt(i)) { - return; - } - auto decodedIndex = decodedPartial_.index(i); - updateNonNullValue( - groups[i], - baseCountVector->valueAt(decodedIndex), - baseSumVector->valueAt(decodedIndex)); - }); - } else { - rows.applyToSelected([&](vector_size_t i) { - auto decodedIndex = decodedPartial_.index(i); - updateNonNullValue( - groups[i], - baseCountVector->valueAt(decodedIndex), - baseSumVector->valueAt(decodedIndex)); - }); - } - } - - void addSingleGroupIntermediateResults( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /* mayPushdown */) override { - decodedPartial_.decode(*args[0], rows); - auto baseRowVector = dynamic_cast(decodedPartial_.base()); - auto baseSumVector = baseRowVector->childAt(0)->as>(); - auto baseCountVector = - baseRowVector->childAt(1)->as>(); - - if (decodedPartial_.isConstantMapping()) { - if (!decodedPartial_.isNullAt(0)) { - auto decodedIndex = decodedPartial_.index(0); - const auto numRows = rows.countSelected(); - auto totalCount = baseCountVector->valueAt(decodedIndex) * numRows; - auto totalSum = baseSumVector->valueAt(decodedIndex) * numRows; - updateNonNullValue(group, totalCount, totalSum); - } - } else if (decodedPartial_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (!decodedPartial_.isNullAt(i)) { - auto decodedIndex = decodedPartial_.index(i); - updateNonNullValue( - group, - baseCountVector->valueAt(decodedIndex), - baseSumVector->valueAt(decodedIndex)); - } - }); - } else { - double totalSum = 0; - int64_t totalCount = 0; - rows.applyToSelected([&](vector_size_t i) { - auto decodedIndex = decodedPartial_.index(i); - totalCount += baseCountVector->valueAt(decodedIndex); - totalSum += baseSumVector->valueAt(decodedIndex); - }); - updateNonNullValue(group, totalCount, totalSum); - } - } - - private: - // partial - template - inline void updateNonNullValue(char* group, T value) { - if constexpr (tableHasNulls) { - exec::Aggregate::clearNull(group); - } - accumulator(group)->sum += value; - accumulator(group)->count += 1; - } - - template - inline void updateNonNullValue(char* group, int64_t count, double sum) { - if constexpr (tableHasNulls) { - exec::Aggregate::clearNull(group); - } - accumulator(group)->sum += sum; - accumulator(group)->count += count; - } - - inline SumCount* accumulator(char* group) { - return exec::Aggregate::value(group); - } - - template - void extractValuesImpl(char** groups, int32_t numGroups, VectorPtr* result) { - auto vector = (*result)->as>(); - VELOX_CHECK(vector); - vector->resize(numGroups); - uint64_t* rawNulls = getRawNulls(vector); - - TResult* rawValues = vector->mutableRawValues(); - for (int32_t i = 0; i < numGroups; ++i) { - char* group = groups[i]; - if (isNull(group)) { - vector->setNull(i, true); - } else { - clearNull(rawNulls, i); - auto* sumCount = accumulator(group); - rawValues[i] = (TResult)sumCount->sum / sumCount->count; - } - } - } - - DecodedVector decodedRaw_; - DecodedVector decodedPartial_; -}; - -void checkSumCountRowType(TypePtr type, const std::string& errorMessage) { - VELOX_CHECK_EQ(type->kind(), TypeKind::ROW, "{}", errorMessage); - VELOX_CHECK_EQ( - type->childAt(0)->kind(), TypeKind::DOUBLE, "{}", errorMessage); - VELOX_CHECK_EQ( - type->childAt(1)->kind(), TypeKind::BIGINT, "{}", errorMessage); -} - -bool registerAverageAggregate(const std::string& name) { - std::vector> signatures; - - for (const auto& inputType : {"smallint", "integer", "bigint", "double"}) { - signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .returnType("double") - .intermediateType("row(double,bigint)") - .argumentType(inputType) - .build()); - } - // Real input type in Presto has special case and returns REAL, not DOUBLE. - signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .returnType("real") - .intermediateType("row(double,bigint)") - .argumentType("real") - .build()); - - exec::registerAggregateFunction( - name, - std::move(signatures), - [name]( - core::AggregationNode::Step step, - const std::vector& argTypes, - const TypePtr& resultType) -> std::unique_ptr { - VELOX_CHECK_LE( - argTypes.size(), 1, "{} takes at most one argument", name); - auto inputType = argTypes[0]; - if (exec::isRawInput(step)) { - switch (inputType->kind()) { - case TypeKind::SMALLINT: - return std::make_unique>(resultType); - case TypeKind::INTEGER: - return std::make_unique>(resultType); - case TypeKind::BIGINT: - return std::make_unique>(resultType); - case TypeKind::REAL: - return std::make_unique>(resultType); - case TypeKind::DOUBLE: - return std::make_unique>(resultType); - default: - VELOX_FAIL( - "Unknown input type for {} aggregation {}", - name, - inputType->kindName()); - } - } else { - checkSumCountRowType( - inputType, - "Input type for final aggregation must be (sum:double, count:bigint) struct"); - return std::make_unique>(resultType); - } - }); - return true; -} - static bool FB_ANONYMOUS_VARIABLE(g_AggregateFunction) = registerAverageAggregate(kAvg); -} // namespace + } // namespace facebook::velox::aggregate diff --git a/velox/functions/prestosql/aggregates/AverageAggregate.h b/velox/functions/prestosql/aggregates/AverageAggregate.h new file mode 100644 index 0000000000000..490dde456e3da --- /dev/null +++ b/velox/functions/prestosql/aggregates/AverageAggregate.h @@ -0,0 +1,347 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/Aggregate.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/vector/ComplexVector.h" +#include "velox/vector/DecodedVector.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::aggregate { + +struct SumCount { + double sum{0}; + int64_t count{0}; +}; + +// Partial aggregation produces a pair of sum and count. +// Final aggregation takes a pair of sum and count and returns a real for real +// input types and double for other input types. +// T is the input type for partial aggregation. Not used for final aggregation. +template +class AverageAggregate : public exec::Aggregate { + public: + explicit AverageAggregate(TypePtr resultType) : exec::Aggregate(resultType) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(SumCount); + } + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + setAllNulls(groups, indices); + for (auto i : indices) { + new (groups[i] + offset_) SumCount(); + } + } + + void finalize(char** /* unused */, int32_t /* unused */) override {} + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + // Real input type in Presto has special case and returns REAL, not DOUBLE. + if (resultType_->isDouble()) { + extractValuesImpl(groups, numGroups, result); + } else { + extractValuesImpl(groups, numGroups, result); + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto rowVector = (*result)->as(); + auto sumVector = rowVector->childAt(0)->asFlatVector(); + auto countVector = rowVector->childAt(1)->asFlatVector(); + + rowVector->resize(numGroups); + sumVector->resize(numGroups); + countVector->resize(numGroups); + uint64_t* rawNulls = getRawNulls(rowVector); + + int64_t* rawCounts = countVector->mutableRawValues(); + double* rawSums = sumVector->mutableRawValues(); + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + if (isNull(group)) { + rowVector->setNull(i, true); + } else { + clearNull(rawNulls, i); + auto* sumCount = accumulator(group); + rawCounts[i] = sumCount->count; + rawSums[i] = sumCount->sum; + } + } + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodedRaw_.decode(*args[0], rows); + if (decodedRaw_.isConstantMapping()) { + if (!decodedRaw_.isNullAt(0)) { + auto value = decodedRaw_.valueAt(0); + rows.applyToSelected( + [&](vector_size_t i) { updateNonNullValue(groups[i], value); }); + } + } else if (decodedRaw_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedRaw_.isNullAt(i)) { + return; + } + updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); + }); + } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + auto data = decodedRaw_.data(); + rows.applyToSelected([&](vector_size_t i) { + updateNonNullValue(groups[i], data[i]); + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); + }); + } + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodedRaw_.decode(*args[0], rows); + + if (decodedRaw_.isConstantMapping()) { + if (!decodedRaw_.isNullAt(0)) { + const T value = decodedRaw_.valueAt(0); + const auto numRows = rows.countSelected(); + updateNonNullValue(group, numRows, value * numRows); + } + } else if (decodedRaw_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (!decodedRaw_.isNullAt(i)) { + updateNonNullValue(group, decodedRaw_.valueAt(i)); + } + }); + } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + const T* data = decodedRaw_.data(); + double totalSum = 0; + rows.applyToSelected([&](vector_size_t i) { totalSum += data[i]; }); + updateNonNullValue(group, rows.countSelected(), totalSum); + } else { + double totalSum = 0; + rows.applyToSelected( + [&](vector_size_t i) { totalSum += decodedRaw_.valueAt(i); }); + updateNonNullValue(group, rows.countSelected(), totalSum); + } + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + auto baseSumVector = baseRowVector->childAt(0)->as>(); + auto baseCountVector = + baseRowVector->childAt(1)->as>(); + + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + auto count = baseCountVector->valueAt(decodedIndex); + auto sum = baseSumVector->valueAt(decodedIndex); + rows.applyToSelected([&](vector_size_t i) { + updateNonNullValue(groups[i], count, sum); + }); + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedPartial_.isNullAt(i)) { + return; + } + auto decodedIndex = decodedPartial_.index(i); + updateNonNullValue( + groups[i], + baseCountVector->valueAt(decodedIndex), + baseSumVector->valueAt(decodedIndex)); + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + auto decodedIndex = decodedPartial_.index(i); + updateNonNullValue( + groups[i], + baseCountVector->valueAt(decodedIndex), + baseSumVector->valueAt(decodedIndex)); + }); + } + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + auto baseSumVector = baseRowVector->childAt(0)->as>(); + auto baseCountVector = + baseRowVector->childAt(1)->as>(); + + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + const auto numRows = rows.countSelected(); + auto totalCount = baseCountVector->valueAt(decodedIndex) * numRows; + auto totalSum = baseSumVector->valueAt(decodedIndex) * numRows; + updateNonNullValue(group, totalCount, totalSum); + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (!decodedPartial_.isNullAt(i)) { + auto decodedIndex = decodedPartial_.index(i); + updateNonNullValue( + group, + baseCountVector->valueAt(decodedIndex), + baseSumVector->valueAt(decodedIndex)); + } + }); + } else { + double totalSum = 0; + int64_t totalCount = 0; + rows.applyToSelected([&](vector_size_t i) { + auto decodedIndex = decodedPartial_.index(i); + totalCount += baseCountVector->valueAt(decodedIndex); + totalSum += baseSumVector->valueAt(decodedIndex); + }); + updateNonNullValue(group, totalCount, totalSum); + } + } + + private: + // partial + template + inline void updateNonNullValue(char* group, T value) { + if constexpr (tableHasNulls) { + exec::Aggregate::clearNull(group); + } + accumulator(group)->sum += value; + accumulator(group)->count += 1; + } + + template + inline void updateNonNullValue(char* group, int64_t count, double sum) { + if constexpr (tableHasNulls) { + exec::Aggregate::clearNull(group); + } + accumulator(group)->sum += sum; + accumulator(group)->count += count; + } + + inline SumCount* accumulator(char* group) { + return exec::Aggregate::value(group); + } + + template + void extractValuesImpl(char** groups, int32_t numGroups, VectorPtr* result) { + auto vector = (*result)->as>(); + VELOX_CHECK(vector); + vector->resize(numGroups); + uint64_t* rawNulls = getRawNulls(vector); + + TResult* rawValues = vector->mutableRawValues(); + for (int32_t i = 0; i < numGroups; ++i) { + char* group = groups[i]; + if (isNull(group)) { + vector->setNull(i, true); + } else { + clearNull(rawNulls, i); + auto* sumCount = accumulator(group); + rawValues[i] = (TResult)sumCount->sum / sumCount->count; + } + } + } + + DecodedVector decodedRaw_; + DecodedVector decodedPartial_; +}; + +void checkSumCountRowType(TypePtr type, const std::string& errorMessage) { + VELOX_CHECK_EQ(type->kind(), TypeKind::ROW, "{}", errorMessage); + VELOX_CHECK_EQ( + type->childAt(0)->kind(), TypeKind::DOUBLE, "{}", errorMessage); + VELOX_CHECK_EQ( + type->childAt(1)->kind(), TypeKind::BIGINT, "{}", errorMessage); +} + +bool registerAverageAggregate(const std::string& name) { + std::vector> signatures; + + for (const auto& inputType : {"smallint", "integer", "bigint", "double"}) { + signatures.push_back(exec::AggregateFunctionSignatureBuilder() + .returnType("double") + .intermediateType("row(double,bigint)") + .argumentType(inputType) + .build()); + } + // Real input type in Presto has special case and returns REAL, not DOUBLE. + signatures.push_back(exec::AggregateFunctionSignatureBuilder() + .returnType("real") + .intermediateType("row(double,bigint)") + .argumentType("real") + .build()); + + exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType) -> std::unique_ptr { + VELOX_CHECK_LE( + argTypes.size(), 1, "{} takes at most one argument", name); + auto inputType = argTypes[0]; + if (exec::isRawInput(step)) { + switch (inputType->kind()) { + case TypeKind::SMALLINT: + return std::make_unique>(resultType); + case TypeKind::INTEGER: + return std::make_unique>(resultType); + case TypeKind::BIGINT: + return std::make_unique>(resultType); + case TypeKind::REAL: + return std::make_unique>(resultType); + case TypeKind::DOUBLE: + return std::make_unique>(resultType); + default: + VELOX_FAIL( + "Unknown input type for {} aggregation {}", + name, + inputType->kindName()); + } + } else { + checkSumCountRowType( + inputType, + "Input type for final aggregation must be (sum:double, count:bigint) struct"); + return std::make_unique>(resultType); + } + }); + return true; +} + +} // namespace facebook::velox::aggregate diff --git a/velox/functions/prestosql/aggregates/CMakeLists.txt b/velox/functions/prestosql/aggregates/CMakeLists.txt index 6b6ca032ec84d..9a9cbe87699e3 100644 --- a/velox/functions/prestosql/aggregates/CMakeLists.txt +++ b/velox/functions/prestosql/aggregates/CMakeLists.txt @@ -25,6 +25,7 @@ add_library( ArbitraryAggregate.cpp ArrayAggAggregate.cpp AverageAggregate.cpp + AverageAggregate.h BitwiseAggregates.cpp BoolAggregates.cpp CountIfAggregate.cpp @@ -38,6 +39,7 @@ add_library( MinMaxAggregates.cpp MinMaxByAggregates.cpp CountAggregate.cpp + CountAggregate.h PrestoHasher.cpp SingleValueAccumulator.cpp SumAggregate.cpp diff --git a/velox/functions/prestosql/aggregates/CountAggregate.cpp b/velox/functions/prestosql/aggregates/CountAggregate.cpp index 6a73566f97efc..ca5f58db4f602 100644 --- a/velox/functions/prestosql/aggregates/CountAggregate.cpp +++ b/velox/functions/prestosql/aggregates/CountAggregate.cpp @@ -13,170 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "velox/common/base/Exceptions.h" -#include "velox/expression/FunctionSignature.h" +#include "velox/functions/prestosql/aggregates/CountAggregate.h" #include "velox/functions/prestosql/aggregates/AggregateNames.h" -#include "velox/functions/prestosql/aggregates/SumAggregate.h" namespace facebook::velox::aggregate { -namespace { - -class CountAggregate : public SimpleNumericAggregate { - using BaseAggregate = SimpleNumericAggregate; - - public: - explicit CountAggregate() : BaseAggregate(BIGINT()) {} - - int32_t accumulatorFixedWidthSize() const override { - return sizeof(int64_t); - } - - void initializeNewGroups( - char** groups, - folly::Range indices) override { - for (auto i : indices) { - // result of count is never null - *value(groups[i]) = (int64_t)0; - } - } - - void extractValues(char** groups, int32_t numGroups, VectorPtr* result) - override { - BaseAggregate::doExtractValues(groups, numGroups, result, [&](char* group) { - return *value(group); - }); - } - - void addRawInput( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - if (args.empty()) { - rows.applyToSelected([&](vector_size_t i) { addToGroup(groups[i], 1); }); - return; - } - - DecodedVector decoded(*args[0], rows); - if (decoded.isConstantMapping()) { - if (!decoded.isNullAt(0)) { - rows.applyToSelected( - [&](vector_size_t i) { addToGroup(groups[i], 1); }); - } - } else if (decoded.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (decoded.isNullAt(i)) { - return; - } - addToGroup(groups[i], 1); - }); - } else { - rows.applyToSelected([&](vector_size_t i) { addToGroup(groups[i], 1); }); - } - } - - void addIntermediateResults( - char** groups, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - decodedIntermediate_.decode(*args[0], rows); - rows.applyToSelected([&](vector_size_t i) { - addToGroup(groups[i], decodedIntermediate_.valueAt(i)); - }); - } - - void addSingleGroupRawInput( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - if (args.empty()) { - addToGroup(group, rows.size()); - return; - } - - DecodedVector decoded(*args[0], rows); - if (decoded.isConstantMapping()) { - if (!decoded.isNullAt(0)) { - addToGroup(group, rows.size()); - } - } else if (decoded.mayHaveNulls()) { - int64_t nonNullCount = 0; - rows.applyToSelected([&](vector_size_t i) { - if (!decoded.isNullAt(i)) { - ++nonNullCount; - } - }); - addToGroup(group, nonNullCount); - } else { - addToGroup(group, rows.size()); - } - } - - void addSingleGroupIntermediateResults( - char* group, - const SelectivityVector& rows, - const std::vector& args, - bool /*mayPushdown*/) override { - decodedIntermediate_.decode(*args[0], rows); - - int64_t count = 0; - if (decodedIntermediate_.mayHaveNulls()) { - rows.applyToSelected([&](vector_size_t i) { - if (!decodedIntermediate_.isNullAt(i)) { - count += decodedIntermediate_.valueAt(i); - } - }); - } else { - rows.applyToSelected([&](vector_size_t i) { - count += decodedIntermediate_.valueAt(i); - }); - } - - addToGroup(group, count); - } - - private: - inline void addToGroup(char* group, int64_t count) { - *value(group) += count; - } - - DecodedVector decodedIntermediate_; -}; - -bool registerCountAggregate(const std::string& name) { - std::vector> signatures{ - exec::AggregateFunctionSignatureBuilder() - .returnType("bigint") - .intermediateType("bigint") - .build(), - exec::AggregateFunctionSignatureBuilder() - .typeVariable("T") - .returnType("bigint") - .intermediateType("bigint") - .argumentType("T") - .build(), - }; - - exec::registerAggregateFunction( - name, - std::move(signatures), - [name]( - core::AggregationNode::Step step, - const std::vector& argTypes, - const TypePtr& - /*resultType*/) -> std::unique_ptr { - VELOX_CHECK_LE( - argTypes.size(), 1, "{} takes at most one argument", name); - return std::make_unique(); - }); - return true; -} - static bool FB_ANONYMOUS_VARIABLE(g_AggregateFunction) = registerCountAggregate(kCount); -} // namespace } // namespace facebook::velox::aggregate diff --git a/velox/functions/prestosql/aggregates/CountAggregate.h b/velox/functions/prestosql/aggregates/CountAggregate.h new file mode 100644 index 0000000000000..9a54c894cf021 --- /dev/null +++ b/velox/functions/prestosql/aggregates/CountAggregate.h @@ -0,0 +1,176 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/Exceptions.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/functions/prestosql/aggregates/SumAggregate.h" + +namespace facebook::velox::aggregate { + +class CountAggregate : public SimpleNumericAggregate { + using BaseAggregate = SimpleNumericAggregate; + + public: + explicit CountAggregate() : BaseAggregate(BIGINT()) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(int64_t); + } + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + for (auto i : indices) { + // result of count is never null + *value(groups[i]) = (int64_t)0; + } + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + BaseAggregate::doExtractValues(groups, numGroups, result, [&](char* group) { + return *value(group); + }); + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + if (args.empty()) { + rows.applyToSelected([&](vector_size_t i) { addToGroup(groups[i], 1); }); + return; + } + + DecodedVector decoded(*args[0], rows); + if (decoded.isConstantMapping()) { + if (!decoded.isNullAt(0)) { + rows.applyToSelected( + [&](vector_size_t i) { addToGroup(groups[i], 1); }); + } + } else if (decoded.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decoded.isNullAt(i)) { + return; + } + addToGroup(groups[i], 1); + }); + } else { + rows.applyToSelected([&](vector_size_t i) { addToGroup(groups[i], 1); }); + } + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodedIntermediate_.decode(*args[0], rows); + rows.applyToSelected([&](vector_size_t i) { + addToGroup(groups[i], decodedIntermediate_.valueAt(i)); + }); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + if (args.empty()) { + addToGroup(group, rows.size()); + return; + } + + DecodedVector decoded(*args[0], rows); + if (decoded.isConstantMapping()) { + if (!decoded.isNullAt(0)) { + addToGroup(group, rows.size()); + } + } else if (decoded.mayHaveNulls()) { + int64_t nonNullCount = 0; + rows.applyToSelected([&](vector_size_t i) { + if (!decoded.isNullAt(i)) { + ++nonNullCount; + } + }); + addToGroup(group, nonNullCount); + } else { + addToGroup(group, rows.size()); + } + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodedIntermediate_.decode(*args[0], rows); + + int64_t count = 0; + if (decodedIntermediate_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (!decodedIntermediate_.isNullAt(i)) { + count += decodedIntermediate_.valueAt(i); + } + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + count += decodedIntermediate_.valueAt(i); + }); + } + + addToGroup(group, count); + } + + private: + inline void addToGroup(char* group, int64_t count) { + *value(group) += count; + } + + DecodedVector decodedIntermediate_; +}; + +bool registerCountAggregate(const std::string& name) { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .returnType("bigint") + .intermediateType("bigint") + .build(), + exec::AggregateFunctionSignatureBuilder() + .typeVariable("T") + .returnType("bigint") + .intermediateType("bigint") + .argumentType("T") + .build(), + }; + + exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& + /*resultType*/) -> std::unique_ptr { + VELOX_CHECK_LE( + argTypes.size(), 1, "{} takes at most one argument", name); + return std::make_unique(); + }); + return true; +} + +} // namespace facebook::velox::aggregate