From bbe0a73bb48c1e5fb61925888d91ec1ee8086858 Mon Sep 17 00:00:00 2001 From: Oleg Doronin Date: Sat, 7 Jun 2025 19:00:18 +0000 Subject: [PATCH 1/3] float sum aggregation has been fixed --- ydb/core/formats/arrow/program/functions.cpp | 170 ++++++++++++++++++- ydb/core/formats/arrow/program/ya.make | 4 + ydb/core/kqp/ut/olap/aggregations_ut.cpp | 50 ++++++ 3 files changed, 223 insertions(+), 1 deletion(-) diff --git a/ydb/core/formats/arrow/program/functions.cpp b/ydb/core/formats/arrow/program/functions.cpp index 97cc26d9eccb..de3f9d6095ab 100644 --- a/ydb/core/formats/arrow/program/functions.cpp +++ b/ydb/core/formats/arrow/program/functions.cpp @@ -1,9 +1,176 @@ #include "functions.h" #include +#include +#include +#include #include namespace NKikimr::NArrow::NSSA { + +namespace internal { + +// Find the largest compatible primitive type for a primitive type. +template +struct FindAccumulatorType {}; + +template +struct FindAccumulatorType> { + using Type = arrow::UInt64Type; +}; + +template +struct FindAccumulatorType> { + using Type = arrow::Int64Type; +}; + +template +struct FindAccumulatorType> { + using Type = arrow::UInt64Type; +}; + +template +struct FindAccumulatorType> { + using Type = arrow::DoubleType; +}; + +template <> +struct FindAccumulatorType { + using Type = arrow::FloatType; +}; + +template +struct SumImpl : public arrow::compute::ScalarAggregator { + using ThisType = SumImpl; + using CType = typename ArrowType::c_type; + using SumType = typename FindAccumulatorType::Type; + using OutputType = typename arrow::TypeTraits::ScalarType; + + arrow::Status Consume(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch) override { + if (batch[0].is_array()) { + const auto& data = batch[0].array(); + this->count += data->length - data->GetNullCount(); + if (arrow::is_boolean_type::value) { + this->sum += + static_cast(arrow::BooleanArray(data).true_count()); + } else { + this->sum += + arrow::compute::detail::SumArray( + *data); + } + } else { + const auto& data = *batch[0].scalar(); + this->count += data.is_valid * batch.length; + if (data.is_valid) { + this->sum += arrow::compute::internal::UnboxScalar::Unbox(data) * batch.length; + } + } + return arrow::Status::OK(); + } + + arrow::Status MergeFrom(arrow::compute::KernelContext*, arrow::compute::KernelState&& src) override { + const auto& other = arrow::checked_cast(src); + this->count += other.count; + this->sum += other.sum; + return arrow::Status::OK(); + } + + arrow::Status Finalize(arrow::compute::KernelContext*, arrow::Datum* out) override { + if (this->count < options.min_count) { + out->value = std::make_shared(); + } else { + out->value = arrow::MakeScalar(this->sum); + } + return arrow::Status::OK(); + } + + size_t count = 0; + typename SumType::c_type sum = 0; + arrow::compute::ScalarAggregateOptions options; +}; + +template +struct SumImplDefault : public SumImpl { + explicit SumImplDefault(const arrow::compute::ScalarAggregateOptions& options_) { + this->options = options_; + } +}; + +void AddScalarAggKernels(arrow::compute::KernelInit init, + const std::vector>& types, + std::shared_ptr out_ty, + arrow::compute::ScalarAggregateFunction* func) { + for (const auto& ty : types) { + // scalar[InT] -> scalar[OutT] + auto sig = arrow::compute::KernelSignature::Make({arrow::compute::InputType::Scalar(ty)}, arrow::ValueDescr::Scalar(out_ty)); + AddAggKernel(std::move(sig), init, func, arrow::compute::SimdLevel::NONE); + } +} + +void AddArrayScalarAggKernels(arrow::compute::KernelInit init, + const std::vector>& types, + std::shared_ptr out_ty, + arrow::compute::ScalarAggregateFunction* func, + arrow::compute::SimdLevel::type simd_level = arrow::compute::SimdLevel::NONE) { + arrow::compute::aggregate::AddBasicAggKernels(init, types, out_ty, func, simd_level); + AddScalarAggKernels(init, types, out_ty, func); +} + +arrow::Result> SumInit(arrow::compute::KernelContext* ctx, + const arrow::compute::KernelInitArgs& args) { + arrow::compute::aggregate::SumLikeInit visitor( + ctx, *args.inputs[0].type, + static_cast(*args.options)); + return visitor.Create(); +} + +static std::unique_ptr CreateCustomRegistry() { + arrow::compute::FunctionRegistry* defaultRegistry = arrow::compute::GetFunctionRegistry(); + auto registry = arrow::compute::FunctionRegistry::Make(); + for (const auto& func : defaultRegistry->GetFunctionNames()) { + if (func == "sum") { + auto aggregateFunc = dynamic_cast(defaultRegistry->GetFunction(func)->get()); + if (!aggregateFunc) { + DCHECK_OK(registry->AddFunction(*defaultRegistry->GetFunction(func))); + continue; + } + arrow::compute::ScalarAggregateFunction newFunc(func, aggregateFunc->arity(), &aggregateFunc->doc(), aggregateFunc->default_options()); + for (const arrow::compute::ScalarAggregateKernel* kernel : aggregateFunc->kernels()) { + auto shouldReplaceKernel = [](const arrow::compute::ScalarAggregateKernel& kernel) { + const auto& params = kernel.signature->in_types(); + if (params.empty()) { + return false; + } + + if (params[0].kind() == arrow::compute::InputType::Kind::EXACT_TYPE) { + auto type = params[0].type(); + return type->id() == arrow::Type::FLOAT; + } + + return false; + }; + + if (shouldReplaceKernel(*kernel)) { + AddArrayScalarAggKernels(SumInit, {arrow::float32()}, arrow::float32(), &newFunc); + } else { + DCHECK_OK(newFunc.AddKernel(*kernel)); + } + } + DCHECK_OK(registry->AddFunction(std::make_shared(std::move(newFunc)))); + } else { + DCHECK_OK(registry->AddFunction(*defaultRegistry->GetFunction(func))); + } + } + + return registry; +} +arrow::compute::FunctionRegistry* GetCustomFunctionRegistry() { + static auto registry = internal::CreateCustomRegistry(); + return registry.get(); +} + +} // namespace internal + TConclusion TInternalFunction::Call( const TExecFunctionContext& context, const std::shared_ptr& resources) const { auto funcNames = GetRegistryFunctionNames(); @@ -16,7 +183,8 @@ TConclusion TInternalFunction::Call( if (GetContext() && GetContext()->func_registry()->GetFunction(funcName).ok()) { result = arrow::compute::CallFunction(funcName, *arguments, FunctionOptions.get(), GetContext()); } else { - result = arrow::compute::CallFunction(funcName, *arguments, FunctionOptions.get()); + arrow::compute::ExecContext defaultContext(arrow::default_memory_pool(), nullptr, internal::GetCustomFunctionRegistry()); + result = arrow::compute::CallFunction(funcName, *arguments, FunctionOptions.get(), &defaultContext); } if (result.ok() && funcName == "count"sv) { diff --git a/ydb/core/formats/arrow/program/ya.make b/ydb/core/formats/arrow/program/ya.make index 720b0f0b3180..71f54350d63a 100644 --- a/ydb/core/formats/arrow/program/ya.make +++ b/ydb/core/formats/arrow/program/ya.make @@ -57,4 +57,8 @@ GENERATE_ENUM_SERIALIZATION(execution.h) YQL_LAST_ABI_VERSION() +CFLAGS( + -Wno-unused-parameter +) + END() diff --git a/ydb/core/kqp/ut/olap/aggregations_ut.cpp b/ydb/core/kqp/ut/olap/aggregations_ut.cpp index 70edea0c3ff6..77b46cd2ab32 100644 --- a/ydb/core/kqp/ut/olap/aggregations_ut.cpp +++ b/ydb/core/kqp/ut/olap/aggregations_ut.cpp @@ -1375,6 +1375,56 @@ Y_UNIT_TEST_SUITE(KqpOlapAggregations) { TestTableWithNulls({ testCase }, /* generic */ true); } + + Y_UNIT_TEST(FloatSum) { + NKikimrConfig::TAppConfig appConfig; + appConfig.MutableTableServiceConfig()->SetEnableOlapSink(true); + auto settings = TKikimrSettings() + .SetAppConfig(appConfig) + .SetWithSampleTables(false); + TKikimrRunner kikimr(settings); + + auto queryClient = kikimr.GetQueryClient(); + { + auto status = queryClient.ExecuteQuery( + R"( + CREATE TABLE `olap_table` ( + id Uint64 NOT NULL, + value Float, + PRIMARY KEY (id) + ) WITH (STORE = COLUMN); + )", NYdb::NQuery::TTxControl::NoTx() + ).GetValueSync(); + UNIT_ASSERT_C(status.IsSuccess(), status.GetIssues().ToString()); + } + + { + auto status = queryClient.ExecuteQuery( + R"( + INSERT INTO `olap_table` (id, value) VALUES (1u, 0.4f); + INSERT INTO `olap_table` (id, value) VALUES (2u, 0.85f); + INSERT INTO `olap_table` (id, value) VALUES (3u, 11.3f); + INSERT INTO `olap_table` (id, value) VALUES (4u, 7.15f); + INSERT INTO `olap_table` (id, value) VALUES (5u, 0.3f); + )", NYdb::NQuery::TTxControl::BeginTx().CommitTx() + ).GetValueSync(); + UNIT_ASSERT_C(status.IsSuccess(), status.GetIssues().ToString()); + } + + { + auto status = queryClient.ExecuteQuery(R"( + --!syntax_v1 + SELECT SUM(value) FROM `olap_table` + WHERE id = 1 + )", NYdb::NQuery::TTxControl::BeginTx().CommitTx() + ).GetValueSync(); + + UNIT_ASSERT_C(status.IsSuccess(), status.GetIssues().ToString()); + TString result = FormatResultSetYson(status.GetResultSet(0)); + Cout << result << Endl; + CompareYson(result, R"([[[0.400000006;]]])"); + } + } } } From ca0bef8eac00f663620264d6225fb98000c0c7d1 Mon Sep 17 00:00:00 2001 From: Oleg Doronin Date: Mon, 9 Jun 2025 12:45:47 +0000 Subject: [PATCH 2/3] style has been fixed --- ydb/core/formats/arrow/program/functions.cpp | 28 ++++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/ydb/core/formats/arrow/program/functions.cpp b/ydb/core/formats/arrow/program/functions.cpp index de3f9d6095ab..be20099faca6 100644 --- a/ydb/core/formats/arrow/program/functions.cpp +++ b/ydb/core/formats/arrow/program/functions.cpp @@ -49,20 +49,20 @@ struct SumImpl : public arrow::compute::ScalarAggregator { arrow::Status Consume(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch) override { if (batch[0].is_array()) { const auto& data = batch[0].array(); - this->count += data->length - data->GetNullCount(); + this->Count += data->length - data->GetNullCount(); if (arrow::is_boolean_type::value) { - this->sum += + this->Sum += static_cast(arrow::BooleanArray(data).true_count()); } else { - this->sum += + this->Sum += arrow::compute::detail::SumArray( *data); } } else { const auto& data = *batch[0].scalar(); - this->count += data.is_valid * batch.length; + this->Count += data.is_valid * batch.length; if (data.is_valid) { - this->sum += arrow::compute::internal::UnboxScalar::Unbox(data) * batch.length; + this->Sum += arrow::compute::internal::UnboxScalar::Unbox(data) * batch.length; } } return arrow::Status::OK(); @@ -70,29 +70,29 @@ struct SumImpl : public arrow::compute::ScalarAggregator { arrow::Status MergeFrom(arrow::compute::KernelContext*, arrow::compute::KernelState&& src) override { const auto& other = arrow::checked_cast(src); - this->count += other.count; - this->sum += other.sum; + this->Count += other.Count; + this->Sum += other.Sum; return arrow::Status::OK(); } arrow::Status Finalize(arrow::compute::KernelContext*, arrow::Datum* out) override { - if (this->count < options.min_count) { + if (this->Count < Options.min_count) { out->value = std::make_shared(); } else { - out->value = arrow::MakeScalar(this->sum); + out->value = arrow::MakeScalar(this->Sum); } return arrow::Status::OK(); } - size_t count = 0; - typename SumType::c_type sum = 0; - arrow::compute::ScalarAggregateOptions options; + size_t Count = 0; + typename SumType::c_type Sum = 0; + arrow::compute::ScalarAggregateOptions Options; }; template struct SumImplDefault : public SumImpl { - explicit SumImplDefault(const arrow::compute::ScalarAggregateOptions& options_) { - this->options = options_; + explicit SumImplDefault(const arrow::compute::ScalarAggregateOptions& options) { + this->Options = options; } }; From b7a578f386087c962078a9bf316d91d18785ebd9 Mon Sep 17 00:00:00 2001 From: Oleg Doronin Date: Thu, 12 Jun 2025 17:47:42 +0000 Subject: [PATCH 3/3] cout has been removed --- ydb/core/kqp/ut/olap/aggregations_ut.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ydb/core/kqp/ut/olap/aggregations_ut.cpp b/ydb/core/kqp/ut/olap/aggregations_ut.cpp index 77b46cd2ab32..3847daa5127f 100644 --- a/ydb/core/kqp/ut/olap/aggregations_ut.cpp +++ b/ydb/core/kqp/ut/olap/aggregations_ut.cpp @@ -1421,7 +1421,6 @@ Y_UNIT_TEST_SUITE(KqpOlapAggregations) { UNIT_ASSERT_C(status.IsSuccess(), status.GetIssues().ToString()); TString result = FormatResultSetYson(status.GetResultSet(0)); - Cout << result << Endl; CompareYson(result, R"([[[0.400000006;]]])"); } }