Skip to content

Commit

Permalink
1. Fix the core bug of agg query in DEBUG mode (apache#28)
Browse files Browse the repository at this point in the history
2. Support the count aggregate_function
3. Some code refactor
  • Loading branch information
HappenLee committed Jul 13, 2021
1 parent 7dc50b6 commit 51ff865
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 28 deletions.
26 changes: 13 additions & 13 deletions be/src/vec/aggregate_functions/aggregate_function_count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,27 @@
// specific language governing permissions and limitations
// under the License.

// #include <AggregateFunctions/AggregateFunctionFactory.h>
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include <vec/aggregate_functions/aggregate_function_count.h>
#include <vec/aggregate_functions/factory_helpers.h>

namespace DB {
namespace doris::vectorized {

namespace {
AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, const DataTypes & argument_types, const Array & parameters) {
assertNoParameters(name, parameters);
assertArityAtMost<1>(name, argument_types);

// AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, const DataTypes & argument_types, const Array & parameters)
// {
// assertNoParameters(name, parameters);
// assertArityAtMost<1>(name, argument_types);

// return std::make_shared<AggregateFunctionCount>(argument_types);
// }

} // namespace
return std::make_shared<AggregateFunctionCount>(argument_types);
}

// void registerAggregateFunctionCount(AggregateFunctionFactory & factory)
// {
// factory.registerFunction("count", createAggregateFunctionCount, AggregateFunctionFactory::CaseInsensitive);
// }

} // namespace DB
void registerAggregateFunctionCount(AggregateFunctionSimpleFactory& factory) {
factory.registerFunction("count", createAggregateFunctionCount);
}

} // namespace

8 changes: 4 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class AggregateFunctionCount final

String getName() const override { return "count"; }

DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt64>(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeInt64>(); }

void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const override {
++data(place).count;
Expand All @@ -67,7 +67,7 @@ class AggregateFunctionCount final
}

void insertResultInto(ConstAggregateDataPtr place, IColumn& to) const override {
assert_cast<ColumnUInt64&>(to).getData().push_back(data(place).count);
assert_cast<ColumnInt64&>(to).getData().push_back(data(place).count);
}

const char* getHeaderFilePath() const override { return __FILE__; }
Expand All @@ -90,7 +90,7 @@ class AggregateFunctionCountNotNullUnary final

String getName() const override { return "count"; }

DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt64>(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeInt64>(); }

void add(AggregateDataPtr place, const IColumn** columns, size_t row_num,
Arena*) const override {
Expand All @@ -110,7 +110,7 @@ class AggregateFunctionCountNotNullUnary final
}

void insertResultInto(ConstAggregateDataPtr place, IColumn& to) const override {
assert_cast<ColumnUInt64&>(to).getData().push_back(data(place).count);
assert_cast<ColumnInt64&>(to).getData().push_back(data(place).count);
}

