From d895ba0355139b578e6d9ac53e30ac85ee72d2fc Mon Sep 17 00:00:00 2001 From: Rui Mo Date: Mon, 6 Jun 2022 11:51:16 +0800 Subject: [PATCH] [OPPRO-104] Support more cases of filter and its pushdown (#14) --- .../prestosql/aggregates/CMakeLists.txt | 1 + .../prestosql/aggregates/MinMaxAggregates.cpp | 7 +- .../prestosql/aggregates/MinMaxAggregates.h | 498 ++++++++++++ velox/substrait/CMakeLists.txt | 3 +- velox/substrait/SubstraitParser.cpp | 12 +- velox/substrait/SubstraitParser.h | 6 +- velox/substrait/SubstraitToVeloxExpr.cpp | 117 ++- velox/substrait/SubstraitToVeloxExpr.h | 26 +- velox/substrait/SubstraitToVeloxPlan.cpp | 767 ++++++++++++++++-- velox/substrait/SubstraitToVeloxPlan.h | 301 +++++-- .../SubstraitToVeloxPlanValidator.cpp | 165 ++-- .../substrait/SubstraitToVeloxPlanValidator.h | 15 +- velox/substrait/TypeUtils.h | 38 + velox/substrait/VectorCreater.cpp | 94 +++ velox/substrait/VectorCreater.h | 37 + velox/substrait/tests/PlanConversionTest.cpp | 94 ++- velox/type/Filter.cpp | 125 +++ velox/type/Filter.h | 79 +- 18 files changed, 2074 insertions(+), 311 deletions(-) create mode 100644 velox/functions/prestosql/aggregates/MinMaxAggregates.h create mode 100644 velox/substrait/VectorCreater.cpp create mode 100644 velox/substrait/VectorCreater.h diff --git a/velox/functions/prestosql/aggregates/CMakeLists.txt b/velox/functions/prestosql/aggregates/CMakeLists.txt index 30872f83a442..30f2c60b65d5 100644 --- a/velox/functions/prestosql/aggregates/CMakeLists.txt +++ b/velox/functions/prestosql/aggregates/CMakeLists.txt @@ -33,6 +33,7 @@ add_library( MapAggregateBase.cpp MapUnionAggregate.cpp MinMaxAggregates.cpp + MinMaxAggregates.h MinMaxByAggregates.cpp CountAggregate.cpp CountAggregate.h diff --git a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp index 19d76d164fb4..dfa891ea3857 100644 --- a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -14,13 +14,8 @@ * limitations under the License. */ -#include -#include "velox/exec/Aggregate.h" -#include "velox/exec/AggregationHook.h" -#include "velox/expression/FunctionSignature.h" +#include "velox/functions/prestosql/aggregates/MinMaxAggregates.h" #include "velox/functions/prestosql/aggregates/AggregateNames.h" -#include "velox/functions/prestosql/aggregates/SimpleNumericAggregate.h" -#include "velox/functions/prestosql/aggregates/SingleValueAccumulator.h" namespace facebook::velox::aggregate::prestosql { diff --git a/velox/functions/prestosql/aggregates/MinMaxAggregates.h b/velox/functions/prestosql/aggregates/MinMaxAggregates.h new file mode 100644 index 000000000000..56b500bd61a6 --- /dev/null +++ b/velox/functions/prestosql/aggregates/MinMaxAggregates.h @@ -0,0 +1,498 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "velox/exec/Aggregate.h" +#include "velox/exec/AggregationHook.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/functions/prestosql/aggregates/SimpleNumericAggregate.h" +#include "velox/functions/prestosql/aggregates/SingleValueAccumulator.h" + +namespace facebook::velox::aggregate { + +template +struct MinMaxTrait : public std::numeric_limits {}; + +template <> +struct MinMaxTrait { + static constexpr Timestamp min() { + return Timestamp(MinMaxTrait::min(), MinMaxTrait::min()); + } + + static constexpr Timestamp max() { + return Timestamp(MinMaxTrait::max(), MinMaxTrait::max()); + } +}; + +template <> +struct MinMaxTrait { + static constexpr Date min() { + return Date(std::numeric_limits::min()); + } + + static constexpr Date max() { + return Date(std::numeric_limits::max()); + } +}; + +template +class MinMaxAggregate : public SimpleNumericAggregate { + using BaseAggregate = SimpleNumericAggregate; + + public: + explicit MinMaxAggregate(TypePtr resultType) : BaseAggregate(resultType) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(T); + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + BaseAggregate::template doExtractValues( + groups, numGroups, result, [&](char* group) { + return *BaseAggregate::Aggregate::template value(group); + }); + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + BaseAggregate::template doExtractValues( + groups, numGroups, result, [&](char* group) { + return *BaseAggregate::Aggregate::template value(group); + }); + } +}; + +// Truncate timestamps to milliseconds precision. +template <> +void MinMaxAggregate::extractValues( + char** groups, + int32_t numGroups, + VectorPtr* result) { + BaseAggregate::template doExtractValues( + groups, numGroups, result, [&](char* group) { + auto ts = *BaseAggregate::Aggregate::template value(group); + return Timestamp::fromMillis(ts.toMillis()); + }); +} + +template +class MaxAggregate : public MinMaxAggregate { + using BaseAggregate = SimpleNumericAggregate; + + public: + explicit MaxAggregate(TypePtr resultType) : MinMaxAggregate(resultType) {} + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + *exec::Aggregate::value(groups[i]) = kInitialValue_; + } + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + if (mayPushdown && args[0]->isLazy()) { + BaseAggregate::template pushdown>( + groups, rows, args[0]); + return; + } + BaseAggregate::template updateGroups( + groups, + rows, + args[0], + [](T& result, T value) { + if (result < value) { + result = value; + } + }, + mayPushdown); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + BaseAggregate::updateOneGroup( + group, + rows, + args[0], + [](T& result, T value) { result = result > value ? result : value; }, + [](T& result, T value, int /* unused */) { result = value; }, + mayPushdown, + kInitialValue_); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } + + private: + static constexpr T kInitialValue_{MinMaxTrait::min()}; +}; + +template +class MinAggregate : public MinMaxAggregate { + using BaseAggregate = SimpleNumericAggregate; + + public: + explicit MinAggregate(TypePtr resultType) : MinMaxAggregate(resultType) {} + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + *exec::Aggregate::value(groups[i]) = kInitialValue_; + } + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + if (mayPushdown && args[0]->isLazy()) { + BaseAggregate::template pushdown>( + groups, rows, args[0]); + return; + } + BaseAggregate::template updateGroups( + groups, + rows, + args[0], + [](T& result, T value) { + if (result > value) { + result = value; + } + }, + mayPushdown); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + BaseAggregate::updateOneGroup( + group, + rows, + args[0], + [](T& result, T value) { result = result < value ? result : value; }, + [](T& result, T value, int /* unused */) { result = value; }, + mayPushdown, + kInitialValue_); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } + + private: + static constexpr T kInitialValue_{MinMaxTrait::max()}; +}; + +class NonNumericMinMaxAggregateBase : public exec::Aggregate { + public: + explicit NonNumericMinMaxAggregateBase(const TypePtr& resultType) + : exec::Aggregate(resultType) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(SingleValueAccumulator); + } + + void initializeNewGroups( + char** groups, + folly::Range indices) override { + exec::Aggregate::setAllNulls(groups, indices); + for (auto i : indices) { + new (groups[i] + offset_) SingleValueAccumulator(); + } + } + + void finalize(char** /* groups */, int32_t /* numGroups */) override { + // Nothing to do + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + VELOX_CHECK(result); + (*result)->resize(numGroups); + + uint64_t* rawNulls = nullptr; + if ((*result)->mayHaveNulls()) { + BufferPtr nulls = (*result)->mutableNulls((*result)->size()); + rawNulls = nulls->asMutable(); + } + + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + auto accumulator = value(group); + if (!accumulator->hasValue()) { + (*result)->setNull(i, true); + } else { + if (rawNulls) { + bits::clearBit(rawNulls, i); + } + accumulator->read(*result, i); + } + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + // partial and final aggregations are the same + extractValues(groups, numGroups, result); + } + + void destroy(folly::Range groups) override { + for (auto group : groups) { + value(group)->destroy(allocator_); + } + } + + protected: + template + void doUpdate( + char** groups, + const SelectivityVector& rows, + const VectorPtr& arg, + TCompareTest compareTest) { + DecodedVector decoded(*arg, rows, true); + auto indices = decoded.indices(); + auto baseVector = decoded.base(); + + if (decoded.isConstantMapping() && decoded.isNullAt(0)) { + // nothing to do; all values are nulls + return; + } + + rows.applyToSelected([&](vector_size_t i) { + if (decoded.isNullAt(i)) { + return; + } + auto accumulator = value(groups[i]); + if (!accumulator->hasValue() || + compareTest(accumulator->compare(decoded, i))) { + accumulator->write(baseVector, indices[i], allocator_); + } + }); + } + + template + void doUpdateSingleGroup( + char* group, + const SelectivityVector& rows, + const VectorPtr& arg, + TCompareTest compareTest) { + DecodedVector decoded(*arg, rows, true); + auto indices = decoded.indices(); + auto baseVector = decoded.base(); + + if (decoded.isConstantMapping()) { + if (decoded.isNullAt(0)) { + // nothing to do; all values are nulls + return; + } + + auto accumulator = value(group); + if (!accumulator->hasValue() || + compareTest(accumulator->compare(decoded, 0))) { + accumulator->write(baseVector, indices[0], allocator_); + } + return; + } + + auto accumulator = value(group); + rows.applyToSelected([&](vector_size_t i) { + if (decoded.isNullAt(i)) { + return; + } + if (!accumulator->hasValue() || + compareTest(accumulator->compare(decoded, i))) { + accumulator->write(baseVector, indices[i], allocator_); + } + }); + } +}; + +class NonNumericMaxAggregate : public NonNumericMinMaxAggregateBase { + public: + explicit NonNumericMaxAggregate(const TypePtr& resultType) + : NonNumericMinMaxAggregateBase(resultType) {} + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdate(groups, rows, args[0], [](int32_t compareResult) { + return compareResult < 0; + }); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdateSingleGroup(group, rows, args[0], [](int32_t compareResult) { + return compareResult < 0; + }); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } +}; + +class NonNumericMinAggregate : public NonNumericMinMaxAggregateBase { + public: + explicit NonNumericMinAggregate(const TypePtr& resultType) + : NonNumericMinMaxAggregateBase(resultType) {} + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdate(groups, rows, args[0], [](int32_t compareResult) { + return compareResult > 0; + }); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addRawInput(groups, rows, args, mayPushdown); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + doUpdateSingleGroup(group, rows, args[0], [](int32_t compareResult) { + return compareResult > 0; + }); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override { + addSingleGroupRawInput(group, rows, args, mayPushdown); + } +}; + +template