diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.cpp b/be/src/vec/aggregate_functions/aggregate_function_count.cpp index 3ad88093093196..f68f2219dcc1fc 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_count.cpp @@ -15,27 +15,27 @@ // specific language governing permissions and limitations // under the License. -// #include +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include #include -namespace DB { +namespace doris::vectorized { -namespace { +AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, const DataTypes & argument_types, const Array & parameters) { + assertNoParameters(name, parameters); + assertArityAtMost<1>(name, argument_types); -// AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, const DataTypes & argument_types, const Array & parameters) -// { -// assertNoParameters(name, parameters); -// assertArityAtMost<1>(name, argument_types); - -// return std::make_shared(argument_types); -// } - -} // namespace + return std::make_shared(argument_types); +} // void registerAggregateFunctionCount(AggregateFunctionFactory & factory) // { // factory.registerFunction("count", createAggregateFunctionCount, AggregateFunctionFactory::CaseInsensitive); // } -} // namespace DB +void registerAggregateFunctionCount(AggregateFunctionSimpleFactory& factory) { + factory.registerFunction("count", createAggregateFunctionCount); +} + +} // namespace + diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.h b/be/src/vec/aggregate_functions/aggregate_function_count.h index ad1a054b8d3110..51c5bc0a749eb7 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count.h +++ b/be/src/vec/aggregate_functions/aggregate_function_count.h @@ -48,7 +48,7 @@ class AggregateFunctionCount final String getName() const override { return "count"; } - DataTypePtr getReturnType() const override { return std::make_shared(); } + DataTypePtr getReturnType() const override { return std::make_shared(); } void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const override { ++data(place).count; @@ -67,7 +67,7 @@ class AggregateFunctionCount final } void insertResultInto(ConstAggregateDataPtr place, IColumn& to) const override { - assert_cast(to).getData().push_back(data(place).count); + assert_cast(to).getData().push_back(data(place).count); } const char* getHeaderFilePath() const override { return __FILE__; } @@ -90,7 +90,7 @@ class AggregateFunctionCountNotNullUnary final String getName() const override { return "count"; } - DataTypePtr getReturnType() const override { return std::make_shared(); } + DataTypePtr getReturnType() const override { return std::make_shared(); } void add(AggregateDataPtr place, const IColumn** columns, size_t row_num, Arena*) const override { @@ -110,7 +110,7 @@ class AggregateFunctionCountNotNullUnary final } void insertResultInto(ConstAggregateDataPtr place, IColumn& to) const override { - assert_cast(to).getData().push_back(data(place).count); + assert_cast(to).getData().push_back(data(place).count); } const char* getHeaderFilePath() const override { return __FILE__; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h index 6f721cc05ed971..79455d68fc4790 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h @@ -31,6 +31,7 @@ namespace doris::vectorized { class AggregateFunctionSimpleFactory; void registerAggregateFunctionSum(AggregateFunctionSimpleFactory& factory); +void registerAggregateFunctionCount(AggregateFunctionSimpleFactory& factory); void registerAggregateFunctionCombinatorNull(AggregateFunctionSimpleFactory& factory); void registerAggregateFunctionMinMax(AggregateFunctionSimpleFactory& factory); void registerAggregateFunctionAvg(AggregateFunctionSimpleFactory& factory); @@ -82,6 +83,7 @@ class AggregateFunctionSimpleFactory { static AggregateFunctionSimpleFactory instance; std::call_once(oc, [&]() { registerAggregateFunctionSum(instance); + registerAggregateFunctionCount(instance); registerAggregateFunctionMinMax(instance); registerAggregateFunctionAvg(instance); registerAggregateFunctionCombinatorNull(instance); diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp index 46bb5394d09dc4..848602722b5df1 100644 --- a/be/src/vec/exec/vaggregation_node.cpp +++ b/be/src/vec/exec/vaggregation_node.cpp @@ -17,6 +17,8 @@ #include "vec/exec/vaggregation_node.h" +#include + #include "exec/exec_node.h" #include "runtime/mem_pool.h" #include "runtime/row_batch.h" @@ -40,7 +42,7 @@ AggregationNode::AggregationNode(ObjectPool* pool, const TPlanNode& tnode, _is_merge(false), _agg_data() {} -AggregationNode::~AggregationNode() {} +AggregationNode::~AggregationNode() = default; Status AggregationNode::init(const TPlanNode& tnode, RuntimeState* state) { RETURN_IF_ERROR(ExecNode::init(tnode, state)); @@ -57,8 +59,8 @@ Status AggregationNode::init(const TPlanNode& tnode, RuntimeState* state) { _aggregate_evaluators.push_back(evaluator); } - auto agg_functions = tnode.agg_node.aggregate_functions; - _is_merge = std::any_of(agg_functions.begin(), agg_functions.end(), + const auto& agg_functions = tnode.agg_node.aggregate_functions; + _is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(), [](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; }); return Status::OK(); } @@ -73,7 +75,7 @@ Status AggregationNode::prepare(RuntimeState* state) { RETURN_IF_ERROR( VExpr::prepare(_probe_expr_ctxs, state, child(0)->row_desc(), expr_mem_tracker())); - _mem_pool.reset(new MemPool(mem_tracker().get())); + _mem_pool = std::make_unique(mem_tracker().get()); int j = _probe_expr_ctxs.size(); for (int i = 0; i < _aggregate_evaluators.size(); ++i, ++j) { @@ -89,7 +91,7 @@ Status AggregationNode::prepare(RuntimeState* state) { for (size_t i = 0; i < _aggregate_evaluators.size(); ++i) { _offsets_of_aggregate_states[i] = _total_size_of_aggregate_states; - const auto agg_function = _aggregate_evaluators[i]->function(); + const auto& agg_function = _aggregate_evaluators[i]->function(); // aggreate states are aligned based on maximum requirement _align_aggregate_states = std::max(_align_aggregate_states, agg_function->alignOfData()); _total_size_of_aggregate_states += agg_function->sizeOfData(); diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index b4281ddcbcf546..1a301cbe1fd635 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -99,14 +99,19 @@ void AggFnEvaluator::destroy(AggregateDataPtr place) { } void AggFnEvaluator::execute_single_add(Block* block, AggregateDataPtr place, Arena* arena) { - std::vector column_arguments(_input_exprs_ctxs.size()); - auto columns = block->getColumns(); + std::vector columns(_input_exprs_ctxs.size()); for (int i = 0; i < _input_exprs_ctxs.size(); ++i) { int column_id = -1; _input_exprs_ctxs[i]->execute(block, &column_id); - column_arguments[i] = - block->getByPosition(column_id).column->convertToFullColumnIfConst().get(); + columns[i] = + block->getByPosition(column_id).column->convertToFullColumnIfConst(); } + // Because the `convertToFullColumnIfConst()` may return a temporary variable, so we need keep the reference of it + // to make sure program do not destroy it before we call `addBatchSinglePlace`. + // WARNING: + // There's danger to call `convertToFullColumnIfConst().get()` to get the `const IColumn*` directly. + std::vector column_arguments(columns.size()); + std::transform(columns.cbegin(), columns.cend(), column_arguments.begin(), [](const auto& ptr) {return ptr.get();}); _function->addBatchSinglePlace(block->rows(), place, column_arguments.data(), nullptr); } diff --git a/be/src/vec/utils/util.hpp b/be/src/vec/utils/util.hpp index 3bd6220b76cb11..f0362bf35e17b7 100644 --- a/be/src/vec/utils/util.hpp +++ b/be/src/vec/utils/util.hpp @@ -15,8 +15,8 @@ class VectorizedUtils { static ColumnsWithTypeAndName create_columns_with_type_and_name(const RowDescriptor& row_desc) { ColumnsWithTypeAndName columns_with_type_and_name; - for (const auto tuple_desc : row_desc.tuple_descriptors()) { - for (const auto slot_desc : tuple_desc->slots()) { + for (const auto& tuple_desc : row_desc.tuple_descriptors()) { + for (const auto& slot_desc : tuple_desc->slots()) { columns_with_type_and_name.emplace_back(nullptr, slot_desc->get_data_type_ptr(), slot_desc->col_name()); }