Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct;
import org.apache.doris.nereids.rules.rewrite.InitJoinOrder;
import org.apache.doris.nereids.rules.rewrite.InlineLogicalView;
import org.apache.doris.nereids.rules.rewrite.JoinExtractOrFromCaseWhen;
import org.apache.doris.nereids.rules.rewrite.LimitAggToTopNAgg;
import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN;
import org.apache.doris.nereids.rules.rewrite.LogicalResultSinkToShortCircuitPointQuery;
Expand Down Expand Up @@ -323,7 +324,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
new PushFilterInsideJoin(),
new FindHashConditionForJoin(),
new ConvertInnerOrCrossJoin(),
new EliminateNullAwareLeftAntiJoin()
new EliminateNullAwareLeftAntiJoin(),
new JoinExtractOrFromCaseWhen()
),
// push down SEMI Join
bottomUp(
Expand Down Expand Up @@ -553,7 +555,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
new PushFilterInsideJoin(),
new FindHashConditionForJoin(),
new ConvertInnerOrCrossJoin(),
new EliminateNullAwareLeftAntiJoin()
new EliminateNullAwareLeftAntiJoin(),
new JoinExtractOrFromCaseWhen()
),
// push down SEMI Join
bottomUp(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ public void execute(JobContext jobContext) {
null, -1, originPlan, jobContext, rules, false, new ProcessState(originPlan)
);
jobContext.getCascadesContext().setRewritePlan(root);

jobContext.getCascadesContext().setRewritePlan(root);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ public enum RuleType {
REWRITE_PARTITION_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_QUALIFY_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
JOIN_EXTRACT_OR_FROM_CASE_WHEN(RuleTypeClass.REWRITE),
EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE),
REORDER_JOIN(RuleTypeClass.REWRITE),
INIT_JOIN_ORDER(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ protected List<? extends Expression> bindSlotByThisScope(UnboundSlot unboundSlot

Set<Expression> havingExprs = having.getConjuncts();
ImmutableSet.Builder<Expression> analyzedHaving = ImmutableSet.builder();
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(having, cascadesContext);
Map<Expression, Expression> bindUniqueIdReplaceMap
= getGroupByUniqueFuncReplaceMap(aggregate.getGroupByExpressions());
for (Expression expression : havingExprs) {
Expand Down Expand Up @@ -668,7 +668,7 @@ private LogicalPlan bindUsingJoin(MatchingContext<LogicalUsingJoin<Plan, Plan>>

Scope leftScope = toScope(cascadesContext, using.left().getOutput(), using.left().getAsteriskOutput());
Scope rightScope = toScope(cascadesContext, using.right().getOutput(), using.right().getAsteriskOutput());
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(using, cascadesContext);

Builder<Expression> hashEqExprs = ImmutableList.builderWithExpectedSize(unboundHashJoinConjunct.size());
List<Slot> rightConjunctsSlots = Lists.newArrayList();
Expand Down Expand Up @@ -1066,7 +1066,7 @@ protected List<? extends Expression> bindSlotByThisScope(UnboundSlot unboundSlot

Map<Expression, Expression> bindUniqueIdReplaceMap
= getGroupByUniqueFuncReplaceMap(aggregate.getGroupByExpressions());
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(qualify, cascadesContext);
for (Expression expression : qualify.getConjuncts()) {
Expression boundExpr = qualifyAnalyzer.analyze(expression, rewriteContext);
// logical plan builder no extract conjunction
Expand Down Expand Up @@ -1598,7 +1598,7 @@ protected SimpleExprAnalyzer buildSimpleExprAnalyzer(
Plan currentPlan, CascadesContext cascadesContext, List<Plan> children) {
Scope scope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(children),
PlanUtils.fastGetChildrenAsteriskOutputs(children));
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext);
ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan,
scope, cascadesContext, true, true);
return expr -> expressionAnalyzer.analyze(expr, rewriteContext);
Expand All @@ -1607,7 +1607,7 @@ protected SimpleExprAnalyzer buildSimpleExprAnalyzer(
private SimpleExprAnalyzer buildCustomSlotBinderAnalyzer(
Plan currentPlan, CascadesContext cascadesContext, Scope defaultScope,
boolean enableExactMatch, boolean bindSlotInOuterScope, CustomSlotBinderAnalyzer customSlotBinder) {
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext);
ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan, defaultScope, cascadesContext,
enableExactMatch, bindSlotInOuterScope) {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.DereferenceExpression;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.EqualTo;
Expand Down Expand Up @@ -127,13 +128,19 @@

/** ExpressionAnalyzer */
public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext> {
// This rule only used in unit test
@VisibleForTesting
public static final AbstractExpressionRewriteRule FUNCTION_ANALYZER_RULE = new AbstractExpressionRewriteRule() {
@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
return new ExpressionAnalyzer(
null, new Scope(ImmutableList.of()), null, false, false
).analyze(expr, ctx);
return new ExpressionAnalyzer(null, new Scope(ImmutableList.of()), null, false, false) {
@Override
protected Expression processCompoundNewChildren(CompoundPredicate cp, List<Expression> newChildren) {
// ExpressionUtils.and/ExpressionUtils.or will remove duplicate children, and simplify FALSE / TRUE.
// But we don't want to simplify them in unit test.
return cp.withChildren(newChildren);
}
}.analyze(expr, ctx);
}
};

Expand Down Expand Up @@ -170,13 +177,22 @@ public static Expression analyzeFunction(
cascadesContext, false, false);
return analyzer.analyze(
expression,
cascadesContext == null ? null : new ExpressionRewriteContext(cascadesContext)
cascadesContext == null ? null : new ExpressionRewriteContext(plan, cascadesContext)
);
}

/** analyze */
public Expression analyze(Expression expression) {
CascadesContext cascadesContext = getCascadesContext();
return analyze(expression, cascadesContext == null ? null : new ExpressionRewriteContext(cascadesContext));
ExpressionRewriteContext rewriteContext;
if (cascadesContext == null) {
rewriteContext = null;
} else if (currentPlan == null) {
rewriteContext = new ExpressionRewriteContext(cascadesContext);
} else {
rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext);
}
return analyze(expression, rewriteContext);
}

/** analyze */
Expand Down Expand Up @@ -655,7 +671,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) {
newChild = child;
}
if (newChild.getDataType().isNullType()) {
newChild = new NullLiteral(BooleanType.INSTANCE);
newChild = NullLiteral.BOOLEAN_INSTANCE;
} else {
newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE);
}
Expand All @@ -666,7 +682,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) {
newChildren.add(newChild);
}
if (hasNewChild) {
return ExpressionUtils.or(newChildren);
return processCompoundNewChildren(or, newChildren);
} else {
return or;
}
Expand All @@ -683,23 +699,31 @@ public Expression visitAnd(And and, ExpressionRewriteContext context) {
newChild = child;
}
if (newChild.getDataType().isNullType()) {
newChild = new NullLiteral(BooleanType.INSTANCE);
newChild = NullLiteral.BOOLEAN_INSTANCE;
} else {
newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE);
}

