Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat](skew & kurt) New aggregate function skew & kurt #40945

Merged
merged 8 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions be/src/pipeline/exec/aggregation_source_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,11 @@ Status AggLocalState::_get_without_key_result(RuntimeState* state, vectorized::B
}
}

// Result of operator is nullable, but aggregate function result is not nullable
// this happens when:
// 1. no group by
// 2. input of aggregate function is empty
// 3. all of input columns are not nullable
if (column_type->is_nullable() && !data_types[i]->is_nullable()) {
vectorized::ColumnPtr ptr = std::move(columns[i]);
// unless `count`, other aggregate function dispose empty set should be null
Expand Down
80 changes: 80 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/aggregate_function_statistic.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_nullable.h"

namespace doris::vectorized {

template <typename T>
AggregateFunctionPtr type_dispatch_for_aggregate_function_kurt(const DataTypes& argument_types,
const bool result_is_nullable,
bool nullable_input) {
using StatFunctionTemplate = StatFuncOneArg<T, 4>;

if (nullable_input) {
return creator_without_type::create_ignore_nullable<
AggregateFunctionVarianceSimple<StatFunctionTemplate, true>>(
argument_types, result_is_nullable, STATISTICS_FUNCTION_KIND::KURT_POP);
} else {
return creator_without_type::create_ignore_nullable<
AggregateFunctionVarianceSimple<StatFunctionTemplate, false>>(
argument_types, result_is_nullable, STATISTICS_FUNCTION_KIND::KURT_POP);
}
};

AggregateFunctionPtr create_aggregate_function_kurt(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
if (argument_types.size() != 1) {
LOG(WARNING) << "aggregate function " << name << " requires exactly 1 argument";
return nullptr;
}

if (!result_is_nullable) {
LOG(WARNING) << "aggregate function " << name << " requires nullable result type";
return nullptr;
}

const bool nullable_input = argument_types[0]->is_nullable();
WhichDataType type(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE) \
if (type.idx == TypeIndex::TYPE) \
return type_dispatch_for_aggregate_function_kurt<TYPE>(argument_types, result_is_nullable, \
nullable_input);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH

LOG(WARNING) << "unsupported input type " << argument_types[0]->get_name()
<< " for aggregate function " << name;
return nullptr;
}

void register_aggregate_function_kurtosis(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("kurt", create_aggregate_function_kurt);
factory.register_alias("kurt", "kurt_pop");
factory.register_alias("kurt", "kurtosis");
}

} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ void register_aggregate_function_bitmap_agg(AggregateFunctionSimpleFactory& fact
void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_covar_pop(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_covar_samp(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_skewness(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_kurtosis(AggregateFunctionSimpleFactory& factory);

AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
static std::once_flag oc;
Expand Down Expand Up @@ -119,6 +121,9 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_covar_samp(instance);

register_aggregate_function_combinator_foreach(instance);

register_aggregate_function_skewness(instance);
register_aggregate_function_kurtosis(instance);
});
return instance;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ class AggregateFunctionSimpleFactory {
if (function_alias.contains(name)) {
name_str = function_alias[name];
}

if (nullable) {
return nullable_aggregate_functions.find(name_str) == nullable_aggregate_functions.end()
? nullptr
Expand Down
80 changes: 80 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_skew.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/aggregate_function_statistic.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_nullable.h"

namespace doris::vectorized {

template <typename T>
AggregateFunctionPtr type_dispatch_for_aggregate_function_skew(const DataTypes& argument_types,
const bool result_is_nullable,
bool nullable_input) {
using StatFunctionTemplate = StatFuncOneArg<T, 3>;

if (nullable_input) {
return creator_without_type::create_ignore_nullable<
AggregateFunctionVarianceSimple<StatFunctionTemplate, true>>(
argument_types, result_is_nullable, STATISTICS_FUNCTION_KIND::SKEW_POP);
} else {
return creator_without_type::create_ignore_nullable<
AggregateFunctionVarianceSimple<StatFunctionTemplate, false>>(
argument_types, result_is_nullable, STATISTICS_FUNCTION_KIND::SKEW_POP);
}
};

AggregateFunctionPtr create_aggregate_function_skew(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
if (argument_types.size() != 1) {
LOG(WARNING) << "aggregate function " << name << " requires exactly 1 argument";
return nullptr;
}

if (!result_is_nullable) {
LOG(WARNING) << "aggregate function " << name << " requires nullable result type";
return nullptr;
}

const bool nullable_input = argument_types[0]->is_nullable();
WhichDataType type(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE) \
if (type.idx == TypeIndex::TYPE) \
return type_dispatch_for_aggregate_function_skew<TYPE>(argument_types, result_is_nullable, \
nullable_input);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH

LOG(WARNING) << "unsupported input type " << argument_types[0]->get_name()
<< " for aggregate function " << name;
return nullptr;
}

void register_aggregate_function_skewness(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("skew", create_aggregate_function_skew);
factory.register_alias("skew", "skew_pop");
factory.register_alias("skew", "skewness");
}

} // namespace doris::vectorized
162 changes: 162 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_statistic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once
#include <cmath>
#include <cstdint>
#include <string>
#include <type_traits>

#include "common/exception.h"
#include "common/status.h"
#include "moments.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/moments.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_vector.h"
#include "vec/common/assert_cast.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"

namespace doris::vectorized {

enum class STATISTICS_FUNCTION_KIND : uint8_t { SKEW_POP, KURT_POP };

inline std::string to_string(STATISTICS_FUNCTION_KIND kind) {
switch (kind) {
case STATISTICS_FUNCTION_KIND::SKEW_POP:
return "skewness";
case STATISTICS_FUNCTION_KIND::KURT_POP:
return "kurtosis";
default:
return "Unknown";
}
}

template <typename T, std::size_t _level>
struct StatFuncOneArg {
using Type = T;
using Data = VarMoments<Float64, _level>;
};

template <typename StatFunc, bool NullableInput>
class AggregateFunctionVarianceSimple
: public IAggregateFunctionDataHelper<
typename StatFunc::Data,
AggregateFunctionVarianceSimple<StatFunc, NullableInput>> {
public:
using InputCol = ColumnVector<typename StatFunc::Type>;
using ResultCol = ColumnVector<Float64>;

explicit AggregateFunctionVarianceSimple(STATISTICS_FUNCTION_KIND kind_,
const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<
typename StatFunc::Data,
AggregateFunctionVarianceSimple<StatFunc, NullableInput>>(argument_types_),
kind(kind_) {
DCHECK(!argument_types_.empty());
}

String get_name() const override { return to_string(kind); }

DataTypePtr get_return_type() const override {
return make_nullable(std::make_shared<DataTypeFloat64>());
}

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
if constexpr (NullableInput) {
Copy link
Contributor

@HappenLee HappenLee Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should skip the null value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is using creator_without_type::create_ignore_nullable, aggregate_function_null will not be used since this return type is always nullable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

const ColumnNullable& column_with_nullable =
assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[0]);

if (column_with_nullable.is_null_at(row_num)) {
return;
} else {
this->data(place).add(assert_cast<const InputCol&, TypeCheckOnRelease::DISABLE>(
column_with_nullable.get_nested_column())
.get_data()[row_num]);
}

} else {
this->data(place).add(
assert_cast<const InputCol&, TypeCheckOnRelease::DISABLE>(*columns[0])
.get_data()[row_num]);
}
}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
this->data(place).merge(this->data(rhs));
}

void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}

void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
this->data(place).read(buf);
}

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
const auto& data = this->data(place);
ColumnNullable& dst_column_with_nullable = assert_cast<ColumnNullable&>(to);
ResultCol* dst_column =
assert_cast<ResultCol*>(&(dst_column_with_nullable.get_nested_column()));

switch (kind) {
case STATISTICS_FUNCTION_KIND::SKEW_POP: {
// If input is empty set, we will get NAN from get_population()
Float64 var_value = data.get_population();
Float64 moments_3 = data.get_moment_3();

if (std::isnan(var_value) || std::isnan(moments_3) || var_value <= 0) {
dst_column_with_nullable.get_null_map_data().push_back(1);
dst_column->insert_default();
} else {
dst_column_with_nullable.get_null_map_data().push_back(0);
dst_column->get_data().push_back(
static_cast<Float64>(moments_3 / pow(var_value, 1.5)));
}
break;
}
case STATISTICS_FUNCTION_KIND::KURT_POP: {
Float64 var_value = data.get_population();
Float64 moments_4 = data.get_moment_4();

if (std::isnan(var_value) || std::isnan(moments_4) || var_value <= 0) {
dst_column_with_nullable.get_null_map_data().push_back(1);
dst_column->insert_default();
} else {
dst_column_with_nullable.get_null_map_data().push_back(0);
dst_column->get_data().push_back(
static_cast<Float64>(moments_4 / pow(var_value, 2)));
}
break;
}
default:
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "Unknown statistics function kind");
}
}

private:
STATISTICS_FUNCTION_KIND kind;
};

} // namespace doris::vectorized
Loading
Loading