Skip to content

Commit

Permalink
[fix](window_func) fix bug of agg function used in window function an…
Browse files Browse the repository at this point in the history
…d add many test cases (apache#40678)

## Proposed changes

Support usage in window function for many agg functions by adding
`reset` method.
Fix bug of `window_funnel` used in window funciton;
Fix bug of wrong result of `orthogonal_bitmap_intersect`.
  • Loading branch information
jacktengg authored Sep 20, 2024
1 parent 51ba957 commit 11dfd19
Show file tree
Hide file tree
Showing 14 changed files with 1,440 additions and 53 deletions.
3 changes: 0 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ class IAggregateFunction {
virtual void deserialize_and_merge_from_column(AggregateDataPtr __restrict place,
const IColumn& column, Arena* arena) const = 0;

/// Returns true if a function requires Arena to handle own states (see add(), merge(), deserialize()).
virtual bool allocates_memory_in_arena() const { return false; }

/// Inserts results into a column.
virtual void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const = 0;

Expand Down
4 changes: 2 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ struct AggregateFunctionBinary

String get_name() const override { return StatFunc::Data::name(); }

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

DataTypePtr get_return_type() const override {
return std::make_shared<DataTypeNumber<ResultType>>();
}

bool allocates_memory_in_arena() const override { return false; }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
this->data(place).add(
Expand Down
2 changes: 0 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_collect.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,6 @@ class AggregateFunctionCollect
return std::make_shared<DataTypeArray>(make_nullable(return_type));
}

bool allocates_memory_in_arena() const override { return ENABLE_ARENA; }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
auto& data = this->data(place);
Expand Down
9 changes: 9 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_corr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ struct CorrMoment {
}

static String name() { return "corr"; }

void reset() {
m0 = {};
x1 = {};
y1 = {};
xy = {};
x2 = {};
y2 = {};
}
};

AggregateFunctionPtr create_aggregate_corr_function(const std::string& name,
Expand Down
2 changes: 0 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_distinct.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,6 @@ class AggregateFunctionDistinct

DataTypePtr get_return_type() const override { return nested_func->get_return_type(); }

bool allocates_memory_in_arena() const override { return true; }

AggregateFunctionPtr transmit_to_stable() override {
return AggregateFunctionPtr(new AggregateFunctionDistinct<Data, true>(
nested_func, IAggregateFunction::argument_types));
Expand Down
4 changes: 0 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_foreach.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,6 @@ class AggregateFunctionForEach : public IAggregateFunctionDataHelper<AggregateFu
offsets_to.push_back(offsets_to.back() + state.dynamic_array_size);
}

bool allocates_memory_in_arena() const override {
return nested_function->allocates_memory_in_arena();
}

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
std::vector<const IColumn*> nested(num_arguments);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ struct AggregateFunctionGroupArrayIntersectData {
Set value;
bool init = false;

void reset() {
init = false;
value = std::make_unique<NullableNumericOrDateSetType>();
}

void process_col_data(auto& column_data, size_t offset, size_t arr_size, bool& init, Set& set) {
const bool is_column_data_nullable = column_data.is_nullable();

Expand Down Expand Up @@ -163,7 +168,7 @@ class AggregateFunctionGroupArrayIntersect

DataTypePtr get_return_type() const override { return argument_type; }

bool allocates_memory_in_arena() const override { return false; }
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
Expand Down Expand Up @@ -331,6 +336,11 @@ struct AggregateFunctionGroupArrayIntersectGenericData {
: value(std::make_unique<NullableStringSet>()) {}
Set value;
bool init = false;

void reset() {
init = false;
value = std::make_unique<NullableStringSet>();
}
};

/** Template parameter with true value should be used for columns that store their elements in memory continuously.
Expand All @@ -357,7 +367,7 @@ class AggregateFunctionGroupArrayIntersectGeneric

DataTypePtr get_return_type() const override { return input_data_type; }

bool allocates_memory_in_arena() const override { return true; }
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
Expand Down
8 changes: 0 additions & 8 deletions be/src/vec/aggregate_functions/aggregate_function_null.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,6 @@ class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper<Derived>
nested_function->insert_result_into(nested_place(place), to);
}
}

bool allocates_memory_in_arena() const override {
return nested_function->allocates_memory_in_arena();
}
};

/** There are two cases: for single argument and variadic.
Expand Down Expand Up @@ -329,10 +325,6 @@ class AggregateFunctionNullVariadicInline final
arena);
}

bool allocates_memory_in_arena() const override {
return this->nested_function->allocates_memory_in_arena();
}

private:
// The array length is fixed in the implementation of some aggregate functions.
// Therefore we choose 256 as the appropriate maximum length limit.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ struct AggOrthBitmapBaseData {
public:
using ColVecData = std::conditional_t<IsNumber<T>, ColumnVector<T>, ColumnString>;

void reset() {
bitmap = {};
first_init = true;
}

void add(const IColumn** columns, size_t row_num) {
const auto& bitmap_col =
assert_cast<const ColumnBitmap&, TypeCheckOnRelease::DISABLE>(*columns[0]);
Expand Down Expand Up @@ -99,6 +104,11 @@ struct AggOrthBitMapIntersect : public AggOrthBitmapBaseData<T> {

static DataTypePtr get_return_type() { return std::make_shared<DataTypeBitMap>(); }

void reset() {
AggOrthBitmapBaseData<T>::reset();
result.reset();
}

void merge(const AggOrthBitMapIntersect& rhs) {
if (rhs.first_init) {
return;
Expand All @@ -120,7 +130,8 @@ struct AggOrthBitMapIntersect : public AggOrthBitmapBaseData<T> {

void get(IColumn& to) const {
auto& column = assert_cast<ColumnBitmap&>(to);
column.get_data().emplace_back(result);
column.get_data().emplace_back(result.empty() ? AggOrthBitmapBaseData<T>::bitmap.intersect()
: result);
}

private:
Expand Down Expand Up @@ -170,6 +181,11 @@ struct AggOrthBitMapIntersectCount : public AggOrthBitmapBaseData<T> {

static DataTypePtr get_return_type() { return std::make_shared<DataTypeInt64>(); }

void reset() {
AggOrthBitmapBaseData<T>::reset();
result = 0;
}

void merge(const AggOrthBitMapIntersectCount& rhs) {
if (rhs.first_init) {
return;
Expand Down Expand Up @@ -225,6 +241,11 @@ struct AggOrthBitmapExprCalBaseData {
}
}

void reset() {
bitmap_expr_cal = {};
first_init = true;
}

protected:
doris::BitmapExprCalculation bitmap_expr_cal;
bool first_init = true;
Expand Down Expand Up @@ -263,6 +284,11 @@ struct AggOrthBitMapExprCal : public AggOrthBitmapExprCalBaseData<T> {
->bitmap_expr_cal.bitmap_calculate());
}

void reset() {
AggOrthBitmapExprCalBaseData<T>::reset();
result.reset();
}

private:
BitmapValue result;
};
Expand Down Expand Up @@ -299,6 +325,11 @@ struct AggOrthBitMapExprCalCount : public AggOrthBitmapExprCalBaseData<T> {
->bitmap_expr_cal.bitmap_calculate_count());
}

void reset() {
AggOrthBitmapExprCalBaseData<T>::reset();
result = 0;
}

private:
int64_t result = 0;
};
Expand Down Expand Up @@ -330,6 +361,11 @@ struct OrthBitmapUnionCountData {
column.get_data().emplace_back(result ? result : value.cardinality());
}

void reset() {
value.reset();
result = 0;
}

private:
BitmapValue value;
int64_t result = 0;
Expand All @@ -347,6 +383,8 @@ class AggFunctionOrthBitmapFunc final

DataTypePtr get_return_type() const override { return Impl::get_return_type(); }

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
this->data(place).init_add_key(columns, row_num, _argument_size);
Expand Down
4 changes: 4 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_uniq.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ struct AggregateFunctionUniqExactData {
Set set;

static String get_name() { return "multi_distinct"; }

void reset() { set.clear(); }
};

namespace detail {
Expand Down Expand Up @@ -115,6 +117,8 @@ class AggregateFunctionUniq final

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
detail::OneAdder<T, Data>::add(this->data(place), *columns[0], row_num);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ struct AggregateFunctionUniqDistributeKeyData {

Set set;
UInt64 count = 0;

void reset() {
set.clear();
count = 0;
}
};

template <typename T, typename Data>
Expand All @@ -83,6 +88,8 @@ class AggregateFunctionUniqDistributeKey final

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
detail::OneAdder<T, Data>::add(this->data(place), *columns[0], row_num);
Expand Down
Loading

0 comments on commit 11dfd19

Please sign in to comment.