Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions be/src/vec/aggregate_functions/aggregate_function_percentile.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,20 @@ namespace doris::vectorized {
class Arena;
class BufferReadable;

inline void check_quantile(double quantile) {
if (quantile < 0 || quantile > 1) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"quantile in func percentile should in [0, 1], but real data is:" +
std::to_string(quantile));
}
}

struct PercentileApproxState {
static constexpr double INIT_QUANTILE = -1.0;
PercentileApproxState() = default;
~PercentileApproxState() = default;

void init(double compression = 10000) {
void init(double quantile, double compression = 10000) {
if (!init_flag) {
//https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description
//The compression parameter setting range is [2048, 10000].
Expand All @@ -66,6 +74,8 @@ struct PercentileApproxState {
compression = 10000;
}
digest = TDigest::create_unique(compression);
check_quantile(quantile);
target_quantile = quantile;
compressions = compression;
init_flag = true;
}
Expand Down Expand Up @@ -126,18 +136,14 @@ struct PercentileApproxState {
}
}

void add(double source, double quantile) {
digest->add(source);
target_quantile = quantile;
}
void add(double source) { digest->add(source); }

void add_with_weight(double source, double weight, double quantile) {
void add_with_weight(double source, double weight) {
// the weight should be positive num, as have check the value valid use DCHECK_GT(c._weight, 0);
if (weight <= 0) {
return;
}
digest->add(source, weight);
target_quantile = quantile;
}

void reset() {
Expand Down Expand Up @@ -192,8 +198,8 @@ class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPerce
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
this->data(place).init();
this->data(place).add(sources.get_element(row_num), quantile.get_element(row_num));
this->data(place).init(quantile.get_element(0));
this->data(place).add(sources.get_element(row_num));
}

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
Expand Down Expand Up @@ -223,8 +229,8 @@ class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPer
const auto& compression =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]);

this->data(place).init(compression.get_element(row_num));
this->data(place).add(sources.get_element(row_num), quantile.get_element(row_num));
this->data(place).init(quantile.get_element(0), compression.get_element(0));
this->data(place).add(sources.get_element(row_num));
}

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
Expand Down Expand Up @@ -256,9 +262,9 @@ class AggregateFunctionPercentileApproxWeightedThreeParams
const auto& quantile =
assert_cast<const ColumnVector<Float64>&, TypeCheckOnRelease::DISABLE>(*columns[2]);

this->data(place).init();
this->data(place).add_with_weight(sources.get_element(row_num), weight.get_element(row_num),
quantile.get_element(row_num));
this->data(place).init(quantile.get_element(0));
this->data(place).add_with_weight(sources.get_element(row_num),
weight.get_element(row_num));
}

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
Expand Down Expand Up @@ -291,9 +297,9 @@ class AggregateFunctionPercentileApproxWeightedFourParams
const auto& compression =
assert_cast<const ColumnVector<Float64>&, TypeCheckOnRelease::DISABLE>(*columns[3]);

this->data(place).init(compression.get_element(row_num));
this->data(place).add_with_weight(sources.get_element(row_num), weight.get_element(row_num),
quantile.get_element(row_num));
this->data(place).init(quantile.get_element(0), compression.get_element(0));
this->data(place).add_with_weight(sources.get_element(row_num),
weight.get_element(row_num));
}

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
Expand Down Expand Up @@ -351,12 +357,19 @@ struct PercentileState {
}
}

void add(T source, const PaddedPODArray<Float64>& quantiles, int arg_size) {
void add(T source, const PaddedPODArray<Float64>& quantiles, const NullMap& null_maps,
int arg_size) {
if (!inited_flag) {
vec_counts.resize(arg_size);
vec_quantile.resize(arg_size, -1);
inited_flag = true;
for (int i = 0; i < arg_size; ++i) {
// throw Exception func call percentile_array(id, [1,0,null])
if (null_maps[i]) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"quantiles in func percentile_array should not have null");
}
check_quantile(quantiles[i]);
vec_quantile[i] = quantiles[i];
}
}
Expand Down Expand Up @@ -429,7 +442,7 @@ class AggregateFunctionPercentile final
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
AggregateFunctionPercentile::data(place).add(sources.get_data()[row_num],
quantile.get_data(), 1);
quantile.get_data(), NullMap(1, 0), 1);
}