const char* getHeaderFilePath() const override { return __FILE__; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace doris::vectorized {

class AggregateFunctionSimpleFactory;
void registerAggregateFunctionSum(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionCount(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionCombinatorNull(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionMinMax(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionAvg(AggregateFunctionSimpleFactory& factory);
Expand Down Expand Up @@ -82,6 +83,7 @@ class AggregateFunctionSimpleFactory {
static AggregateFunctionSimpleFactory instance;
std::call_once(oc, [&]() {
registerAggregateFunctionSum(instance);
registerAggregateFunctionCount(instance);
registerAggregateFunctionMinMax(instance);
registerAggregateFunctionAvg(instance);
registerAggregateFunctionCombinatorNull(instance);
Expand Down
12 changes: 7 additions & 5 deletions be/src/vec/exec/vaggregation_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#include "vec/exec/vaggregation_node.h"

#include <memory>

#include "exec/exec_node.h"
#include "runtime/mem_pool.h"
#include "runtime/row_batch.h"
Expand All @@ -40,7 +42,7 @@ AggregationNode::AggregationNode(ObjectPool* pool, const TPlanNode& tnode,
_is_merge(false),
_agg_data() {}

AggregationNode::~AggregationNode() {}
AggregationNode::~AggregationNode() = default;

Status AggregationNode::init(const TPlanNode& tnode, RuntimeState* state) {
RETURN_IF_ERROR(ExecNode::init(tnode, state));
Expand All @@ -57,8 +59,8 @@ Status AggregationNode::init(const TPlanNode& tnode, RuntimeState* state) {
_aggregate_evaluators.push_back(evaluator);
}

auto agg_functions = tnode.agg_node.aggregate_functions;
_is_merge = std::any_of(agg_functions.begin(), agg_functions.end(),
const auto& agg_functions = tnode.agg_node.aggregate_functions;
_is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
[](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; });
return Status::OK();
}
Expand All @@ -73,7 +75,7 @@ Status AggregationNode::prepare(RuntimeState* state) {
RETURN_IF_ERROR(
VExpr::prepare(_probe_expr_ctxs, state, child(0)->row_desc(), expr_mem_tracker()));

_mem_pool.reset(new MemPool(mem_tracker().get()));
_mem_pool = std::make_unique<MemPool>(mem_tracker().get());

int j = _probe_expr_ctxs.size();
for (int i = 0; i < _aggregate_evaluators.size(); ++i, ++j) {
Expand All @@ -89,7 +91,7 @@ Status AggregationNode::prepare(RuntimeState* state) {
for (size_t i = 0; i < _aggregate_evaluators.size(); ++i) {
_offsets_of_aggregate_states[i] = _total_size_of_aggregate_states;

const auto agg_function = _aggregate_evaluators[i]->function();
const auto& agg_function = _aggregate_evaluators[i]->function();
// aggreate states are aligned based on maximum requirement
_align_aggregate_states = std::max(_align_aggregate_states, agg_function->alignOfData());
_total_size_of_aggregate_states += agg_function->sizeOfData();
Expand Down
13 changes: 9 additions & 4 deletions be/src/vec/exprs/vectorized_agg_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,19 @@ void AggFnEvaluator::destroy(AggregateDataPtr place) {
}

void AggFnEvaluator::execute_single_add(Block* block, AggregateDataPtr place, Arena* arena) {
std::vector<const IColumn*> column_arguments(_input_exprs_ctxs.size());
auto columns = block->getColumns();
std::vector<ColumnPtr> columns(_input_exprs_ctxs.size());
for (int i = 0; i < _input_exprs_ctxs.size(); ++i) {
int column_id = -1;
_input_exprs_ctxs[i]->execute(block, &column_id);
column_arguments[i] =
block->getByPosition(column_id).column->convertToFullColumnIfConst().get();
columns[i] =
block->getByPosition(column_id).column->convertToFullColumnIfConst();
}
// Because the `convertToFullColumnIfConst()` may return a temporary variable, so we need keep the reference of it
// to make sure program do not destroy it before we call `addBatchSinglePlace`.
// WARNING:
// There's danger to call `convertToFullColumnIfConst().get()` to get the `const IColumn*` directly.
std::vector<const IColumn*> column_arguments(columns.size());
std::transform(columns.cbegin(), columns.cend(), column_arguments.begin(), [](const auto& ptr) {return ptr.get();});
_function->addBatchSinglePlace(block->rows(), place, column_arguments.data(), nullptr);
}

Expand Down
4 changes: 2 additions & 2 deletions be/src/vec/utils/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class VectorizedUtils {

static ColumnsWithTypeAndName create_columns_with_type_and_name(const RowDescriptor& row_desc) {
ColumnsWithTypeAndName columns_with_type_and_name;
for (const auto tuple_desc : row_desc.tuple_descriptors()) {
for (const auto slot_desc : tuple_desc->slots()) {
for (const auto& tuple_desc : row_desc.tuple_descriptors()) {
for (const auto& slot_desc : tuple_desc->slots()) {
columns_with_type_and_name.emplace_back(nullptr, slot_desc->get_data_type_ptr(),
slot_desc->col_name());
}
Expand Down

0 comments on commit 51ff865

Please sign in to comment.