Skip to content

Commit

Permalink
[fix](nereids)push down subquery exprs in non-distinct agg functions #…
Browse files Browse the repository at this point in the history
  • Loading branch information
starocean999 authored Nov 3, 2023
1 parent 72c125d commit 0efcff6
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
Expand Down Expand Up @@ -102,14 +104,21 @@ public Rule build() {

List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
Set<Alias> existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
Set<Expression> groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions());
NormalizeToSlotContext groupByToSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, groupingByExprs);
Set<NamedExpression> bottomGroupByProjects =
groupByToSlotContext.pushDownToNamedExpression(groupingByExprs);

List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));

// we need push down subquery exprs inside non-window and non-distinct agg functions
// because the distinct agg's children would be pushed down in later process
Set<SubqueryExpr> subqueryExprs = ExpressionUtils.mutableCollect(aggFuncs.stream()
.filter(aggFunc -> !aggFunc.isDistinct()).collect(Collectors.toList()),
SubqueryExpr.class::isInstance);
Set<Expression> groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions());
NormalizeToSlotContext bottomSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, Sets.union(groupingByExprs, subqueryExprs));
Set<NamedExpression> bottomOutputs =
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, subqueryExprs));

// use group by context to normalize agg functions to process
// sql like: select sum(a + 1) from t group by a + 1
//
Expand All @@ -120,8 +129,8 @@ public Rule build() {
// after normalize:
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
// +-- project((a[#0] + 1)[#1])
List<AggregateFunction> normalizedAggFuncs = groupByToSlotContext.normalizeToUseSlotRef(aggFuncs);
Set<NamedExpression> bottomProjects = Sets.newHashSet(bottomGroupByProjects);
List<AggregateFunction> normalizedAggFuncs = bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
Set<NamedExpression> bottomProjects = Sets.newHashSet(bottomOutputs);
// TODO: if we have distinct agg, we must push down its children,
// because need use it to generate distribution enforce
// step 1: split agg functions into 2 parts: distinct and not distinct
Expand Down Expand Up @@ -175,7 +184,7 @@ public Rule build() {
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
// agg output include 2 part, normalized group by slots and normalized agg functions
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(bottomGroupByProjects.stream().map(NamedExpression::toSlot).iterator())
.addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator())
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
.build();
// add normalized agg's input slots to bottom projects
Expand All @@ -189,7 +198,7 @@ public Rule build() {
.collect(Collectors.toSet());
bottomProjects.addAll(aggInputSlots);
// build group by exprs
List<Expression> normalizedGroupExprs = groupByToSlotContext.normalizeToUseSlotRef(groupingByExprs);
List<Expression> normalizedGroupExprs = bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);

Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
Expand All @@ -199,7 +208,7 @@ public Rule build() {
}

List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
groupByToSlotContext, normalizedAggFuncsToSlotContext);
bottomSlotContext, normalizedAggFuncsToSlotContext);

return new LogicalProject<>(upperProjects,
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,21 @@ true
\N 2.0
2020-09-09 2.0

-- !sql15 --
12

-- !sql16 --
12

-- !sql17 --
12
12

-- !sql18 --
12
12

-- !sql20 --
5
7

Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,25 @@ suite("test_subquery_in_project") {
end 'test' from test_sql group by cube(dt) order by dt;
"""

qt_sql15 """
select sum(age + (select sum(age) from test_sql)) from test_sql order by 1;
"""

qt_sql16 """
select sum(distinct age + (select sum(age) from test_sql)) from test_sql order by 1;
"""

qt_sql17 """
select sum(age + (select sum(age) from test_sql)) over() from test_sql order by 1;
"""

qt_sql18 """
select sum(age + (select sum(age) from test_sql)) over() from test_sql group by dt, age order by 1;
"""

qt_sql20 """
select sum(age + (select sum(age) from test_sql)) from test_sql group by dt, age order by 1;
"""

sql """drop table if exists test_sql;"""
}

0 comments on commit 0efcff6

Please sign in to comment.