diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index 752d8b2f314a06..51b8fb32481f3d 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -22,6 +22,7 @@ set(EXECUTABLE_OUTPUT_PATH "${BUILD_DIR}/src/vec") set(VEC_FILES aggregate_functions/aggregate_function_avg.cpp aggregate_functions/aggregate_function_count.cpp + aggregate_functions/aggregate_function_distinct.cpp aggregate_functions/aggregate_function_sum.cpp aggregate_functions/aggregate_function_min_max.cpp aggregate_functions/aggregate_function_null.cpp diff --git a/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp new file mode 100644 index 00000000000000..7875b48c09972b --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp @@ -0,0 +1,79 @@ +#include "vec/aggregate_functions/aggregate_function_distinct.h" + +#include + +#include "boost/algorithm/string.hpp" +#include "vec/aggregate_functions/aggregate_function_combinator.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/common/typeid_cast.h" +#include "vec/data_types/data_type_nullable.h" +// #include "registerAggregateFunctions.h" + +namespace doris::vectorized { +namespace ErrorCodes { +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +class AggregateFunctionCombinatorDistinct final : public IAggregateFunctionCombinator { +public: + String getName() const override { return "Distinct"; } + + DataTypes transformArguments(const DataTypes& arguments) const override { + if (arguments.empty()) + throw Exception( + "Incorrect number of arguments for aggregate function with Distinct suffix", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + return arguments; + } + + AggregateFunctionPtr transformAggregateFunction(const AggregateFunctionPtr& nested_function, + const DataTypes& arguments, + const Array& params) const override { + AggregateFunctionPtr res; + if (arguments.size() == 1) { + res.reset(createWithNumericType( + *arguments[0], nested_function, arguments)); + + if (res) return res; + + if (arguments[0]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion()) + return std::make_shared>>(nested_function, + arguments); + else + return std::make_shared>>(nested_function, + arguments); + } + + return std::make_shared< + AggregateFunctionDistinct>( + nested_function, arguments); + } +}; + +const std::string DISTINCT_FUNCTION_PREFIX = "multi_distinct_"; + +void registerAggregateFunctionCombinatorDistinct(AggregateFunctionSimpleFactory& factory) { + AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, + const Array& params) { + // 1. we should get not nullable types; + DataTypes nested_types(types.size()); + std::transform(types.begin(), types.end(), nested_types.begin(), + [](const auto& e) { return removeNullable(e); }); + auto function_combinator = std::make_shared(); + auto transformArguments = function_combinator->transformArguments(nested_types); + if (!boost::algorithm::starts_with(name, DISTINCT_FUNCTION_PREFIX)) { + return AggregateFunctionPtr(); + } + auto nested_function_name = name.substr(DISTINCT_FUNCTION_PREFIX.size()); + auto nested_function = factory.get(nested_function_name, transformArguments, params); + return function_combinator->transformAggregateFunction(nested_function, types, params); + }; + factory.registerDistinctFunctionCombinator(creator, DISTINCT_FUNCTION_PREFIX); + // factory.registerCombinator(std::make_shared()); +} +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_distinct.h b/be/src/vec/aggregate_functions/aggregate_function_distinct.h new file mode 100644 index 00000000000000..3648bed0f0ac29 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_distinct.h @@ -0,0 +1,203 @@ +#pragma once + +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/key_holder_helpers.h" + +// #include +// #include +#include "vec/common/aggregation_common.h" +#include "vec/common/assert_cast.h" +#include "vec/common/field_visitors.h" +#include "vec/common/hash_table/hash_set.h" +#include "vec/common/hash_table/hash_table.h" +#include "vec/common/sip_hash.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +template +struct AggregateFunctionDistinctSingleNumericData { + /// When creating, the hash table must be small. + using Set = HashSetWithStackMemory, 4>; + using Self = AggregateFunctionDistinctSingleNumericData; + Set set; + + void add(const IColumn** columns, size_t /* columns_num */, size_t row_num, Arena*) { + const auto& vec = assert_cast&>(*columns[0]).getData(); + set.insert(vec[row_num]); + } + + void merge(const Self& rhs, Arena*) { set.merge(rhs.set); } + + void serialize(std::ostream& buf) const { set.write(buf); } + + void deserialize(std::istream& buf, Arena*) { set.read(buf); } + + MutableColumns getArguments(const DataTypes& argument_types) const { + MutableColumns argument_columns; + argument_columns.emplace_back(argument_types[0]->createColumn()); + for (const auto& elem : set) argument_columns[0]->insert(elem.getValue()); + + return argument_columns; + } +}; + +struct AggregateFunctionDistinctGenericData { + /// When creating, the hash table must be small. + using Set = HashSetWithSavedHashWithStackMemory; + using Self = AggregateFunctionDistinctGenericData; + Set set; + + void merge(const Self& rhs, Arena* arena) { + Set::LookupResult it; + bool inserted; + for (const auto& elem : rhs.set) + set.emplace(ArenaKeyHolder{elem.getValue(), *arena}, it, inserted); + } + + void serialize(std::ostream& buf) const { + writeVarUInt(set.size(), buf); + for (const auto& elem : set) writeStringBinary(elem.getValue(), buf); + } + + void deserialize(std::istream& buf, Arena* arena) { + size_t size; + readVarUInt(size, buf); + for (size_t i = 0; i < size; ++i) set.insert(readStringBinaryInto(*arena, buf)); + } +}; + +template +struct AggregateFunctionDistinctSingleGenericData : public AggregateFunctionDistinctGenericData { + void add(const IColumn** columns, size_t /* columns_num */, size_t row_num, Arena* arena) { + Set::LookupResult it; + bool inserted; + auto key_holder = getKeyHolder(*columns[0], row_num, *arena); + set.emplace(key_holder, it, inserted); + } + + MutableColumns getArguments(const DataTypes& argument_types) const { + MutableColumns argument_columns; + argument_columns.emplace_back(argument_types[0]->createColumn()); + for (const auto& elem : set) + deserializeAndInsert(elem.getValue(), *argument_columns[0]); + + return argument_columns; + } +}; + +struct AggregateFunctionDistinctMultipleGenericData : public AggregateFunctionDistinctGenericData { + void add(const IColumn** columns, size_t columns_num, size_t row_num, Arena* arena) { + const char* begin = nullptr; + StringRef value(begin, 0); + for (size_t i = 0; i < columns_num; ++i) { + auto cur_ref = columns[i]->serializeValueIntoArena(row_num, *arena, begin); + value.data = cur_ref.data - value.size; + value.size += cur_ref.size; + } + + Set::LookupResult it; + bool inserted; + auto key_holder = SerializedKeyHolder{value, *arena}; + set.emplace(key_holder, it, inserted); + } + + MutableColumns getArguments(const DataTypes& argument_types) const { + MutableColumns argument_columns(argument_types.size()); + for (size_t i = 0; i < argument_types.size(); ++i) + argument_columns[i] = argument_types[i]->createColumn(); + + for (const auto& elem : set) { + const char* begin = elem.getValue().data; + for (auto& column : argument_columns) + begin = column->deserializeAndInsertFromArena(begin); + } + + return argument_columns; + } +}; + +/** Adaptor for aggregate functions. + * Adding -Distinct suffix to aggregate function +**/ +template +class AggregateFunctionDistinct + : public IAggregateFunctionDataHelper> { +private: + static constexpr auto prefix_size = sizeof(Data); + AggregateFunctionPtr nested_func; + size_t arguments_num; + + AggregateDataPtr getNestedPlace(AggregateDataPtr __restrict place) const noexcept { + return place + prefix_size; + } + + ConstAggregateDataPtr getNestedPlace(ConstAggregateDataPtr __restrict place) const noexcept { + return place + prefix_size; + } + +public: + AggregateFunctionDistinct(AggregateFunctionPtr nested_func_, const DataTypes& arguments) + : IAggregateFunctionDataHelper( + arguments, nested_func_->getParameters()), + nested_func(nested_func_), + arguments_num(arguments.size()) {} + + void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + Arena* arena) const override { + this->data(place).add(columns, arguments_num, row_num, arena); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena* arena) const override { + this->data(place).merge(this->data(rhs), arena); + } + + void serialize(ConstAggregateDataPtr place, std::ostream& buf) const override { + this->data(place).serialize(buf); + } + + void deserialize(AggregateDataPtr place, std::istream& buf, Arena* arena) const override { + this->data(place).deserialize(buf, arena); + } + + // void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * arena) const override + void insertResultInto(ConstAggregateDataPtr targetplace, IColumn& to) const override { + auto place = const_cast(targetplace); + auto arguments = this->data(place).getArguments(this->argument_types); + ColumnRawPtrs arguments_raw(arguments.size()); + for (size_t i = 0; i < arguments.size(); ++i) arguments_raw[i] = arguments[i].get(); + + assert(!arguments.empty()); + // nested_func->addBatchSinglePlace(arguments[0]->size(), getNestedPlace(place), arguments_raw.data(), arena); + // nested_func->insertResultInto(getNestedPlace(place), to, arena); + + nested_func->addBatchSinglePlace(arguments[0]->size(), getNestedPlace(place), + arguments_raw.data(), nullptr); + nested_func->insertResultInto(getNestedPlace(place), to); + } + + size_t sizeOfData() const override { return prefix_size + nested_func->sizeOfData(); } + + void create(AggregateDataPtr place) const override { + new (place) Data; + nested_func->create(getNestedPlace(place)); + } + + void destroy(AggregateDataPtr place) const noexcept override { + this->data(place).~Data(); + nested_func->destroy(getNestedPlace(place)); + } + + String getName() const override { return nested_func->getName() + "Distinct"; } + + DataTypePtr getReturnType() const override { return nested_func->getReturnType(); } + + bool allocatesMemoryInArena() const override { return true; } + + const char* getHeaderFilePath() const override { return __FILE__; } + + // AggregateFunctionPtr getNestedFunction() const override { return nested_func; } +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h index f408a2a8526ccc..f82a944f0a4833 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h @@ -36,6 +36,7 @@ void registerAggregateFunctionMinMax(AggregateFunctionSimpleFactory& factory); void registerAggregateFunctionAvg(AggregateFunctionSimpleFactory& factory); void registerAggregateFunctionCount(AggregateFunctionSimpleFactory& factory); void registerAggregateFunctionsUniq(AggregateFunctionSimpleFactory& factory); +void registerAggregateFunctionCombinatorDistinct(AggregateFunctionSimpleFactory& factory); using DataTypePtr = std::shared_ptr; using DataTypes = std::vector; @@ -61,6 +62,19 @@ class AggregateFunctionSimpleFactory { } } + void registerDistinctFunctionCombinator(Creator creator, const std::string& prefix) { + std::vector need_insert; + for (auto entity : aggregate_functions) { + std::string target_value = prefix + entity.first; + if (aggregate_functions[target_value] == nullptr) { + need_insert.emplace_back(std::move(target_value)); + } + } + for (const auto& function_name : need_insert) { + aggregate_functions[function_name] = creator; + } + } + void registerFunction(const std::string& name, Creator creator, bool nullable = false) { if (nullable) { nullable_aggregate_functions[name] = creator; @@ -98,6 +112,7 @@ class AggregateFunctionSimpleFactory { registerAggregateFunctionAvg(instance); registerAggregateFunctionCount(instance); registerAggregateFunctionsUniq(instance); + registerAggregateFunctionCombinatorDistinct(instance); registerAggregateFunctionCombinatorNull(instance); }); return instance; diff --git a/be/src/vec/aggregate_functions/helpers.h b/be/src/vec/aggregate_functions/helpers.h index eb3e4e425b9d5f..744afe8111dca0 100644 --- a/be/src/vec/aggregate_functions/helpers.h +++ b/be/src/vec/aggregate_functions/helpers.h @@ -99,6 +99,20 @@ static IAggregateFunction* createWithNumericType(const IDataType& argument_type, return nullptr; } +template