void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Expand Down Expand Up @@ -490,14 +503,17 @@ class AggregateFunctionPercentileArray final
const auto& quantile_array =
assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[1]);
const auto& offset_column_data = quantile_array.get_offsets();
const auto& null_maps = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_null_map_data();
const auto& nested_column = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
quantile_array.get_data())
.get_nested_column();
const auto& nested_column_data =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(nested_column);

AggregateFunctionPercentileArray::data(place).add(
sources.get_int(row_num), nested_column_data.get_data(),
sources.get_int(row_num), nested_column_data.get_data(), null_maps,
offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,6 @@ class AggregateFunctionSimpleFactory {
std::unordered_map<std::string, std::string> function_alias;

public:
void register_nullable_function_combinator(const Creator& creator) {
for (const auto& entity : aggregate_functions) {
if (nullable_aggregate_functions.find(entity.first) ==
nullable_aggregate_functions.end()) {
nullable_aggregate_functions[entity.first] = creator;
}
}
}

static bool is_foreach(const std::string& name) {
constexpr std::string_view suffix = "_foreach";
if (name.length() < suffix.length()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
Expand Down Expand Up @@ -74,6 +75,14 @@ public PercentileArray(boolean distinct, Expression arg0, Expression arg1) {
super("percentile_array", distinct, arg0, arg1);
}

@Override
public void checkLegalityBeforeTypeCoercion() {
if (!getArgument(1).isConstant()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the limit is constant, can we also add the limit of not null and [0,1] value range?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now FE can't sure the const fold ready in FE, so union check in BE

throw new AnalysisException(
"percentile_array requires second parameter must be a constant : " + this.toSql());
}
}

/**
* withDistinctAndChildren.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,27 @@

import com.google.common.collect.ImmutableList;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;

/**
* combinator foreach
*/
public class ForEachCombinator extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, Combinator {

public static final Set<String> UNSUPPORTED_AGGREGATE_FUNCTION = Collections.unmodifiableSet(new HashSet<String>() {
{
add("percentile");
add("percentile_array");
add("percentile_approx");
add("percentile_approx_weighted");
}
});

private final AggregateFunction nested;

/**
Expand All @@ -48,10 +60,27 @@ public ForEachCombinator(List<Expression> arguments, AggregateFunction nested) {
this(arguments, false, nested);
}

/**
* Constructs a new instance of {@code ForEachCombinator}.
*
* <p>This constructor initializes a combinator that will iterate over each item in the input list
* and apply the nested aggregate function.
* If the provided aggregate function name is within the list of unsupported functions,
* an {@link UnsupportedOperationException} will be thrown.
*
* @param arguments A list of {@code Expression} objects that serve as parameters to the aggregate function.
* @param alwaysNullable A boolean flag indicating whether this combinator should always return a nullable result.
* @param nested The nested aggregate function to apply to each element. It must not be {@code null}.
* @throws NullPointerException If the provided nested aggregate function is {@code null}.
* @throws UnsupportedOperationException If nested aggregate function is one of the unsupported aggregate functions
*/
public ForEachCombinator(List<Expression> arguments, boolean alwaysNullable, AggregateFunction nested) {
super(nested.getName() + AggCombinerFunctionBuilder.FOREACH_SUFFIX, false, alwaysNullable, arguments);

this.nested = Objects.requireNonNull(nested, "nested can not be null");
if (UNSUPPORTED_AGGREGATE_FUNCTION.contains(nested.getName().toLowerCase())) {
throw new UnsupportedOperationException("Unsupport the func:" + nested.getName() + " use in foreach");
}
}

public static ForEachCombinator create(AggregateFunction nested) {
Expand Down
9 changes: 0 additions & 9 deletions regression-test/data/function_p0/test_agg_foreach.out
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,6 @@
-- !sql --
["{"num_buckets":3,"buckets":[{"lower":"1","upper":"1","ndv":1,"count":1,"pre_sum":0},{"lower":"20","upper":"20","ndv":1,"count":1,"pre_sum":1},{"lower":"100","upper":"100","ndv":1,"count":1,"pre_sum":2}]}", "{"num_buckets":1,"buckets":[{"lower":"2","upper":"2","ndv":1,"count":2,"pre_sum":0}]}", "{"num_buckets":1,"buckets":[{"lower":"3","upper":"3","ndv":1,"count":1,"pre_sum":0}]}"]

-- !sql --
[100, 2, 3]

-- !sql --
[[1], [2, 2, 2], [3]]

-- !sql --
[0, 0, 0]

-- !sql --
[0, 2, 3] [117, 2, 3] [113, 0, 3]

Expand Down
9 changes: 0 additions & 9 deletions regression-test/data/function_p0/test_agg_foreach_notnull.out
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,6 @@
-- !sql --
["{"num_buckets":3,"buckets":[{"lower":"1","upper":"1","ndv":1,"count":1,"pre_sum":0},{"lower":"20","upper":"20","ndv":1,"count":1,"pre_sum":1},{"lower":"100","upper":"100","ndv":1,"count":1,"pre_sum":2}]}", "{"num_buckets":1,"buckets":[{"lower":"2","upper":"2","ndv":1,"count":2,"pre_sum":0}]}", "{"num_buckets":1,"buckets":[{"lower":"3","upper":"3","ndv":1,"count":1,"pre_sum":0}]}"]

-- !sql --
[100, 2, 3]

-- !sql --
[[1], [2, 2, 2], [3]]

-- !sql --
[0, 0, 0]

-- !sql --
[0, 2, 3] [117, 2, 3] [113, 0, 3]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ beijing chengdu shanghai
5 29.1
6 101.1

-- !select22_1_1 --
1 \N
2 \N
3 \N
5 \N
6 \N

-- !select23 --
1 10.0
2 224.5
Expand Down Expand Up @@ -192,6 +199,13 @@ beijing chengdu shanghai
5 29.0
6 101.0

-- !select28_1 --
1 \N
2 \N
3 \N
5 \N
6 \N

-- !select29 --
1 0.0
2 216.5
Expand Down
26 changes: 16 additions & 10 deletions regression-test/suites/function_p0/test_agg_foreach.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,24 @@ suite("test_agg_foreach") {
select histogram_foreach(a) from foreach_table;
"""

qt_sql """
select PERCENTILE_foreach(a,a) from foreach_table;
"""
try {
sql "select PERCENTILE_foreach(a,a) from foreach_table;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}

qt_sql """
select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id = 1;
"""

qt_sql """

select PERCENTILE_APPROX_foreach(a,a) from foreach_table;
"""
try {
sql "select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id = 1;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}

try {
sql "select PERCENTILE_APPROX_foreach(a,a) from foreach_table;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}

qt_sql """
select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), GROUP_BIT_XOR_foreach(a) from foreach_table;
Expand Down
30 changes: 18 additions & 12 deletions regression-test/suites/function_p0/test_agg_foreach_notnull.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,26 @@ suite("test_agg_foreach_not_null") {
qt_sql """
select histogram_foreach(a) from foreach_table_not_null;
"""

qt_sql """
select PERCENTILE_foreach(a,a) from foreach_table_not_null;
"""

qt_sql """
select PERCENTILE_ARRAY_foreach(a,b) from foreach_table_not_null where id = 1;
"""

qt_sql """

select PERCENTILE_APPROX_foreach(a,a) from foreach_table_not_null;
"""
try {
sql "select PERCENTILE_foreach(a,a) from foreach_table_not_null;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}


try {
sql "select PERCENTILE_ARRAY_foreach(a,b) from foreach_table_not_null where id = 1;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}

try {
sql "select PERCENTILE_APPROX_foreach(a,a) from foreach_table_not_null;"
} catch (Exception ex) {
assert("${ex}".contains("Unsupport the func"))
}

qt_sql """
select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), GROUP_BIT_XOR_foreach(a) from foreach_table_not_null;
"""
Expand Down
Loading
Loading