if (! child.equals(newChild)) {
if (!child.equals(newChild)) {
hasNewChild = true;
}
newChildren.add(newChild);
}
if (hasNewChild) {
return ExpressionUtils.and(newChildren);
return processCompoundNewChildren(and, newChildren);
} else {
return and;
}
}

protected Expression processCompoundNewChildren(CompoundPredicate cp, List<Expression> newChildren) {
if (cp instanceof And) {
return ExpressionUtils.and(newChildren);
} else {
return ExpressionUtils.or(newChildren);
}
}

@Override
public Expression visitNot(Not not, ExpressionRewriteContext context) {
// maybe is `not subquery`, we should bind it first
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ public List<Rule> buildRules() {
* if it's semi join with non-null mark slot
* we can safely change the mark conjunct to hash conjunct
*/
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext);
ExpressionRewriteContext rewriteContext
= new ExpressionRewriteContext(join, ctx.cascadesContext);
boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class)
? ExpressionUtils.canInferNotNullForMarkSlot(
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, rewriteContext),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticComparisonRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyEqualBooleanLiteral;
import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule;
import org.apache.doris.nereids.rules.expression.rules.SupportJavaDateFormatter;
import org.apache.doris.nereids.rules.expression.rules.TimestampToAddTime;
Expand Down Expand Up @@ -72,7 +73,8 @@ public class ExpressionNormalization extends ExpressionRewrite {
ConvertAggStateCast.INSTANCE,
MergeDateTrunc.INSTANCE,
NormalizeStructElement.INSTANCE,
CheckCast.INSTANCE
CheckCast.INSTANCE,
SimplifyEqualBooleanLiteral.INSTANCE
)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
import org.apache.doris.nereids.rules.expression.rules.AddMinMax;
import org.apache.doris.nereids.rules.expression.rules.ArrayContainToArrayOverlap;
import org.apache.doris.nereids.rules.expression.rules.BetweenToEqual;
import org.apache.doris.nereids.rules.expression.rules.CaseWhenToCompoundPredicate;
import org.apache.doris.nereids.rules.expression.rules.CaseWhenToIf;
import org.apache.doris.nereids.rules.expression.rules.CondReplaceNullWithFalse;
import org.apache.doris.nereids.rules.expression.rules.DateFunctionRewrite;
import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
import org.apache.doris.nereids.rules.expression.rules.LikeToEqualRewrite;
import org.apache.doris.nereids.rules.expression.rules.NestedCaseWhenCondToLiteral;
import org.apache.doris.nereids.rules.expression.rules.NullSafeEqualToEqual;
import org.apache.doris.nereids.rules.expression.rules.PushIntoCaseWhenBranch;
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.nereids.rules.expression.rules.SimplifyConflictCompound;
import org.apache.doris.nereids.rules.expression.rules.SimplifyInPredicate;
Expand Down Expand Up @@ -57,7 +61,11 @@ public class ExpressionOptimization extends ExpressionRewrite {

DateFunctionRewrite.INSTANCE,
ArrayContainToArrayOverlap.INSTANCE,
CondReplaceNullWithFalse.INSTANCE,
NestedCaseWhenCondToLiteral.INSTANCE,
CaseWhenToIf.INSTANCE,
CaseWhenToCompoundPredicate.INSTANCE,
PushIntoCaseWhenBranch.INSTANCE,
TopnToMax.INSTANCE,
NullSafeEqualToEqual.INSTANCE,
LikeToEqualRewrite.INSTANCE,
Expand Down
Loading