Skip to content

Commit

Permalink
[fix](nereids) push down subquery exprs in non-distinct agg functions (
Browse files Browse the repository at this point in the history
  • Loading branch information
starocean999 authored and 胥剑旭 committed Dec 14, 2023
1 parent fa52395 commit 4dda45a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
Expand Down Expand Up @@ -101,11 +103,17 @@ public Rule build() {

List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
Set<Alias> existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
// we need push down subquery exprs in side non-distinct agg functions
Set<SubqueryExpr> subqueryExprs = ExpressionUtils.mutableCollect(
Lists.newArrayList(ExpressionUtils.mutableCollect(aggregateOutput,
expr -> expr instanceof AggregateFunction
&& !((AggregateFunction) expr).isDistinct())),
SubqueryExpr.class::isInstance);
Set<Expression> groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions());
NormalizeToSlotContext groupByToSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, groupingByExprs);
Set<NamedExpression> bottomGroupByProjects =
groupByToSlotContext.pushDownToNamedExpression(groupingByExprs);
NormalizeToSlotContext bottomSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, Sets.union(groupingByExprs, subqueryExprs));
Set<NamedExpression> bottomOutputs =
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, subqueryExprs));

List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
Expand All @@ -119,8 +127,8 @@ public Rule build() {
// after normalize:
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
// +-- project((a[#0] + 1)[#1])
List<AggregateFunction> normalizedAggFuncs = groupByToSlotContext.normalizeToUseSlotRef(aggFuncs);
Set<NamedExpression> bottomProjects = Sets.newHashSet(bottomGroupByProjects);
List<AggregateFunction> normalizedAggFuncs = bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
Set<NamedExpression> bottomProjects = Sets.newHashSet(bottomOutputs);
// TODO: if we have distinct agg, we must push down its children,
// because need use it to generate distribution enforce
// step 1: split agg functions into 2 parts: distinct and not distinct
Expand Down Expand Up @@ -174,7 +182,7 @@ public Rule build() {
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
// agg output include 2 part, normalized group by slots and normalized agg functions
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(bottomGroupByProjects.stream().map(NamedExpression::toSlot).iterator())
.addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator())
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
.build();
// add normalized agg's input slots to bottom projects
Expand All @@ -188,7 +196,7 @@ public Rule build() {
.collect(Collectors.toSet());
bottomProjects.addAll(aggInputSlots);
// build group by exprs
List<Expression> normalizedGroupExprs = groupByToSlotContext.normalizeToUseSlotRef(groupingByExprs);
List<Expression> normalizedGroupExprs = bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);

Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
Expand All @@ -198,7 +206,7 @@ public Rule build() {
}

List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
groupByToSlotContext, normalizedAggFuncsToSlotContext);
bottomSlotContext, normalizedAggFuncsToSlotContext);

return new LogicalProject<>(upperProjects,
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ true
\N 2.0
2020-09-09 2.0

-- !sql15 --
12

-- !sql16 --
12

Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,13 @@ suite("test_subquery_in_project") {
end 'test' from test_sql group by cube(dt) order by dt;
"""

qt_sql15 """
select sum(age + (select sum(age) from test_sql)) from test_sql;
"""

qt_sql16 """
select sum(distinct age + (select sum(age) from test_sql)) from test_sql;
"""

sql """drop table if exists test_sql;"""
}

0 comments on commit 4dda45a

Please sign in to comment.