diff --git a/velox/functions/prestosql/aggregates/CMakeLists.txt b/velox/functions/prestosql/aggregates/CMakeLists.txt index 30872f83a442..30f2c60b65d5 100644 --- a/velox/functions/prestosql/aggregates/CMakeLists.txt +++ b/velox/functions/prestosql/aggregates/CMakeLists.txt @@ -33,6 +33,7 @@ add_library( MapAggregateBase.cpp MapUnionAggregate.cpp MinMaxAggregates.cpp + MinMaxAggregates.h MinMaxByAggregates.cpp CountAggregate.cpp CountAggregate.h diff --git a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp index 19d76d164fb4..dfa891ea3857 100644 --- a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -14,13 +14,8 @@ * limitations under the License. */ -#include -#include "velox/exec/Aggregate.h" -#include "velox/exec/AggregationHook.h" -#include "velox/expression/FunctionSignature.h" +#include "velox/functions/prestosql/aggregates/MinMaxAggregates.h" #include "velox/functions/prestosql/aggregates/AggregateNames.h" -#include "velox/functions/prestosql/aggregates/SimpleNumericAggregate.h" -#include "velox/functions/prestosql/aggregates/SingleValueAccumulator.h" namespace facebook::velox::aggregate::prestosql { diff --git a/velox/functions/prestosql/aggregates/MinMaxAggregates.h b/velox/functions/prestosql/aggregates/MinMaxAggregates.h new file mode 100644 index 000000000000..56b500bd61a6 --- /dev/null +++ b/velox/functions/prestosql/aggregates/MinMaxAggregates.h @@ -0,0 +1,498 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "velox/exec/Aggregate.h" +#include "velox/exec/AggregationHook.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/functions/prestosql/aggregates/SimpleNumericAggregate.h" +#include "velox/functions/prestosql/aggregates/SingleValueAccumulator.h" + +namespace facebook::velox::aggregate { + +template +struct MinMaxTrait : public std::numeric_limits {}; + +template <> +struct MinMaxTrait { + static constexpr Timestamp min() { + return Timestamp(MinMaxTrait::min(), MinMaxTrait::min()); + } + + static constexpr Timestamp max() { + return Timestamp(MinMaxTrait::max(), MinMaxTrait::max()); + } +}; + +template <> +struct MinMaxTrait { + static constexpr Date min() { + return Date(std::numeric_limits::min()); + } + + static constexpr Date max() { + return Date(std::numeric_limits::max()); + } +}; + +template +class MinMaxAggregate : public SimpleNumericAggregate { + using BaseAggregate = SimpleNumericAggregate; + + public: + explicit MinMaxAggregate(TypePtr resultType) : BaseAggregate(resultType) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(T); + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + BaseAggregate::template doExtractValues( + groups, numGroups, result, [&](char* group) { + return *BaseAggregate::Aggregate::template value(group); + }); + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + BaseAggregate::template doExtractValues( + groups, numGroups, result, [&](char* group) { + return *BaseAggregate::Aggregate::template value(group); + }); + } +}; + +// Truncate timestamps to milliseconds precision. +template <> +void MinMaxAggregate::extractValues( + char** groups, + int32_t numGroups, + VectorPtr* result) { + BaseAggregate::template doExtractValues( + groups, numGroups, result, [&](char* group) { + auto ts = *BaseAggregate::Aggregate::template value(group); + return Timestamp::fromMillis(ts.toMillis()); + }); +} + +template +class MaxAggregate : public MinMaxAggregate { + using BaseAggregate = SimpleNumericAggregate; + + public: + explicit MaxAggregate(TypePtr resultType) : MinMaxAggregate(resultType) {} + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + *exec::Aggregate::value(groups[i]) = kInitialValue_; + } + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + if (mayPushdown && args[0]->isLazy()) { + BaseAggregate::template pushdown>( + groups, rows, args[0]); + return; + } + BaseAggregate::template updateGroups( + groups, + rows, + args[0], + [](T& result, T value) { + if (result < value) { + result = value; + } + }, + mayPushdown); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + BaseAggregate::updateOneGroup( + group, + rows, + args[0], + [](T& result, T value) { result = result > value ? result : value; }, + [](T& result, T value, int /* unused */) { result = value; }, + mayPushdown, + kInitialValue_); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } + + private: + static constexpr T kInitialValue_{MinMaxTrait::min()}; +}; + +template +class MinAggregate : public MinMaxAggregate { + using BaseAggregate = SimpleNumericAggregate; + + public: + explicit MinAggregate(TypePtr resultType) : MinMaxAggregate(resultType) {} + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + *exec::Aggregate::value(groups[i]) = kInitialValue_; + } + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + if (mayPushdown && args[0]->isLazy()) { + BaseAggregate::template pushdown>( + groups, rows, args[0]); + return; + } + BaseAggregate::template updateGroups( + groups, + rows, + args[0], + [](T& result, T value) { + if (result > value) { + result = value; + } + }, + mayPushdown); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + BaseAggregate::updateOneGroup( + group, + rows, + args[0], + [](T& result, T value) { result = result < value ? result : value; }, + [](T& result, T value, int /* unused */) { result = value; }, + mayPushdown, + kInitialValue_); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } + + private: + static constexpr T kInitialValue_{MinMaxTrait::max()}; +}; + +class NonNumericMinMaxAggregateBase : public exec::Aggregate { + public: + explicit NonNumericMinMaxAggregateBase(const TypePtr& resultType) + : exec::Aggregate(resultType) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(SingleValueAccumulator); + } + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + new (groups[i] + offset_) SingleValueAccumulator(); + } + } + + void finalize(char** /* groups */, int32_t /* numGroups */) override { + // Nothing to do + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + VELOX_CHECK(result); + (*result)->resize(numGroups); + + uint64_t* rawNulls = nullptr; + if ((*result)->mayHaveNulls()) { + BufferPtr nulls = (*result)->mutableNulls((*result)->size()); + rawNulls = nulls->asMutable(); + } + + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + auto accumulator = value(group); + if (!accumulator->hasValue()) { + (*result)->setNull(i, true); + } else { + if (rawNulls) { + bits::clearBit(rawNulls, i); + } + accumulator->read(*result, i); + } + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + // partial and final aggregations are the same + extractValues(groups, numGroups, result); + } + + void destroy(folly::Range groups) override { + for (auto group : groups) { + value(group)->destroy(allocator_); + } + } + + protected: + template + void doUpdate( + char** groups, + const SelectivityVector& rows, + const VectorPtr& arg, + TCompareTest compareTest) { + DecodedVector decoded(*arg, rows, true); + auto indices = decoded.indices(); + auto baseVector = decoded.base(); + + if (decoded.isConstantMapping() && decoded.isNullAt(0)) { + // nothing to do; all values are nulls + return; + } + + rows.applyToSelected([&](vector_size_t i) { + if (decoded.isNullAt(i)) { + return; + } + auto accumulator = value(groups[i]); + if (!accumulator->hasValue() || + compareTest(accumulator->compare(decoded, i))) { + accumulator->write(baseVector, indices[i], allocator_); + } + }); + } + + template + void doUpdateSingleGroup( + char* group, + const SelectivityVector& rows, + const VectorPtr& arg, + TCompareTest compareTest) { + DecodedVector decoded(*arg, rows, true); + auto indices = decoded.indices(); + auto baseVector = decoded.base(); + + if (decoded.isConstantMapping()) { + if (decoded.isNullAt(0)) { + // nothing to do; all values are nulls + return; + } + + auto accumulator = value(group); + if (!accumulator->hasValue() || + compareTest(accumulator->compare(decoded, 0))) { + accumulator->write(baseVector, indices[0], allocator_); + } + return; + } + + auto accumulator = value(group); + rows.applyToSelected([&](vector_size_t i) { + if (decoded.isNullAt(i)) { + return; + } + if (!accumulator->hasValue() || + compareTest(accumulator->compare(decoded, i))) { + accumulator->write(baseVector, indices[i], allocator_); + } + }); + } +}; + +class NonNumericMaxAggregate : public NonNumericMinMaxAggregateBase { + public: + explicit NonNumericMaxAggregate(const TypePtr& resultType) + : NonNumericMinMaxAggregateBase(resultType) {} + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdate(groups, rows, args[0], [](int32_t compareResult) { + return compareResult < 0; + }); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdateSingleGroup(group, rows, args[0], [](int32_t compareResult) { + return compareResult < 0; + }); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } +}; + +class NonNumericMinAggregate : public NonNumericMinMaxAggregateBase { + public: + explicit NonNumericMinAggregate(const TypePtr& resultType) + : NonNumericMinMaxAggregateBase(resultType) {} + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdate(groups, rows, args[0], [](int32_t compareResult) { + return compareResult > 0; + }); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdateSingleGroup(group, rows, args[0], [](int32_t compareResult) { + return compareResult > 0; + }); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } +}; + +template