From fdf88387a05ba5a55ed5bb5d5296bdbcb49e66c2 Mon Sep 17 00:00:00 2001 From: HappenLee Date: Thu, 16 Jan 2025 21:59:47 +0800 Subject: [PATCH 1/2] [Fix](bug) Percentile* func core when percent args is negative number --- .../aggregate_function_percentile.h | 56 ++++++++++++------- .../aggregate_function_simple_factory.h | 9 --- .../functions/agg/PercentileArray.java | 9 +++ .../test_aggregate_all_functions.out | 14 +++++ .../query_p0/aggregate/aggregate.groovy | 17 ++++++ .../test_aggregate_all_functions.groovy | 24 ++++++++ 6 files changed, 100 insertions(+), 29 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.h b/be/src/vec/aggregate_functions/aggregate_function_percentile.h index 0766c59f3de1c3..dbd52af923f71b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile.h +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.h @@ -51,12 +51,20 @@ namespace doris::vectorized { class Arena; class BufferReadable; +inline void check_quantile(double quantile) { + if (quantile < 0 || quantile > 1) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "quantile in func percentile should in [0, 1], but real data is:" + + std::to_string(quantile)); + } +} + struct PercentileApproxState { static constexpr double INIT_QUANTILE = -1.0; PercentileApproxState() = default; ~PercentileApproxState() = default; - void init(double compression = 10000) { + void init(double quantile, double compression = 10000) { if (!init_flag) { //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description //The compression parameter setting range is [2048, 10000]. @@ -66,6 +74,8 @@ struct PercentileApproxState { compression = 10000; } digest = TDigest::create_unique(compression); + check_quantile(quantile); + target_quantile = quantile; compressions = compression; init_flag = true; } @@ -126,18 +136,14 @@ struct PercentileApproxState { } } - void add(double source, double quantile) { - digest->add(source); - target_quantile = quantile; - } + void add(double source) { digest->add(source); } - void add_with_weight(double source, double weight, double quantile) { + void add_with_weight(double source, double weight) { // the weight should be positive num, as have check the value valid use DCHECK_GT(c._weight, 0); if (weight <= 0) { return; } digest->add(source, weight); - target_quantile = quantile; } void reset() { @@ -192,8 +198,8 @@ class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPerce assert_cast(*columns[0]); const auto& quantile = assert_cast(*columns[1]); - this->data(place).init(); - this->data(place).add(sources.get_element(row_num), quantile.get_element(row_num)); + this->data(place).init(quantile.get_element(0)); + this->data(place).add(sources.get_element(row_num)); } DataTypePtr get_return_type() const override { return std::make_shared(); } @@ -223,8 +229,8 @@ class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPer const auto& compression = assert_cast(*columns[2]); - this->data(place).init(compression.get_element(row_num)); - this->data(place).add(sources.get_element(row_num), quantile.get_element(row_num)); + this->data(place).init(quantile.get_element(0), compression.get_element(0)); + this->data(place).add(sources.get_element(row_num)); } DataTypePtr get_return_type() const override { return std::make_shared(); } @@ -256,9 +262,9 @@ class AggregateFunctionPercentileApproxWeightedThreeParams const auto& quantile = assert_cast&, TypeCheckOnRelease::DISABLE>(*columns[2]); - this->data(place).init(); - this->data(place).add_with_weight(sources.get_element(row_num), weight.get_element(row_num), - quantile.get_element(row_num)); + this->data(place).init(quantile.get_element(0)); + this->data(place).add_with_weight(sources.get_element(row_num), + weight.get_element(row_num)); } DataTypePtr get_return_type() const override { return std::make_shared(); } @@ -291,9 +297,9 @@ class AggregateFunctionPercentileApproxWeightedFourParams const auto& compression = assert_cast&, TypeCheckOnRelease::DISABLE>(*columns[3]); - this->data(place).init(compression.get_element(row_num)); - this->data(place).add_with_weight(sources.get_element(row_num), weight.get_element(row_num), - quantile.get_element(row_num)); + this->data(place).init(quantile.get_element(0), compression.get_element(0)); + this->data(place).add_with_weight(sources.get_element(row_num), + weight.get_element(row_num)); } DataTypePtr get_return_type() const override { return std::make_shared(); } @@ -351,12 +357,19 @@ struct PercentileState { } } - void add(T source, const PaddedPODArray& quantiles, int arg_size) { + void add(T source, const PaddedPODArray& quantiles, const NullMap& null_maps, + int arg_size) { if (!inited_flag) { vec_counts.resize(arg_size); vec_quantile.resize(arg_size, -1); inited_flag = true; for (int i = 0; i < arg_size; ++i) { + // throw Exception func call percentile_array(id, [1,0,null]) + if (null_maps[i]) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "quantiles in func percentile_array should not have null"); + } + check_quantile(quantiles[i]); vec_quantile[i] = quantiles[i]; } } @@ -429,7 +442,7 @@ class AggregateFunctionPercentile final const auto& quantile = assert_cast(*columns[1]); AggregateFunctionPercentile::data(place).add(sources.get_data()[row_num], - quantile.get_data(), 1); + quantile.get_data(), NullMap(1, 0), 1); } void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, @@ -490,6 +503,9 @@ class AggregateFunctionPercentileArray final const auto& quantile_array = assert_cast(*columns[1]); const auto& offset_column_data = quantile_array.get_offsets(); + const auto& null_maps = assert_cast( + quantile_array.get_data()) + .get_null_map_data(); const auto& nested_column = assert_cast( quantile_array.get_data()) .get_nested_column(); @@ -497,7 +513,7 @@ class AggregateFunctionPercentileArray final assert_cast(nested_column); AggregateFunctionPercentileArray::data(place).add( - sources.get_int(row_num), nested_column_data.get_data(), + sources.get_int(row_num), nested_column_data.get_data(), null_maps, offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h index 842170c18eba45..4807600ad2fe4e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h @@ -62,15 +62,6 @@ class AggregateFunctionSimpleFactory { std::unordered_map function_alias; public: - void register_nullable_function_combinator(const Creator& creator) { - for (const auto& entity : aggregate_functions) { - if (nullable_aggregate_functions.find(entity.first) == - nullable_aggregate_functions.end()) { - nullable_aggregate_functions[entity.first] = creator; - } - } - } - static bool is_foreach(const std::string& name) { constexpr std::string_view suffix = "_foreach"; if (name.length() < suffix.length()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java index 1abbe4d5450531..61d4a328c0f177 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral; @@ -74,6 +75,14 @@ public PercentileArray(boolean distinct, Expression arg0, Expression arg1) { super("percentile_array", distinct, arg0, arg1); } + @Override + public void checkLegalityBeforeTypeCoercion() { + if (!getArgument(1).isConstant()) { + throw new AnalysisException( + "percentile_array requires second parameter must be a constant : " + this.toSql()); + } + } + /** * withDistinctAndChildren. */ diff --git a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out index 75d9a18679f78a..90953b0a11c84f 100644 --- a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out +++ b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out @@ -150,6 +150,13 @@ beijing chengdu shanghai 5 29.1 6 101.1 +-- !select22_1_1 -- +1 \N +2 \N +3 \N +5 \N +6 \N + -- !select23 -- 1 10.0 2 224.5 @@ -192,6 +199,13 @@ beijing chengdu shanghai 5 29.0 6 101.0 +-- !select28_1 -- +1 \N +2 \N +3 \N +5 \N +6 \N + -- !select29 -- 1 0.0 2 216.5 diff --git a/regression-test/suites/query_p0/aggregate/aggregate.groovy b/regression-test/suites/query_p0/aggregate/aggregate.groovy index b611ff92b0eaba..6079d09577f496 100644 --- a/regression-test/suites/query_p0/aggregate/aggregate.groovy +++ b/regression-test/suites/query_p0/aggregate/aggregate.groovy @@ -141,6 +141,23 @@ suite("aggregate") { qt_aggregate32" select topn_weighted(c_string,c_bigint,3) from ${tableName}" qt_aggregate33" select avg_weighted(c_double,c_bigint) from ${tableName};" qt_aggregate34" select percentile_array(c_bigint,[0.2,0.5,0.9]) from ${tableName};" + + try { + sql "select percentile_array(c_bigint,[-1,0.5,0.9]) from ${tableName};" + } catch (Exception ex) { + assert("${ex}".contains("-1")) + } + try { + sql "select percentile_array(c_bigint,[0.5,0.9,3000]) from ${tableName};" + } catch (Exception ex) { + assert("${ex}".contains("3000")) + } + try { + sql "select percentile_array(c_bigint,[0.5,0.9,null]) from ${tableName};" + } catch (Exception ex) { + assert("${ex}".contains("null")) + } + qt_aggregate """ SELECT c_bigint, CASE diff --git a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy index cdab9472e27dbd..c64d33e1e82b81 100644 --- a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy +++ b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy @@ -286,6 +286,18 @@ suite("test_aggregate_all_functions", "arrow_flight_sql") { qt_select20_1 "select id,percentile(level + 0.1,0.5) from ${tableName_13} group by id order by id" qt_select21_1 "select id,percentile(level + 0.1,0.55) from ${tableName_13} group by id order by id" qt_select22_1 "select id,percentile(level + 0.1,0.805) from ${tableName_13} group by id order by id" + qt_select22_1_1 "select id,percentile(level + 0.1, null) from ${tableName_13} group by id order by id" + + try { + sql "select id,percentile(level + 0.1, -1) from ${tableName_13} group by id order by id" + } catch (Exception ex) { + assert("${ex}".contains("-1")) + } + try { + sql "select id,percentile(level + 0.1, 3000) from ${tableName_13} group by id order by id" + } catch (Exception ex) { + assert("${ex}".contains("3000")) + } sql "DROP TABLE IF EXISTS ${tableName_13}" @@ -313,6 +325,18 @@ suite("test_aggregate_all_functions", "arrow_flight_sql") { qt_select26 "select id,PERCENTILE_APPROX(level,0.5,2048) from ${tableName_14} group by id order by id" qt_select27 "select id,PERCENTILE_APPROX(level,0.55,2048) from ${tableName_14} group by id order by id" qt_select28 "select id,PERCENTILE_APPROX(level,0.805,2048) from ${tableName_14} group by id order by id" + qt_select28_1 "select id,PERCENTILE_APPROX(level, null ,2048) from ${tableName_14} group by id order by id" + + try { + sql "select id,PERCENTILE_APPROX(level, -1, 2048) from ${tableName_14} group by id order by id" + } catch (Exception ex) { + assert("${ex}".contains("-1")) + } + try { + sql "select id,PERCENTILE_APPROX(level, 3000 ,2048) from ${tableName_14} group by id order by id" + } catch (Exception ex) { + assert("${ex}".contains("3000")) + } sql "DROP TABLE IF EXISTS ${tableName_14}" From cdfc5aaaaad5c6f173d20d84c05b23f55c128d74 Mon Sep 17 00:00:00 2001 From: HappenLee Date: Fri, 17 Jan 2025 18:24:56 +0800 Subject: [PATCH 2/2] fix regression test --- .../combinator/ForEachCombinator.java | 29 ++++++++++++++++++ .../data/function_p0/test_agg_foreach.out | 9 ------ .../function_p0/test_agg_foreach_notnull.out | 9 ------ .../function_p0/test_agg_foreach.groovy | 26 +++++++++------- .../test_agg_foreach_notnull.groovy | 30 +++++++++++-------- 5 files changed, 63 insertions(+), 40 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java index a6d011ff0fbbf4..ddd92f894e15a0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java @@ -30,8 +30,11 @@ import com.google.common.collect.ImmutableList; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Set; /** * combinator foreach @@ -39,6 +42,15 @@ public class ForEachCombinator extends NullableAggregateFunction implements UnaryExpression, ExplicitlyCastableSignature, Combinator { + public static final Set UNSUPPORTED_AGGREGATE_FUNCTION = Collections.unmodifiableSet(new HashSet() { + { + add("percentile"); + add("percentile_array"); + add("percentile_approx"); + add("percentile_approx_weighted"); + } + }); + private final AggregateFunction nested; /** @@ -48,10 +60,27 @@ public ForEachCombinator(List arguments, AggregateFunction nested) { this(arguments, false, nested); } + /** + * Constructs a new instance of {@code ForEachCombinator}. + * + *

This constructor initializes a combinator that will iterate over each item in the input list + * and apply the nested aggregate function. + * If the provided aggregate function name is within the list of unsupported functions, + * an {@link UnsupportedOperationException} will be thrown. + * + * @param arguments A list of {@code Expression} objects that serve as parameters to the aggregate function. + * @param alwaysNullable A boolean flag indicating whether this combinator should always return a nullable result. + * @param nested The nested aggregate function to apply to each element. It must not be {@code null}. + * @throws NullPointerException If the provided nested aggregate function is {@code null}. + * @throws UnsupportedOperationException If nested aggregate function is one of the unsupported aggregate functions + */ public ForEachCombinator(List arguments, boolean alwaysNullable, AggregateFunction nested) { super(nested.getName() + AggCombinerFunctionBuilder.FOREACH_SUFFIX, false, alwaysNullable, arguments); this.nested = Objects.requireNonNull(nested, "nested can not be null"); + if (UNSUPPORTED_AGGREGATE_FUNCTION.contains(nested.getName().toLowerCase())) { + throw new UnsupportedOperationException("Unsupport the func:" + nested.getName() + " use in foreach"); + } } public static ForEachCombinator create(AggregateFunction nested) { diff --git a/regression-test/data/function_p0/test_agg_foreach.out b/regression-test/data/function_p0/test_agg_foreach.out index c45ae9f67a974a..693009d389044f 100644 --- a/regression-test/data/function_p0/test_agg_foreach.out +++ b/regression-test/data/function_p0/test_agg_foreach.out @@ -17,15 +17,6 @@ -- !sql -- ["{"num_buckets":3,"buckets":[{"lower":"1","upper":"1","ndv":1,"count":1,"pre_sum":0},{"lower":"20","upper":"20","ndv":1,"count":1,"pre_sum":1},{"lower":"100","upper":"100","ndv":1,"count":1,"pre_sum":2}]}", "{"num_buckets":1,"buckets":[{"lower":"2","upper":"2","ndv":1,"count":2,"pre_sum":0}]}", "{"num_buckets":1,"buckets":[{"lower":"3","upper":"3","ndv":1,"count":1,"pre_sum":0}]}"] --- !sql -- -[100, 2, 3] - --- !sql -- -[[1], [2, 2, 2], [3]] - --- !sql -- -[0, 0, 0] - -- !sql -- [0, 2, 3] [117, 2, 3] [113, 0, 3] diff --git a/regression-test/data/function_p0/test_agg_foreach_notnull.out b/regression-test/data/function_p0/test_agg_foreach_notnull.out index c45ae9f67a974a..693009d389044f 100644 --- a/regression-test/data/function_p0/test_agg_foreach_notnull.out +++ b/regression-test/data/function_p0/test_agg_foreach_notnull.out @@ -17,15 +17,6 @@ -- !sql -- ["{"num_buckets":3,"buckets":[{"lower":"1","upper":"1","ndv":1,"count":1,"pre_sum":0},{"lower":"20","upper":"20","ndv":1,"count":1,"pre_sum":1},{"lower":"100","upper":"100","ndv":1,"count":1,"pre_sum":2}]}", "{"num_buckets":1,"buckets":[{"lower":"2","upper":"2","ndv":1,"count":2,"pre_sum":0}]}", "{"num_buckets":1,"buckets":[{"lower":"3","upper":"3","ndv":1,"count":1,"pre_sum":0}]}"] --- !sql -- -[100, 2, 3] - --- !sql -- -[[1], [2, 2, 2], [3]] - --- !sql -- -[0, 0, 0] - -- !sql -- [0, 2, 3] [117, 2, 3] [113, 0, 3] diff --git a/regression-test/suites/function_p0/test_agg_foreach.groovy b/regression-test/suites/function_p0/test_agg_foreach.groovy index 281fdea6a3bea7..fad9925af81c3b 100644 --- a/regression-test/suites/function_p0/test_agg_foreach.groovy +++ b/regression-test/suites/function_p0/test_agg_foreach.groovy @@ -87,18 +87,24 @@ suite("test_agg_foreach") { select histogram_foreach(a) from foreach_table; """ - qt_sql """ - select PERCENTILE_foreach(a,a) from foreach_table; - """ + try { + sql "select PERCENTILE_foreach(a,a) from foreach_table;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } - qt_sql """ - select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id = 1; - """ - qt_sql """ - - select PERCENTILE_APPROX_foreach(a,a) from foreach_table; - """ + try { + sql "select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id = 1;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } + + try { + sql "select PERCENTILE_APPROX_foreach(a,a) from foreach_table;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } qt_sql """ select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), GROUP_BIT_XOR_foreach(a) from foreach_table; diff --git a/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy b/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy index 91f4ea902dd6a3..68f27e6d049e05 100644 --- a/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy +++ b/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy @@ -85,20 +85,26 @@ suite("test_agg_foreach_not_null") { qt_sql """ select histogram_foreach(a) from foreach_table_not_null; """ - - qt_sql """ - select PERCENTILE_foreach(a,a) from foreach_table_not_null; - """ - - qt_sql """ - select PERCENTILE_ARRAY_foreach(a,b) from foreach_table_not_null where id = 1; - """ - - qt_sql """ - select PERCENTILE_APPROX_foreach(a,a) from foreach_table_not_null; - """ + try { + sql "select PERCENTILE_foreach(a,a) from foreach_table_not_null;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } + + try { + sql "select PERCENTILE_ARRAY_foreach(a,b) from foreach_table_not_null where id = 1;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } + + try { + sql "select PERCENTILE_APPROX_foreach(a,a) from foreach_table_not_null;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } + qt_sql """ select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), GROUP_BIT_XOR_foreach(a) from foreach_table_not_null; """