From 5827ca902ea2716cf94d68a7cee5459dc0e52d81 Mon Sep 17 00:00:00 2001 From: starocean999 <40539150+starocean999@users.noreply.github.com> Date: Fri, 27 Oct 2023 09:36:10 +0800 Subject: [PATCH] [fix](nereids) push down subquery exprs in non-distinct agg functions (#25955) --- .../rules/analysis/NormalizeAggregate.java | 26 ++++++++++++------- .../subquery/test_subquery_in_project.out | 6 +++++ .../subquery/test_subquery_in_project.groovy | 8 ++++++ 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index 180e9b915c3752..c287f2dffe9841 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -20,12 +20,14 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot; +import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.SubqueryExpr; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; @@ -101,11 +103,17 @@ public Rule build() { List aggregateOutput = aggregate.getOutputExpressions(); Set existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance); + // we need push down subquery exprs in side non-distinct agg functions + Set subqueryExprs = ExpressionUtils.mutableCollect( + Lists.newArrayList(ExpressionUtils.mutableCollect(aggregateOutput, + expr -> expr instanceof AggregateFunction + && !((AggregateFunction) expr).isDistinct())), + SubqueryExpr.class::isInstance); Set groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions()); - NormalizeToSlotContext groupByToSlotContext = - NormalizeToSlotContext.buildContext(existsAlias, groupingByExprs); - Set bottomGroupByProjects = - groupByToSlotContext.pushDownToNamedExpression(groupingByExprs); + NormalizeToSlotContext bottomSlotContext = + NormalizeToSlotContext.buildContext(existsAlias, Sets.union(groupingByExprs, subqueryExprs)); + Set bottomOutputs = + bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, subqueryExprs)); List aggFuncs = Lists.newArrayList(); aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs)); @@ -119,8 +127,8 @@ public Rule build() { // after normalize: // agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1]) // +-- project((a[#0] + 1)[#1]) - List normalizedAggFuncs = groupByToSlotContext.normalizeToUseSlotRef(aggFuncs); - Set bottomProjects = Sets.newHashSet(bottomGroupByProjects); + List normalizedAggFuncs = bottomSlotContext.normalizeToUseSlotRef(aggFuncs); + Set bottomProjects = Sets.newHashSet(bottomOutputs); // TODO: if we have distinct agg, we must push down its children, // because need use it to generate distribution enforce // step 1: split agg functions into 2 parts: distinct and not distinct @@ -174,7 +182,7 @@ public Rule build() { NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs); // agg output include 2 part, normalized group by slots and normalized agg functions List normalizedAggOutput = ImmutableList.builder() - .addAll(bottomGroupByProjects.stream().map(NamedExpression::toSlot).iterator()) + .addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator()) .addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs)) .build(); // add normalized agg's input slots to bottom projects @@ -188,7 +196,7 @@ public Rule build() { .collect(Collectors.toSet()); bottomProjects.addAll(aggInputSlots); // build group by exprs - List normalizedGroupExprs = groupByToSlotContext.normalizeToUseSlotRef(groupingByExprs); + List normalizedGroupExprs = bottomSlotContext.normalizeToUseSlotRef(groupingByExprs); Plan bottomPlan; if (!bottomProjects.isEmpty()) { @@ -198,7 +206,7 @@ public Rule build() { } List upperProjects = normalizeOutput(aggregateOutput, - groupByToSlotContext, normalizedAggFuncsToSlotContext); + bottomSlotContext, normalizedAggFuncsToSlotContext); return new LogicalProject<>(upperProjects, aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan)); diff --git a/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out index 5b97935639059d..4d8bd4c7361963 100644 --- a/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out +++ b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out @@ -48,3 +48,9 @@ true \N 2.0 2020-09-09 2.0 +-- !sql15 -- +12 + +-- !sql16 -- +12 + diff --git a/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy index 0521334d8ae881..b9de14e530bcc7 100644 --- a/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy +++ b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy @@ -116,5 +116,13 @@ suite("test_subquery_in_project") { end 'test' from test_sql group by cube(dt) order by dt; """ + qt_sql15 """ + select sum(age + (select sum(age) from test_sql)) from test_sql; + """ + + qt_sql16 """ + select sum(distinct age + (select sum(age) from test_sql)) from test_sql; + """ + sql """drop table if exists test_sql;""" }