-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[fix](nereids) do eliminate constant group by key in normalizeagg #49589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b1c206e
7feb0df
084e6cd
23caabb
0161f8b
b2b321d
9d3f17d
9d6c739
33093af
adfd9b2
ceb9aad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,9 +17,12 @@ | |
|
|
||
| package org.apache.doris.nereids.rules.analysis; | ||
|
|
||
| import org.apache.doris.nereids.CascadesContext; | ||
| import org.apache.doris.nereids.exceptions.AnalysisException; | ||
| import org.apache.doris.nereids.rules.Rule; | ||
| import org.apache.doris.nereids.rules.RuleType; | ||
| import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; | ||
| import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE; | ||
| import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot; | ||
| import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; | ||
| import org.apache.doris.nereids.trees.expressions.Alias; | ||
|
|
@@ -35,6 +38,7 @@ | |
| import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; | ||
| import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction; | ||
| import org.apache.doris.nereids.trees.expressions.literal.Literal; | ||
| import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; | ||
| import org.apache.doris.nereids.trees.plans.Plan; | ||
| import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; | ||
| import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; | ||
|
|
@@ -50,6 +54,7 @@ | |
| import com.google.common.collect.Sets; | ||
|
|
||
| import java.util.ArrayList; | ||
| import java.util.HashMap; | ||
| import java.util.HashSet; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
|
|
@@ -111,14 +116,16 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot { | |
| public List<Rule> buildRules() { | ||
| return ImmutableList.of( | ||
| logicalHaving(logicalAggregate().whenNot(LogicalAggregate::isNormalized)) | ||
| .then(having -> normalizeAgg(having.child(), Optional.of(having))) | ||
| .thenApply(ctx -> normalizeAgg(ctx.root.child(), Optional.of(ctx.root), ctx.cascadesContext)) | ||
| .toRule(RuleType.NORMALIZE_AGGREGATE), | ||
| logicalAggregate().whenNot(LogicalAggregate::isNormalized) | ||
| .then(aggregate -> normalizeAgg(aggregate, Optional.empty())) | ||
| .thenApply(ctx -> normalizeAgg(ctx.root, Optional.empty(), ctx.cascadesContext)) | ||
| .toRule(RuleType.NORMALIZE_AGGREGATE)); | ||
| } | ||
|
|
||
| private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having) { | ||
| @SuppressWarnings("checkstyle:UnusedLocalVariable") | ||
| private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having, | ||
| CascadesContext ctx) { | ||
| // The LogicalAggregate node may contain window agg functions and usual agg functions | ||
| // we call window agg functions as window-agg and usual agg functions as trivial-agg for short | ||
| // This rule simplify LogicalAggregate node by: | ||
|
|
@@ -279,8 +286,10 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<Logi | |
| List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput, | ||
| groupByExprContext, argsOfAggFuncNeedPushDownContext, normalizedAggFuncsToSlotContext); | ||
|
|
||
| // create a parent project node | ||
| LogicalProject<Plan> project = new LogicalProject<>(upperProjects, newAggregate); | ||
| ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx); | ||
| LogicalProject<Plan> project = eliminateGroupByConstant(groupByExprContext, rewriteContext, | ||
| normalizedGroupExprs, normalizedAggOutput, bottomProjects, aggregate, upperProjects, newAggregate); | ||
|
|
||
| // verify project used slots are all coming from agg's output | ||
| List<Slot> slots = collectAllUsedSlots(upperProjects); | ||
| if (!slots.isEmpty()) { | ||
|
|
@@ -389,4 +398,93 @@ private Expression normalizeAggFuncChildren(NormalizeToSlotContext context, Expr | |
| return expr; | ||
| } | ||
| } | ||
|
|
||
| private LogicalProject<Plan> eliminateGroupByConstant(NormalizeToSlotContext groupByExprContext, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we test this method in UT? |
||
| ExpressionRewriteContext rewriteContext, List<Expression> normalizedGroupExprs, | ||
| List<NamedExpression> normalizedAggOutput, Set<NamedExpression> bottomProjects, | ||
| LogicalAggregate<Plan> aggregate, List<NamedExpression> upperProjects, LogicalAggregate<?> newAggregate) { | ||
| // 1. Find the expressions in group by that can be folded into constants and build a map(slot, literal) | ||
| Map<Expression, NormalizeToSlotTriplet> replaceMap = groupByExprContext.getNormalizeToSlotMap(); | ||
| if (replaceMap.isEmpty()) { | ||
| return new LogicalProject<>(upperProjects, newAggregate); | ||
| } | ||
| Map<Slot, Expression> slotToLiteral = new HashMap<>(); | ||
| for (Map.Entry<Expression, NormalizeToSlotTriplet> entry : replaceMap.entrySet()) { | ||
| Expression foldExpression = FoldConstantRuleOnFE.evaluate(entry.getKey(), rewriteContext); | ||
| if (foldExpression.isConstant()) { | ||
| slotToLiteral.put(entry.getValue().remainExpr, foldExpression); | ||
| } | ||
| } | ||
| if (slotToLiteral.isEmpty()) { | ||
| return new LogicalProject<>(upperProjects, newAggregate); | ||
| } | ||
| // 2. Regenerate a group by list without constant key | ||
| List<Expression> newNormalizedGroupExprs = new ArrayList<>(); | ||
| for (Expression normalizedGroupExpr : normalizedGroupExprs) { | ||
| if (!slotToLiteral.containsKey((Slot) normalizedGroupExpr)) { | ||
| newNormalizedGroupExprs.add(normalizedGroupExpr); | ||
| } | ||
| } | ||
| if (newNormalizedGroupExprs.size() == normalizedGroupExprs.size()) { | ||
| return new LogicalProject<>(upperProjects, newAggregate); | ||
| } | ||
| if (newNormalizedGroupExprs.isEmpty()) { | ||
| Alias tinyInt = new Alias(new TinyIntLiteral((byte) 1)); | ||
| bottomProjects = new HashSet<>(bottomProjects); | ||
| bottomProjects.add(tinyInt); | ||
| normalizedAggOutput = new ArrayList<>(normalizedAggOutput); | ||
| Slot tinyIntSlot = tinyInt.toSlot(); | ||
| normalizedAggOutput.add(tinyIntSlot); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please check if it's be's requirement that all group by exprs should be in agg's output? |
||
| newNormalizedGroupExprs.add(tinyIntSlot); | ||
| } | ||
| // 3. Replace the agg output expression and delete the constant group by key in the output | ||
| ImmutableList.Builder<NamedExpression> nonConstAggOutput = ImmutableList.builder(); | ||
| for (NamedExpression ne : normalizedAggOutput) { | ||
| if (ne instanceof Alias) { | ||
| nonConstAggOutput.add(ExpressionUtils.replaceNameExpression(ne, slotToLiteral)); | ||
| continue; | ||
| } else if (ne instanceof Slot) { | ||
| if (!slotToLiteral.containsKey(ne)) { | ||
| nonConstAggOutput.add(ne); | ||
| } | ||
| continue; | ||
| } | ||
| nonConstAggOutput.add(ne); | ||
| } | ||
|
|
||
| // 4. The constant expression calculation in bottom projects needs to be deleted | ||
| // and put into upperProjects for calculation | ||
| Plan bottomPlan; | ||
| if (!bottomProjects.isEmpty()) { | ||
| ImmutableList.Builder<NamedExpression> builder = ImmutableList.builder(); | ||
| for (NamedExpression bottomProject : bottomProjects) { | ||
| if (!slotToLiteral.containsKey(bottomProject.toSlot())) { | ||
| builder.add(bottomProject); | ||
| } | ||
| } | ||
| bottomPlan = new LogicalProject<>(builder.build(), aggregate.child()); | ||
| } else { | ||
| bottomPlan = aggregate.child(); | ||
| } | ||
| LogicalAggregate<Plan> newAggAfterEliminate = aggregate.withNormalized(newNormalizedGroupExprs, | ||
| nonConstAggOutput.build(), bottomPlan); | ||
| // 5. This upperProjects needs to add the constant key that was deleted in the group by key | ||
| // and change the reference to the constant key to a constant expression | ||
| ImmutableList.Builder<NamedExpression> newUpperProjects = ImmutableList.builder(); | ||
| for (NamedExpression upperProject : upperProjects) { | ||
| if (upperProject instanceof Alias) { | ||
| newUpperProjects.add(ExpressionUtils.replaceNameExpression(upperProject, slotToLiteral)); | ||
| continue; | ||
| } else if (upperProject instanceof Slot) { | ||
| if (slotToLiteral.containsKey(upperProject)) { | ||
| Alias newLiteral = new Alias(upperProject.getExprId(), slotToLiteral.get(upperProject), | ||
| upperProject.getName()); | ||
| newUpperProjects.add(newLiteral); | ||
| continue; | ||
| } | ||
| } | ||
| newUpperProjects.add(upperProject); | ||
| } | ||
| return new LogicalProject<>(newUpperProjects.build(), newAggAfterEliminate); | ||
| } | ||
| } | ||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove EliminateGroupByConstant file