Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.apache.doris.nereids.rules.analysis.CollectJoinConstraint;
import org.apache.doris.nereids.rules.analysis.CollectSubQueryAlias;
import org.apache.doris.nereids.rules.analysis.EliminateDistinctConstant;
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.EliminateLogicalSelectHint;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
import org.apache.doris.nereids.rules.analysis.HavingToFilter;
Expand Down Expand Up @@ -136,8 +135,6 @@ private static List<RewriteJob> buildAnalyzerJobs() {
// select SUM(lo_tax) FROM lineorder group by 1;
// errCode = 2, detailMessage = GROUP BY expression must not contain aggregate functions: sum(lo_tax)
bottomUp(new CheckAnalysis()),
topDown(new EliminateGroupByConstant()),

topDown(new SimplifyAggGroupBy()),
topDown(new NormalizeAggregate()),
topDown(new HavingToFilter()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
Expand Down Expand Up @@ -158,7 +157,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(
new EliminateOrderByConstant(),
new EliminateSortUnderSubqueryOrView(),
new EliminateGroupByConstant(),
// MergeProjects depends on this rule
new LogicalSubQueryAliasToLogicalProject(),
// TODO: we should do expression normalization after plan normalization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
Expand All @@ -35,6 +38,7 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
Expand All @@ -50,6 +54,7 @@
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -111,14 +116,16 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
public List<Rule> buildRules() {
return ImmutableList.of(
logicalHaving(logicalAggregate().whenNot(LogicalAggregate::isNormalized))
.then(having -> normalizeAgg(having.child(), Optional.of(having)))
.thenApply(ctx -> normalizeAgg(ctx.root.child(), Optional.of(ctx.root), ctx.cascadesContext))
.toRule(RuleType.NORMALIZE_AGGREGATE),
logicalAggregate().whenNot(LogicalAggregate::isNormalized)
.then(aggregate -> normalizeAgg(aggregate, Optional.empty()))
.thenApply(ctx -> normalizeAgg(ctx.root, Optional.empty(), ctx.cascadesContext))
.toRule(RuleType.NORMALIZE_AGGREGATE));
}

private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having) {
@SuppressWarnings("checkstyle:UnusedLocalVariable")
private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having,
CascadesContext ctx) {
// The LogicalAggregate node may contain window agg functions and usual agg functions
// we call window agg functions as window-agg and usual agg functions as trivial-agg for short
// This rule simplify LogicalAggregate node by:
Expand Down Expand Up @@ -279,8 +286,10 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<Logi
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
groupByExprContext, argsOfAggFuncNeedPushDownContext, normalizedAggFuncsToSlotContext);

// create a parent project node
LogicalProject<Plan> project = new LogicalProject<>(upperProjects, newAggregate);
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx);
LogicalProject<Plan> project = eliminateGroupByConstant(groupByExprContext, rewriteContext,
normalizedGroupExprs, normalizedAggOutput, bottomProjects, aggregate, upperProjects, newAggregate);

// verify project used slots are all coming from agg's output
List<Slot> slots = collectAllUsedSlots(upperProjects);
if (!slots.isEmpty()) {
Expand Down Expand Up @@ -389,4 +398,93 @@ private Expression normalizeAggFuncChildren(NormalizeToSlotContext context, Expr
return expr;
}
}

private LogicalProject<Plan> eliminateGroupByConstant(NormalizeToSlotContext groupByExprContext,
ExpressionRewriteContext rewriteContext, List<Expression> normalizedGroupExprs,
List<NamedExpression> normalizedAggOutput, Set<NamedExpression> bottomProjects,
LogicalAggregate<Plan> aggregate, List<NamedExpression> upperProjects, LogicalAggregate<?> newAggregate) {
// 1. Find the expressions in group by that can be folded into constants and build a map(slot, literal)
Map<Expression, NormalizeToSlotTriplet> replaceMap = groupByExprContext.getNormalizeToSlotMap();
if (replaceMap.isEmpty()) {
return new LogicalProject<>(upperProjects, newAggregate);
}
Map<Slot, Expression> slotToLiteral = new HashMap<>();
for (Map.Entry<Expression, NormalizeToSlotTriplet> entry : replaceMap.entrySet()) {
Expression foldExpression = FoldConstantRuleOnFE.evaluate(entry.getKey(), rewriteContext);
if (foldExpression.isConstant()) {
slotToLiteral.put(entry.getValue().remainExpr, foldExpression);
}
}
if (slotToLiteral.isEmpty()) {
return new LogicalProject<>(upperProjects, newAggregate);
}
// 2. Regenerate a group by list without constant key
List<Expression> newNormalizedGroupExprs = new ArrayList<>();
for (Expression normalizedGroupExpr : normalizedGroupExprs) {
if (!slotToLiteral.containsKey((Slot) normalizedGroupExpr)) {
newNormalizedGroupExprs.add(normalizedGroupExpr);
}
}
if (newNormalizedGroupExprs.size() == normalizedGroupExprs.size()) {
return new LogicalProject<>(upperProjects, newAggregate);
}
if (newNormalizedGroupExprs.isEmpty()) {
Alias tinyInt = new Alias(new TinyIntLiteral((byte) 1));
bottomProjects = new HashSet<>(bottomProjects);
bottomProjects.add(tinyInt);
normalizedAggOutput = new ArrayList<>(normalizedAggOutput);
Slot tinyIntSlot = tinyInt.toSlot();
normalizedAggOutput.add(tinyIntSlot);
newNormalizedGroupExprs.add(tinyIntSlot);
}
// 3. Replace the agg output expression and delete the constant group by key in the output
ImmutableList.Builder<NamedExpression> nonConstAggOutput = ImmutableList.builder();
for (NamedExpression ne : normalizedAggOutput) {
if (ne instanceof Alias) {
nonConstAggOutput.add(ExpressionUtils.replaceNameExpression(ne, slotToLiteral));
continue;
} else if (ne instanceof Slot) {
if (!slotToLiteral.containsKey(ne)) {
nonConstAggOutput.add(ne);
}
continue;
}
nonConstAggOutput.add(ne);
}

// 4. The constant expression calculation in bottom projects needs to be deleted
// and put into upperProjects for calculation
Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
ImmutableList.Builder<NamedExpression> builder = ImmutableList.builder();
for (NamedExpression bottomProject : bottomProjects) {
if (!slotToLiteral.containsKey(bottomProject.toSlot())) {
builder.add(bottomProject);
}
}
bottomPlan = new LogicalProject<>(builder.build(), aggregate.child());
} else {
bottomPlan = aggregate.child();
}
LogicalAggregate<Plan> newAggAfterEliminate = aggregate.withNormalized(newNormalizedGroupExprs,
nonConstAggOutput.build(), bottomPlan);
// 5. This upperProjects needs to add the constant key that was deleted in the group by key
// and change the reference to the constant key to a constant expression
ImmutableList.Builder<NamedExpression> newUpperProjects = ImmutableList.builder();
for (NamedExpression upperProject : upperProjects) {
if (upperProject instanceof Alias) {
newUpperProjects.add(ExpressionUtils.replaceNameExpression(upperProject, slotToLiteral));
continue;
} else if (upperProject instanceof Slot) {
if (slotToLiteral.containsKey(upperProject)) {
Alias newLiteral = new Alias(upperProject.getExprId(), slotToLiteral.get(upperProject),
upperProject.getName());
newUpperProjects.add(newLiteral);
continue;
}
}
newUpperProjects.add(upperProject);
}
return new LogicalProject<>(newUpperProjects.build(), newAggAfterEliminate);
}
}

This file was deleted.

Loading