forked from apache/doris
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. Fix the bug of sort node return empty block if child eos is true a…
…nd add some comment (apache#23) 2. Use SIMD to speed up has_null() in column nullable 3. Support UDAF of avg Change-Id: I13846d7275e1cc37085d3afbf41d60261e296662 Co-authored-by: lihaopeng <lihaopeng@baidu.com>
- Loading branch information
Showing
11 changed files
with
212 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#include "vec/aggregate_functions/aggregate_function_simple_factory.h" | ||
#include "vec/aggregate_functions/aggregate_function_avg.h" | ||
#include "vec/aggregate_functions/helpers.h" | ||
#include "vec/aggregate_functions/factory_helpers.h" | ||
|
||
namespace doris::vectorized | ||
{ | ||
|
||
namespace | ||
{ | ||
|
||
template <typename T> | ||
struct Avg | ||
{ | ||
using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>; | ||
using Function = AggregateFunctionAvg<T, AggregateFunctionAvgData<FieldType>>; | ||
}; | ||
|
||
template <typename T> | ||
using AggregateFuncAvg = typename Avg<T>::Function; | ||
|
||
AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters) | ||
{ | ||
assertNoParameters(name, parameters); | ||
assertUnary(name, argument_types); | ||
|
||
AggregateFunctionPtr res; | ||
DataTypePtr data_type = argument_types[0]; | ||
if (isDecimal(data_type)) | ||
res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type, *data_type, argument_types)); | ||
else | ||
res.reset(createWithNumericType<AggregateFuncAvg>(*data_type, argument_types)); | ||
|
||
if (!res) | ||
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, | ||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); | ||
return res; | ||
} | ||
|
||
} | ||
|
||
//void registerAggregateFunctionAvg(AggregateFunctionFactory & factory) | ||
//{ | ||
// factory.registerFunction("avg", createAggregateFunctionAvg, AggregateFunctionFactory::CaseInsensitive); | ||
//} | ||
|
||
void registerAggregateFunctionAvg(AggregateFunctionSimpleFactory& factory) { | ||
factory.registerFunction("avg", createAggregateFunctionAvg); | ||
} | ||
} |
115 changes: 115 additions & 0 deletions
115
be/src/vec/aggregate_functions/aggregate_function_avg.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#pragma once | ||
|
||
#include "vec/data_types/data_types_number.h" | ||
#include "vec/data_types/data_types_decimal.h" | ||
#include "vec/columns/columns_number.h" | ||
#include "vec/aggregate_functions/aggregate_function.h" | ||
#include "vec/io/io_helper.h" | ||
|
||
namespace doris::vectorized | ||
{ | ||
|
||
namespace ErrorCodes | ||
{ | ||
extern const int LOGICAL_ERROR; | ||
} | ||
|
||
template <typename T> | ||
struct AggregateFunctionAvgData | ||
{ | ||
T sum = 0; | ||
UInt64 count = 0; | ||
|
||
template <typename ResultT> | ||
ResultT NO_SANITIZE_UNDEFINED result() const | ||
{ | ||
if constexpr (std::is_floating_point_v<ResultT>) | ||
if constexpr (std::numeric_limits<ResultT>::is_iec559) | ||
return static_cast<ResultT>(sum) / count; /// allow division by zero | ||
|
||
if (!count) | ||
throw Exception("AggregateFunctionAvg with zero values", ErrorCodes::LOGICAL_ERROR); | ||
return static_cast<ResultT>(sum) / count; | ||
} | ||
|
||
void write(std::ostream& buf) const { | ||
writeBinary(sum, buf); | ||
writeBinary(count, buf); | ||
} | ||
|
||
void read(std::istream& buf) { | ||
readBinary(sum, buf); | ||
readBinary(count, buf); | ||
} | ||
}; | ||
|
||
|
||
/// Calculates arithmetic mean of numbers. | ||
template <typename T, typename Data> | ||
class AggregateFunctionAvg final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>> | ||
{ | ||
public: | ||
using ResultType = std::conditional_t<IsDecimalNumber<T>, Decimal128, Float64>; | ||
using ResultDataType = std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<Decimal128>, DataTypeNumber<Float64>>; | ||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>; | ||
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<Float64>>; | ||
|
||
/// ctor for native types | ||
AggregateFunctionAvg(const DataTypes & argument_types_) | ||
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_, {}) | ||
, scale(0) | ||
{} | ||
|
||
/// ctor for Decimals | ||
AggregateFunctionAvg(const IDataType & data_type, const DataTypes & argument_types_) | ||
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_, {}) | ||
, scale(getDecimalScale(data_type)) | ||
{} | ||
|
||
String getName() const override { return "avg"; } | ||
|
||
DataTypePtr getReturnType() const override | ||
{ | ||
if constexpr (IsDecimalNumber<T>) | ||
return std::make_shared<ResultDataType>(ResultDataType::maxPrecision(), scale); | ||
else | ||
return std::make_shared<ResultDataType>(); | ||
} | ||
|
||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override | ||
{ | ||
const auto & column = static_cast<const ColVecType &>(*columns[0]); | ||
this->data(place).sum += column.getData()[row_num]; | ||
++this->data(place).count; | ||
} | ||
|
||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override | ||
{ | ||
this->data(place).sum += this->data(rhs).sum; | ||
this->data(place).count += this->data(rhs).count; | ||
} | ||
|
||
void serialize(ConstAggregateDataPtr place, std::ostream& buf) const override | ||
{ | ||
this->data(place).write(buf); | ||
} | ||
|
||
void deserialize(AggregateDataPtr place, std::istream& buf, Arena *) const override | ||
{ | ||
this->data(place).read(buf); | ||
} | ||
|
||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override | ||
{ | ||
auto & column = static_cast<ColVecResult &>(to); | ||
column.getData().push_back(this->data(place).template result<ResultType>()); | ||
} | ||
|
||
const char * getHeaderFilePath() const override { return __FILE__; } | ||
|
||
private: | ||
UInt32 scale; | ||
}; | ||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters