diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index a814d903a34217..d676e00c206363 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -99,6 +99,7 @@ import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct; import org.apache.doris.nereids.rules.rewrite.InitJoinOrder; import org.apache.doris.nereids.rules.rewrite.InlineLogicalView; +import org.apache.doris.nereids.rules.rewrite.JoinExtractOrFromCaseWhen; import org.apache.doris.nereids.rules.rewrite.LimitAggToTopNAgg; import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN; import org.apache.doris.nereids.rules.rewrite.LogicalResultSinkToShortCircuitPointQuery; @@ -323,7 +324,8 @@ public class Rewriter extends AbstractBatchJobExecutor { new PushFilterInsideJoin(), new FindHashConditionForJoin(), new ConvertInnerOrCrossJoin(), - new EliminateNullAwareLeftAntiJoin() + new EliminateNullAwareLeftAntiJoin(), + new JoinExtractOrFromCaseWhen() ), // push down SEMI Join bottomUp( @@ -553,7 +555,8 @@ public class Rewriter extends AbstractBatchJobExecutor { new PushFilterInsideJoin(), new FindHashConditionForJoin(), new ConvertInnerOrCrossJoin(), - new EliminateNullAwareLeftAntiJoin() + new EliminateNullAwareLeftAntiJoin(), + new JoinExtractOrFromCaseWhen() ), // push down SEMI Join bottomUp( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/TopDownVisitorRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/TopDownVisitorRewriteJob.java index 5275de0eb0ac30..8ae29a77f470f8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/TopDownVisitorRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/TopDownVisitorRewriteJob.java @@ -57,8 +57,6 @@ public void execute(JobContext jobContext) { null, -1, originPlan, jobContext, rules, false, new ProcessState(originPlan) ); jobContext.getCascadesContext().setRewritePlan(root); - - jobContext.getCascadesContext().setRewritePlan(root); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index aef4ddcbb7c38f..9da2a0360bba96 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -260,6 +260,7 @@ public enum RuleType { REWRITE_PARTITION_TOPN_EXPRESSION(RuleTypeClass.REWRITE), REWRITE_QUALIFY_EXPRESSION(RuleTypeClass.REWRITE), REWRITE_TOPN_EXPRESSION(RuleTypeClass.REWRITE), + JOIN_EXTRACT_OR_FROM_CASE_WHEN(RuleTypeClass.REWRITE), EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE), REORDER_JOIN(RuleTypeClass.REWRITE), INIT_JOIN_ORDER(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index 9409261fc52d02..f0dd951e08f53b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -529,7 +529,7 @@ protected List bindSlotByThisScope(UnboundSlot unboundSlot Set havingExprs = having.getConjuncts(); ImmutableSet.Builder analyzedHaving = ImmutableSet.builder(); - ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(having, cascadesContext); Map bindUniqueIdReplaceMap = getGroupByUniqueFuncReplaceMap(aggregate.getGroupByExpressions()); for (Expression expression : havingExprs) { @@ -668,7 +668,7 @@ private LogicalPlan bindUsingJoin(MatchingContext> Scope leftScope = toScope(cascadesContext, using.left().getOutput(), using.left().getAsteriskOutput()); Scope rightScope = toScope(cascadesContext, using.right().getOutput(), using.right().getAsteriskOutput()); - ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(using, cascadesContext); Builder hashEqExprs = ImmutableList.builderWithExpectedSize(unboundHashJoinConjunct.size()); List rightConjunctsSlots = Lists.newArrayList(); @@ -1066,7 +1066,7 @@ protected List bindSlotByThisScope(UnboundSlot unboundSlot Map bindUniqueIdReplaceMap = getGroupByUniqueFuncReplaceMap(aggregate.getGroupByExpressions()); - ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(qualify, cascadesContext); for (Expression expression : qualify.getConjuncts()) { Expression boundExpr = qualifyAnalyzer.analyze(expression, rewriteContext); // logical plan builder no extract conjunction @@ -1598,7 +1598,7 @@ protected SimpleExprAnalyzer buildSimpleExprAnalyzer( Plan currentPlan, CascadesContext cascadesContext, List children) { Scope scope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(children), PlanUtils.fastGetChildrenAsteriskOutputs(children)); - ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext); ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan, scope, cascadesContext, true, true); return expr -> expressionAnalyzer.analyze(expr, rewriteContext); @@ -1607,7 +1607,7 @@ protected SimpleExprAnalyzer buildSimpleExprAnalyzer( private SimpleExprAnalyzer buildCustomSlotBinderAnalyzer( Plan currentPlan, CascadesContext cascadesContext, Scope defaultScope, boolean enableExactMatch, boolean bindSlotInOuterScope, CustomSlotBinderAnalyzer customSlotBinder) { - ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext); ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan, defaultScope, cascadesContext, enableExactMatch, bindSlotInOuterScope) { @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index fdb709516ea28f..eb1f47129040f2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -49,6 +49,7 @@ import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.DereferenceExpression; import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.EqualTo; @@ -127,13 +128,19 @@ /** ExpressionAnalyzer */ public class ExpressionAnalyzer extends SubExprAnalyzer { + // This rule only used in unit test @VisibleForTesting public static final AbstractExpressionRewriteRule FUNCTION_ANALYZER_RULE = new AbstractExpressionRewriteRule() { @Override public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return new ExpressionAnalyzer( - null, new Scope(ImmutableList.of()), null, false, false - ).analyze(expr, ctx); + return new ExpressionAnalyzer(null, new Scope(ImmutableList.of()), null, false, false) { + @Override + protected Expression processCompoundNewChildren(CompoundPredicate cp, List newChildren) { + // ExpressionUtils.and/ExpressionUtils.or will remove duplicate children, and simplify FALSE / TRUE. + // But we don't want to simplify them in unit test. + return cp.withChildren(newChildren); + } + }.analyze(expr, ctx); } }; @@ -170,13 +177,22 @@ public static Expression analyzeFunction( cascadesContext, false, false); return analyzer.analyze( expression, - cascadesContext == null ? null : new ExpressionRewriteContext(cascadesContext) + cascadesContext == null ? null : new ExpressionRewriteContext(plan, cascadesContext) ); } + /** analyze */ public Expression analyze(Expression expression) { CascadesContext cascadesContext = getCascadesContext(); - return analyze(expression, cascadesContext == null ? null : new ExpressionRewriteContext(cascadesContext)); + ExpressionRewriteContext rewriteContext; + if (cascadesContext == null) { + rewriteContext = null; + } else if (currentPlan == null) { + rewriteContext = new ExpressionRewriteContext(cascadesContext); + } else { + rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext); + } + return analyze(expression, rewriteContext); } /** analyze */ @@ -655,7 +671,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) { newChild = child; } if (newChild.getDataType().isNullType()) { - newChild = new NullLiteral(BooleanType.INSTANCE); + newChild = NullLiteral.BOOLEAN_INSTANCE; } else { newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE); } @@ -666,7 +682,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) { newChildren.add(newChild); } if (hasNewChild) { - return ExpressionUtils.or(newChildren); + return processCompoundNewChildren(or, newChildren); } else { return or; } @@ -683,23 +699,31 @@ public Expression visitAnd(And and, ExpressionRewriteContext context) { newChild = child; } if (newChild.getDataType().isNullType()) { - newChild = new NullLiteral(BooleanType.INSTANCE); + newChild = NullLiteral.BOOLEAN_INSTANCE; } else { newChild = TypeCoercionUtils.castIfNotSameType(newChild, BooleanType.INSTANCE); } - if (! child.equals(newChild)) { + if (!child.equals(newChild)) { hasNewChild = true; } newChildren.add(newChild); } if (hasNewChild) { - return ExpressionUtils.and(newChildren); + return processCompoundNewChildren(and, newChildren); } else { return and; } } + protected Expression processCompoundNewChildren(CompoundPredicate cp, List newChildren) { + if (cp instanceof And) { + return ExpressionUtils.and(newChildren); + } else { + return ExpressionUtils.or(newChildren); + } + } + @Override public Expression visitNot(Not not, ExpressionRewriteContext context) { // maybe is `not subquery`, we should bind it first diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java index 2e41a88c62f63c..d98d1dfcd1dc15 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java @@ -254,7 +254,8 @@ public List buildRules() { * if it's semi join with non-null mark slot * we can safely change the mark conjunct to hash conjunct */ - ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext rewriteContext + = new ExpressionRewriteContext(join, ctx.cascadesContext); boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class) ? ExpressionUtils.canInferNotNullForMarkSlot( TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, rewriteContext), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java index f20bcd54ecb0d0..e4823347d94a9e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java @@ -33,6 +33,7 @@ import org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticComparisonRule; import org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticRule; import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule; +import org.apache.doris.nereids.rules.expression.rules.SimplifyEqualBooleanLiteral; import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule; import org.apache.doris.nereids.rules.expression.rules.SupportJavaDateFormatter; import org.apache.doris.nereids.rules.expression.rules.TimestampToAddTime; @@ -72,7 +73,8 @@ public class ExpressionNormalization extends ExpressionRewrite { ConvertAggStateCast.INSTANCE, MergeDateTrunc.INSTANCE, NormalizeStructElement.INSTANCE, - CheckCast.INSTANCE + CheckCast.INSTANCE, + SimplifyEqualBooleanLiteral.INSTANCE ) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index 489af4b331cb3d..b5b806aba39c23 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -20,12 +20,16 @@ import org.apache.doris.nereids.rules.expression.rules.AddMinMax; import org.apache.doris.nereids.rules.expression.rules.ArrayContainToArrayOverlap; import org.apache.doris.nereids.rules.expression.rules.BetweenToEqual; +import org.apache.doris.nereids.rules.expression.rules.CaseWhenToCompoundPredicate; import org.apache.doris.nereids.rules.expression.rules.CaseWhenToIf; +import org.apache.doris.nereids.rules.expression.rules.CondReplaceNullWithFalse; import org.apache.doris.nereids.rules.expression.rules.DateFunctionRewrite; import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule; import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule; import org.apache.doris.nereids.rules.expression.rules.LikeToEqualRewrite; +import org.apache.doris.nereids.rules.expression.rules.NestedCaseWhenCondToLiteral; import org.apache.doris.nereids.rules.expression.rules.NullSafeEqualToEqual; +import org.apache.doris.nereids.rules.expression.rules.PushIntoCaseWhenBranch; import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate; import org.apache.doris.nereids.rules.expression.rules.SimplifyConflictCompound; import org.apache.doris.nereids.rules.expression.rules.SimplifyInPredicate; @@ -57,7 +61,11 @@ public class ExpressionOptimization extends ExpressionRewrite { DateFunctionRewrite.INSTANCE, ArrayContainToArrayOverlap.INSTANCE, + CondReplaceNullWithFalse.INSTANCE, + NestedCaseWhenCondToLiteral.INSTANCE, CaseWhenToIf.INSTANCE, + CaseWhenToCompoundPredicate.INSTANCE, + PushIntoCaseWhenBranch.INSTANCE, TopnToMax.INSTANCE, NullSafeEqualToEqual.INSTANCE, LikeToEqualRewrite.INSTANCE, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java index e41ce5c39925d9..22c5d29043b2ea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext.ExpressionSource; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.And; @@ -53,6 +54,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; @@ -121,7 +123,7 @@ public class GenerateExpressionRewrite extends OneRewriteRuleFactory { public Rule build() { return logicalGenerate().thenApply(ctx -> { LogicalGenerate generate = ctx.root; - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(generate, ctx.cascadesContext); List generators = generate.getGenerators(); List newGenerators = generators.stream() .map(func -> (Function) rewriter.rewrite(func, context)) @@ -141,7 +143,7 @@ public Rule build() { return logicalOneRowRelation().thenApply(ctx -> { LogicalOneRowRelation oneRowRelation = ctx.root; List projects = oneRowRelation.getProjects(); - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(oneRowRelation, ctx.cascadesContext); Builder rewrittenExprs = ImmutableList.builderWithExpectedSize(projects.size()); @@ -166,7 +168,7 @@ public class ProjectExpressionRewrite extends OneRewriteRuleFactory { public Rule build() { return logicalProject().thenApply(ctx -> { LogicalProject project = ctx.root; - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(project, ctx.cascadesContext); List projects = project.getProjects(); RewriteResult result = rewriteAll(projects, rewriter, context); if (!result.changed) { @@ -183,7 +185,7 @@ public class FilterExpressionRewrite extends OneRewriteRuleFactory { public Rule build() { return logicalFilter().thenApply(ctx -> { LogicalFilter filter = ctx.root; - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(filter, ctx.cascadesContext); Expression originPredicate = filter.getPredicate(); Expression predicate = rewriter.rewrite(originPredicate, context); if (predicate == originPredicate && !(predicate instanceof And)) { @@ -206,7 +208,7 @@ public class LogicalOlapTableSinkExpressionRewrite extends OneRewriteRuleFactory public Rule build() { return logicalOlapTableSink().thenApply(ctx -> { LogicalOlapTableSink olapTableSink = ctx.root; - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(olapTableSink, ctx.cascadesContext); List partitionExprList = olapTableSink.getPartitionExprList(); RewriteResult result = rewriteAll(partitionExprList, rewriter, context); Map syncMvWhereClauses = olapTableSink.getSyncMvWhereClauses(); @@ -229,7 +231,7 @@ public Rule build() { return logicalAggregate().thenApply(ctx -> { LogicalAggregate agg = ctx.root; List groupByExprs = agg.getGroupByExpressions(); - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(agg, ctx.cascadesContext); List newGroupByExprs = rewriter.rewrite(groupByExprs, context); List outputExpressions = agg.getOutputExpressions(); @@ -257,18 +259,32 @@ public Rule build() { return join; } - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); - Pair> newHashJoinConjuncts = rewriteConjuncts(hashJoinConjuncts, context); - Pair> newOtherJoinConjuncts = rewriteConjuncts(otherJoinConjuncts, context); - Pair> newMarkJoinConjuncts = rewriteConjuncts(markJoinConjuncts, context); + Pair> newHashJoinConjuncts = rewriteConjuncts(hashJoinConjuncts, + new ExpressionRewriteContext(join, ExpressionSource.JOIN_HASH_CONDITION, ctx.cascadesContext)); + Pair> newOtherJoinConjuncts = rewriteConjuncts(otherJoinConjuncts, + new ExpressionRewriteContext(join, ExpressionSource.JOIN_OTHER_CONDITION, ctx.cascadesContext)); + Pair> newMarkJoinConjuncts = rewriteConjuncts(markJoinConjuncts, + new ExpressionRewriteContext(join, ExpressionSource.JOIN_MARK_CONDITION, ctx.cascadesContext)); if (!newHashJoinConjuncts.first && !newOtherJoinConjuncts.first && !newMarkJoinConjuncts.first) { return join; } - return new LogicalJoin<>(join.getJoinType(), newHashJoinConjuncts.second, - newOtherJoinConjuncts.second, newMarkJoinConjuncts.second, + List newOtherConjunctsList = newOtherJoinConjuncts.second; + List newHashConjunctsList = newHashJoinConjuncts.second; + List newMarkConjunctsList = newMarkJoinConjuncts.second; + // split hash join conjuncts and other conjuncts + Pair, List> splitResult = JoinUtils.extractExpressionForHashTable( + join.left().getOutput(), join.right().getOutput(), newHashConjunctsList); + if (!splitResult.second.isEmpty()) { + newHashConjunctsList = ImmutableList.copyOf(splitResult.first); + newOtherConjunctsList = ImmutableList.builder() + .addAll(newOtherConjunctsList).addAll(splitResult.second).build(); + } + + return new LogicalJoin<>(join.getJoinType(), + newHashConjunctsList, newOtherConjunctsList, newMarkConjunctsList, join.getDistributeHint(), join.getMarkJoinSlotReference(), join.children(), join.getJoinReorderContext()); }).toRule(RuleType.REWRITE_JOIN_EXPRESSION); @@ -309,7 +325,7 @@ public Rule build() { List orderKeys = sort.getOrderKeys(); ImmutableList.Builder rewrittenOrderKeys = ImmutableList.builderWithExpectedSize(orderKeys.size()); - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(sort, ctx.cascadesContext); boolean changed = false; for (OrderKey k : orderKeys) { Expression expression = rewriter.rewrite(k.getExpr(), context); @@ -327,7 +343,7 @@ public class HavingExpressionRewrite extends OneRewriteRuleFactory { public Rule build() { return logicalHaving().thenApply(ctx -> { LogicalHaving having = ctx.root; - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(having, ctx.cascadesContext); Set newConjuncts = ImmutableSet.copyOf(ExpressionUtils.extractConjunction( rewriter.rewrite(having.getPredicate(), context))); if (newConjuncts.equals(having.getConjuncts())) { @@ -344,7 +360,7 @@ public Rule build() { return logicalWindow().thenApply(ctx -> { LogicalWindow window = ctx.root; List windowExpressions = window.getWindowExpressions(); - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(window, ctx.cascadesContext); RewriteResult result = rewriteAll(windowExpressions, rewriter, context); if (!result.changed) { return window; @@ -362,7 +378,7 @@ public Rule build() { LogicalSetOperation setOperation = ctx.root; List> slotsList = setOperation.getRegularChildrenOutputs(); List> newSlotsList = new ArrayList<>(); - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(setOperation, ctx.cascadesContext); boolean changed = false; for (List slots : slotsList) { RewriteResult result = rewriteAll(slots, rewriter, context); @@ -409,7 +425,7 @@ public Rule build() { List orderKeys = topN.getOrderKeys(); ImmutableList.Builder rewrittenOrderKeys = ImmutableList.builderWithExpectedSize(orderKeys.size()); - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(topN, ctx.cascadesContext); boolean changed = false; for (OrderKey k : orderKeys) { Expression expression = rewriter.rewrite(k.getExpr(), context); @@ -426,7 +442,7 @@ private class LogicalPartitionTopNExpressionRewrite extends OneRewriteRuleFactor public Rule build() { return logicalPartitionTopN().thenApply(ctx -> { LogicalPartitionTopN partitionTopN = ctx.root; - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(partitionTopN, ctx.cascadesContext); List newOrderExpressions = new ArrayList<>(); boolean changed = false; for (OrderExpression orderExpression : partitionTopN.getOrderKeys()) { @@ -453,7 +469,7 @@ public Rule build() { return logicalCTEConsumer().thenApply(ctx -> { LogicalCTEConsumer consumer = ctx.root; boolean changed = false; - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(consumer, ctx.cascadesContext); ImmutableMap.Builder cToPBuilder = ImmutableMap.builder(); ImmutableMultimap.Builder pToCBuilder = ImmutableMultimap.builder(); for (Map.Entry entry : consumer.getConsumerToProducerOutputMap().entrySet()) { @@ -536,7 +552,7 @@ public Rule build() { private LogicalSink applyRewriteToSink(MatchingContext> ctx) { LogicalSink sink = ctx.root; - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(sink, ctx.cascadesContext); List outputExprs = sink.getOutputExprs(); RewriteResult result = rewriteAll(outputExprs, rewriter, context); if (!result.changed) { @@ -552,7 +568,7 @@ public Rule build() { return logicalRepeat().thenApply(ctx -> { LogicalRepeat repeat = ctx.root; ImmutableList.Builder> groupingExprs = ImmutableList.builder(); - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(repeat, ctx.cascadesContext); for (List expressions : repeat.getGroupingSets()) { groupingExprs.add(expressions.stream() .map(expr -> rewriter.rewrite(expr, context)) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java index 35633e7594f717..b634649a6b6e00 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java @@ -18,18 +18,46 @@ package org.apache.doris.nereids.rules.expression; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.trees.plans.Plan; import java.util.Objects; +import java.util.Optional; /** * expression rewrite context. */ public class ExpressionRewriteContext { + public final Optional plan; + public final Optional source; public final CascadesContext cascadesContext; public ExpressionRewriteContext(CascadesContext cascadesContext) { + this(Optional.empty(), Optional.empty(), cascadesContext); + } + + public ExpressionRewriteContext(Plan plan, CascadesContext cascadesContext) { + this(Optional.of(plan), Optional.empty(), cascadesContext); + } + + public ExpressionRewriteContext(Plan plan, ExpressionSource source, CascadesContext cascadesContext) { + this(Optional.of(plan), Optional.of(source), cascadesContext); + } + + private ExpressionRewriteContext(Optional plan, Optional source, + CascadesContext cascadesContext) { + this.plan = Objects.requireNonNull(plan, "plan can not be null, or use Optional.empty()"); + this.source = Objects.requireNonNull(source, "source can not be null, or use Optional.empty()"); this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext can not be null"); } + /** + * Expression detail source from. + * Currently only used in Join, add more if needed. + */ + public enum ExpressionSource { + JOIN_HASH_CONDITION, + JOIN_OTHER_CONDITION, + JOIN_MARK_CONDITION, + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java index b8803095d8108f..01e5263c878ba8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java @@ -24,8 +24,10 @@ public enum ExpressionRuleType { ADD_MIN_MAX, ARRAY_CONTAIN_TO_ARRAY_OVERLAP, BETWEEN_TO_EQUAL, + CASE_WHEN_TO_COMPOUND_PREDICATE, CASE_WHEN_TO_IF, CHECK_CAST, + COND_REPLACE_NULL_WITH_FALSE, CONVERT_AGG_STATE_CAST, CONCATWS_MULTI_ARRAY_TO_ONE, DATE_FUNCTION_REWRITE, @@ -37,14 +39,17 @@ public enum ExpressionRuleType { FOLD_CONSTANT_ON_BE, FOLD_CONSTANT_ON_FE, LOG_TO_LN, + IF_TO_COMPOUND_PREDICATE, IN_PREDICATE_DEDUP, IN_PREDICATE_EXTRACT_NON_CONSTANT, IN_PREDICATE_TO_EQUAL_TO, LIKE_TO_EQUAL, MERGE_DATE_TRUNC, MEDIAN_CONVERT, + NESTED_CASE_WHEN_COND_TO_LITERAL, NORMALIZE_BINARY_PREDICATES, NULL_SAFE_EQUAL_TO_EQUAL, + PUSH_INTO_CASE_WHEN_BRANCH, REPLACE_VARIABLE_BY_LITERAL, SIMPLIFY_ARITHMETIC_COMPARISON, SIMPLIFY_ARITHMETIC, @@ -52,6 +57,7 @@ public enum ExpressionRuleType { SIMPLIFY_COMPARISON_PREDICATE, SIMPLIFY_CONDITIONAL_FUNCTION, SIMPLIFY_CONFLICT_COMPOUND, + SIMPLIFY_EQUAL_BOOLEAN_LITERAL, SIMPLIFY_IN_PREDICATE, SIMPLIFY_NOT_EXPR, SIMPLIFY_RANGE, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java index e3ea1a43b3ab47..bf42a796846f48 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java @@ -21,11 +21,17 @@ import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.rules.expression.rules.AddMinMax.MinMaxValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.CompoundValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.DiscreteValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.EmptyValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNotNullValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNullValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.NotDiscreteValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.RangeValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDescVisitor; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; @@ -39,6 +45,7 @@ import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.PlanUtils; import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableList; @@ -46,7 +53,6 @@ import com.google.common.collect.Maps; import com.google.common.collect.Range; import com.google.common.collect.Sets; -import org.apache.commons.lang3.NotImplementedException; import java.util.List; import java.util.Map; @@ -63,13 +69,14 @@ * a between 10 and 20 and b between 10 and 20 or a between 100 and 200 and b between 100 and 200 * => (a <= 20 and b <= 20 or a >= 100 and b >= 100) and a >= 10 and a <= 200 and b >= 10 and b <= 200 */ -public class AddMinMax implements ExpressionPatternRuleFactory { +public class AddMinMax implements ExpressionPatternRuleFactory, ValueDescVisitor, Void> { public static final AddMinMax INSTANCE = new AddMinMax(); @Override public List> buildRules() { return ImmutableList.of( matchesTopType(CompoundPredicate.class) + .whenCtx(ctx -> PlanUtils.isConditionExpressionPlan(ctx.rewriteContext.plan.orElse(null))) .thenApply(ctx -> rewrite(ctx.expr, ctx.rewriteContext)) .toRule(ExpressionRuleType.ADD_MIN_MAX) ); @@ -78,7 +85,7 @@ public List> buildRules() { /** rewrite */ public Expression rewrite(CompoundPredicate expr, ExpressionRewriteContext context) { ValueDesc valueDesc = (new RangeInference()).getValue(expr, context); - Map exprMinMaxValues = getExprMinMaxValues(valueDesc); + Map exprMinMaxValues = valueDesc.accept(this, null); removeUnnecessaryMinMaxValues(expr, exprMinMaxValues); if (!exprMinMaxValues.isEmpty()) { return addExprMinMaxValues(expr, context, exprMinMaxValues); @@ -92,7 +99,8 @@ private enum MatchMinMax { MATCH_NONE, } - private static class MinMaxValue { + /** record each expression's min and max value */ + public static class MinMaxValue { // min max range, if range = null means empty Range range; @@ -280,21 +288,8 @@ private boolean isExprNeedAddMinMax(Expression expr) { return (expr instanceof SlotReference) && ((SlotReference) expr).getOriginalColumn().isPresent(); } - private Map getExprMinMaxValues(ValueDesc value) { - if (value instanceof EmptyValue) { - return getExprMinMaxValues((EmptyValue) value); - } else if (value instanceof DiscreteValue) { - return getExprMinMaxValues((DiscreteValue) value); - } else if (value instanceof RangeValue) { - return getExprMinMaxValues((RangeValue) value); - } else if (value instanceof UnknownValue) { - return getExprMinMaxValues((UnknownValue) value); - } else { - throw new NotImplementedException("not implements"); - } - } - - private Map getExprMinMaxValues(EmptyValue value) { + @Override + public Map visitEmptyValue(EmptyValue value, Void context) { Expression reference = value.getReference(); Map exprMinMaxValues = Maps.newHashMap(); if (isExprNeedAddMinMax(reference)) { @@ -303,7 +298,8 @@ private Map getExprMinMaxValues(EmptyValue value) { return exprMinMaxValues; } - private Map getExprMinMaxValues(DiscreteValue value) { + @Override + public Map visitDiscreteValue(DiscreteValue value, Void context) { Expression reference = value.getReference(); Map exprMinMaxValues = Maps.newHashMap(); if (isExprNeedAddMinMax(reference)) { @@ -312,7 +308,23 @@ private Map getExprMinMaxValues(DiscreteValue value) { return exprMinMaxValues; } - private Map getExprMinMaxValues(RangeValue value) { + @Override + public Map visitNotDiscreteValue(NotDiscreteValue value, Void context) { + return Maps.newHashMap(); + } + + @Override + public Map visitIsNullValue(IsNullValue value, Void context) { + return Maps.newHashMap(); + } + + @Override + public Map visitIsNotNullValue(IsNotNullValue value, Void context) { + return Maps.newHashMap(); + } + + @Override + public Map visitRangeValue(RangeValue value, Void context) { Expression reference = value.getReference(); Map exprMinMaxValues = Maps.newHashMap(); if (isExprNeedAddMinMax(reference)) { @@ -321,16 +333,14 @@ private Map getExprMinMaxValues(RangeValue value) { return exprMinMaxValues; } - private Map getExprMinMaxValues(UnknownValue valueDesc) { + @Override + public Map visitCompoundValue(CompoundValue valueDesc, Void context) { List sourceValues = valueDesc.getSourceValues(); - if (sourceValues.isEmpty()) { - return Maps.newHashMap(); - } - Map result = Maps.newHashMap(getExprMinMaxValues(sourceValues.get(0))); + Map result = Maps.newHashMap(sourceValues.get(0).accept(this, context)); int nextExprOrderIndex = result.values().stream().mapToInt(k -> k.exprOrderIndex).max().orElse(0); for (int i = 1; i < sourceValues.size(); i++) { // process in sourceValues[i] - Map minMaxValues = getExprMinMaxValues(sourceValues.get(i)); + Map minMaxValues = sourceValues.get(i).accept(this, context); // merge values of sourceValues[i] into result. // also keep the value's relative order in sourceValues[i]. // for example, if a and b in sourceValues[i], but not in result, then during merging, @@ -398,4 +408,9 @@ private Map getExprMinMaxValues(UnknownValue valueDesc) } return result; } + + @Override + public Map visitUnknownValue(UnknownValue valueDesc, Void context) { + return Maps.newHashMap(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicate.java new file mode 100644 index 00000000000000..51f4fce171af70 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicate.java @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.NullSafeEqual; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +/** + * if case when all branch value are true/false literal, and the ELSE default value can be any expression, + * then can eliminate this case when. + * + * for example: + * 1. case when c1 then true when c2 then false end => (c1 <=> true or (not (c2 <=> true) and null)) + * 2. if (c1, true, false) => c1 <=> true or false + * 3. in a condition expression: + * if(c, p, true) => not(c <=> true) or p + * if(c, p, false) => c and p + * + */ +public class CaseWhenToCompoundPredicate implements ExpressionPatternRuleFactory { + public static final CaseWhenToCompoundPredicate INSTANCE = new CaseWhenToCompoundPredicate(); + private static final IfToCompoundPredicateInCond IF_REWRITE_IN_COND = new IfToCompoundPredicateInCond(); + + @Override + public List> buildRules() { + ImmutableList.Builder> rulesBuilder + = ImmutableList.builder(); + rulesBuilder.add(matchesType(CaseWhen.class) + .when(this::checkBooleanType) + .then(this::rewriteCaseWhen) + .toRule(ExpressionRuleType.CASE_WHEN_TO_COMPOUND_PREDICATE)); + rulesBuilder.add(matchesType(If.class) + .when(this::checkBooleanType) + .then(this::rewriteIf) + .toRule(ExpressionRuleType.IF_TO_COMPOUND_PREDICATE)); + rulesBuilder.addAll(IF_REWRITE_IN_COND.buildRules()); + return rulesBuilder.build(); + } + + private boolean checkBooleanType(Expression expression) { + return expression.getDataType().isBooleanType(); + } + + private Expression rewriteCaseWhen(CaseWhen caseWhen) { + Expression defaultValue = caseWhen.getDefaultValue().orElse(NullLiteral.BOOLEAN_INSTANCE); + return rewrite(caseWhen.getWhenClauses(), defaultValue).orElse(caseWhen); + } + + private Expression rewriteIf(If ifExpr) { + List whenClauses = ImmutableList.of(new WhenClause(ifExpr.getCondition(), ifExpr.getTrueValue())); + Expression defaultValue = ifExpr.getFalseValue(); + return rewrite(whenClauses, defaultValue).orElse(ifExpr); + } + + // for a branch, suppose the branches later it can rewrite to X, then given the branch: + // 1. when c then true ..., will rewrite to (c <=> true OR X), + // 2. when c then false ..., will rewrite to (not(c <=> true) AND X), + // for the ELSE branch, it can rewrite to `when true then defaultValue`, + // process the branches from back to front, the default value process first, while the first when clause will + // process last. + private Optional rewrite(List whenClauses, Expression defaultValue) { + for (WhenClause whenClause : whenClauses) { + Expression result = whenClause.getResult(); + if (!(result instanceof BooleanLiteral)) { + return Optional.empty(); + } + } + Expression result = defaultValue; + try { + for (int i = whenClauses.size() - 1; i >= 0; i--) { + WhenClause whenClause = whenClauses.get(i); + // operand <=> true + Expression condition = new NullSafeEqual(whenClause.getOperand(), BooleanLiteral.TRUE); + if (whenClause.getResult().equals(BooleanLiteral.TRUE)) { + result = new Or(condition, result); + } else { + result = new And(new Not(condition), result); + } + } + } catch (Exception e) { + // expression may exceed expression limit + return Optional.empty(); + } + return Optional.of(result); + } + + private static class IfToCompoundPredicateInCond extends ConditionRewrite { + @Override + public List> buildRules() { + return buildCondRules(ExpressionRuleType.IF_TO_COMPOUND_PREDICATE); + } + + // rewrite all the expression tree, not only the condition part. + @Override + protected boolean needRewrite(Expression expression, boolean isInsideCondition) { + return expression.containsType(If.class) + && expression.containsType(BooleanLiteral.class, NullLiteral.class); + } + + @Override + public Expression visitIf(If ifExpr, Boolean isInsideCondition) { + If newIf = (If) super.visitIf(ifExpr, isInsideCondition); + if (isInsideCondition) { + Expression newCondition = newIf.getCondition(); + Expression newTrueValue = newIf.getTrueValue(); + Expression newFalseValue = newIf.getFalseValue(); + if (newFalseValue.equals(BooleanLiteral.TRUE)) { + // if (c, p, true) => not(c <=> true) || p + return ExpressionUtils.or( + new Not(new NullSafeEqual(newCondition, BooleanLiteral.TRUE)), newTrueValue); + } else if (newFalseValue.equals(BooleanLiteral.FALSE) + || newFalseValue.equals(NullLiteral.BOOLEAN_INSTANCE)) { + // if (c, p, false) => c and p + return ExpressionUtils.and(newCondition, newTrueValue); + } + } + return newIf; + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CondReplaceNullWithFalse.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CondReplaceNullWithFalse.java new file mode 100644 index 00000000000000..a2b2926ae50849 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CondReplaceNullWithFalse.java @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; + +import java.util.List; + +/** + * replace null with false for condition expression. + * a) if(null and a > 1, ...) => if(false and a > 1, ...) + * b) case when null and a > 1 then ... => case when false and a > 1 then ... + * c) null or (null and a > 1) or not(null) => false or (false and a > 1) or not(null) + */ +public class CondReplaceNullWithFalse extends ConditionRewrite { + + public static final CondReplaceNullWithFalse INSTANCE = new CondReplaceNullWithFalse(); + + @Override + public List> buildRules() { + return buildCondRules(ExpressionRuleType.COND_REPLACE_NULL_WITH_FALSE); + } + + @Override + protected boolean needRewrite(Expression expression, boolean isInsideCondition) { + if (!super.needRewrite(expression, isInsideCondition)) { + return false; + } + + return expression.containsType(NullLiteral.class); + } + + @Override + public Expression visitNullLiteral(NullLiteral nullLiteral, Boolean isInsideCondition) { + if (isInsideCondition + && (nullLiteral.getDataType().isBooleanType() || nullLiteral.getDataType().isNullType())) { + return BooleanLiteral.FALSE; + } + return nullLiteral; + } + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConditionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConditionRewrite.java new file mode 100644 index 00000000000000..c28a395a97bd05 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConditionRewrite.java @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext.ExpressionSource; +import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.CompoundPredicate; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.PlanUtils; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * + * Here condition expression means the expression used in filter, join condition, case when condition, if condition. + * And the rewrite argument isInsideCondition means the expression's ancestors to the condition root + * are AND/OR/CASE WHEN/IF. + * + * for example: for 'a and not(b)' in filter, when visit 'a', isInsideCondition is true, + * while visit 'b' isInsideCondition is false because its parent NOT is not AND/OR/CASE WHEN/IF. + * + */ +public abstract class ConditionRewrite extends DefaultExpressionRewriter + implements ExpressionPatternRuleFactory { + + protected List> buildCondRules(ExpressionRuleType ruleType) { + return ImmutableList.of( + root(Expression.class) + .thenApply(ctx -> rewrite(ctx.expr, ctx.rewriteContext)) + .toRule(ruleType) + ); + } + + protected boolean needRewrite(Expression expression, boolean isInsideCondition) { + return isInsideCondition || expression.containsType(WhenClause.class, If.class); + } + + protected Expression rewrite(Expression expression, ExpressionRewriteContext context) { + return expression.accept(this, rootIsCondition(context)); + } + + // for the expression root, only filter and join expression can treat as condition + protected boolean rootIsCondition(ExpressionRewriteContext context) { + Plan plan = context.plan.orElse(null); + if (plan instanceof LogicalJoin) { + // null aware join can not treat null as false + ExpressionSource source = context.source.orElse(null); + return ((LogicalJoin) plan).getJoinType() != JoinType.NULL_AWARE_LEFT_ANTI_JOIN + && (source == ExpressionSource.JOIN_HASH_CONDITION + || source == ExpressionSource.JOIN_OTHER_CONDITION); + } else { + return PlanUtils.isConditionExpressionPlan(plan); + } + } + + @Override + public Expression visit(Expression expr, Boolean isInsideCondition) { + if (needRewrite(expr, isInsideCondition)) { + return super.visit(expr, false); + } else { + return expr; + } + } + + @Override + public Expression visitCompoundPredicate(CompoundPredicate predicate, Boolean isInsideCondition) { + if (!needRewrite(predicate, isInsideCondition)) { + return predicate; + } + boolean changed = false; + ImmutableList.Builder builder + = ImmutableList.builderWithExpectedSize(predicate.children().size()); + for (Expression child : predicate.children()) { + Expression newChild = child.accept(this, isInsideCondition); + if (newChild != child) { + changed = true; + } + if (newChild.getClass() == predicate.getClass()) { + builder.addAll(newChild.children()); + changed = true; + } else { + builder.add(newChild); + } + } + if (changed) { + return predicate.withChildren(builder.build()); + } else { + return predicate; + } + } + + @Override + public Expression visitCaseWhen(CaseWhen caseWhen, Boolean isInsideCondition) { + if (!needRewrite(caseWhen, isInsideCondition)) { + return caseWhen; + } + boolean changed = false; + ImmutableList.Builder whenClausesBuilder + = ImmutableList.builderWithExpectedSize(caseWhen.getWhenClauses().size()); + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + WhenClause newWhenClause = (WhenClause) whenClause.accept(this, isInsideCondition); + if (newWhenClause != whenClause) { + changed = true; + } + whenClausesBuilder.add(newWhenClause); + } + Expression oldDefaultValue = caseWhen.getDefaultValue().orElse(null); + Expression newDefaultValue = oldDefaultValue; + if (oldDefaultValue != null) { + newDefaultValue = oldDefaultValue.accept(this, isInsideCondition); + } + if (newDefaultValue != oldDefaultValue) { + changed = true; + } + if (changed) { + return newDefaultValue != null + ? new CaseWhen(whenClausesBuilder.build(), newDefaultValue) + : new CaseWhen(whenClausesBuilder.build()); + } else { + return caseWhen; + } + } + + @Override + public Expression visitWhenClause(WhenClause whenClause, Boolean isInsideCondition) { + if (!needRewrite(whenClause, isInsideCondition)) { + return whenClause; + } + Expression newOperand = whenClause.getOperand().accept(this, true); + Expression newResult = whenClause.getResult().accept(this, isInsideCondition); + if (newOperand != whenClause.getOperand() || newResult != whenClause.getResult()) { + return new WhenClause(newOperand, newResult); + } else { + return whenClause; + } + } + + @Override + public Expression visitIf(If ifExpr, Boolean isInsideCondition) { + if (!needRewrite(ifExpr, isInsideCondition)) { + return ifExpr; + } + Expression newCondition = ifExpr.getCondition().accept(this, true); + Expression newTrueValue = ifExpr.getTrueValue().accept(this, isInsideCondition); + Expression newFalseValue = ifExpr.getFalseValue().accept(this, isInsideCondition); + if (newCondition != ifExpr.getCondition() + || newTrueValue != ifExpr.getTrueValue() + || newFalseValue != ifExpr.getFalseValue()) { + return new If(newCondition, newTrueValue, newFalseValue); + } else { + return ifExpr; + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java index bfc983915b37de..1560ab55370e64 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java @@ -64,7 +64,6 @@ import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; import org.apache.doris.nereids.trees.expressions.functions.scalar.Array; import org.apache.doris.nereids.trees.expressions.functions.scalar.ConnectionId; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog; @@ -74,6 +73,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.EncryptKeyRef; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.functions.scalar.LastQueryId; +import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf; import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; import org.apache.doris.nereids.trees.expressions.functions.scalar.Password; import org.apache.doris.nereids.trees.expressions.functions.scalar.SessionUser; @@ -92,7 +92,6 @@ import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; -import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; @@ -106,12 +105,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import org.apache.commons.codec.digest.DigestUtils; -import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.function.BiFunction; import java.util.function.Predicate; @@ -186,6 +186,7 @@ public List> buildRules() { matches(SessionUser.class, this::visitSessionUser), matches(LastQueryId.class, this::visitLastQueryId), matches(Nvl.class, this::visitNvl), + matches(NullIf.class, this::visitNullIf), matches(Match.class, this::visitMatch) ); } @@ -444,7 +445,7 @@ public Expression visitAnd(And and, ExpressionRewriteContext context) { } } else { // null and null and null and ... - return new NullLiteral(BooleanType.INSTANCE); + return NullLiteral.BOOLEAN_INSTANCE; } } @@ -490,7 +491,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) { return or.withChildren(nonFalseLiteral); } else { // null or null - return new NullLiteral(BooleanType.INSTANCE); + return NullLiteral.BOOLEAN_INSTANCE; } } @@ -551,9 +552,6 @@ public Expression visitTryCast(TryCast cast, ExpressionRewriteContext context) { @Override public Expression visitBoundFunction(BoundFunction boundFunction, ExpressionRewriteContext context) { - if (!boundFunction.foldable()) { - return boundFunction; - } boundFunction = rewriteChildren(boundFunction, context); Optional checkedExpr = preProcess(boundFunction); if (checkedExpr.isPresent()) { @@ -576,57 +574,64 @@ public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, Expre public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) { CaseWhen originCaseWhen = caseWhen; caseWhen = rewriteChildren(caseWhen, context); - Expression newDefault = null; - boolean foundNewDefault = false; - - List whenClauses = new ArrayList<>(); + final Expression oldDefault = caseWhen.getDefaultValue().orElse(null); + Expression newDefault = oldDefault; + ImmutableList.Builder whenClausesBuilder + = ImmutableList.builderWithExpectedSize(caseWhen.getWhenClauses().size()); + Set uniqueOperands = Sets.newHashSet(); for (WhenClause whenClause : caseWhen.getWhenClauses()) { Expression whenOperand = whenClause.getOperand(); - - if (!(whenOperand.isLiteral())) { - whenClauses.add(new WhenClause(whenOperand, whenClause.getResult())); + if (!whenOperand.isLiteral() && uniqueOperands.add(whenOperand)) { + whenClausesBuilder.add(new WhenClause(whenOperand, whenClause.getResult())); } else if (BooleanLiteral.TRUE.equals(whenOperand)) { - foundNewDefault = true; newDefault = whenClause.getResult(); break; } } - - Expression defaultResult = null; - if (caseWhen.getDefaultValue().isPresent()) { - defaultResult = caseWhen.getDefaultValue().get(); - } - if (foundNewDefault) { - defaultResult = newDefault; + List newWhenClauses = whenClausesBuilder.build(); + Expression realTypeCoercionDefault = newDefault != null ? newDefault : new NullLiteral(caseWhen.getDataType()); + boolean allThenEqualsDefault = true; + for (WhenClause whenClause : newWhenClauses) { + if (!whenClause.getResult().equals(realTypeCoercionDefault)) { + allThenEqualsDefault = false; + break; + } } - if (whenClauses.isEmpty()) { - return TypeCoercionUtils.ensureSameResultType( - originCaseWhen, defaultResult == null ? new NullLiteral(caseWhen.getDataType()) : defaultResult, - context - ); + if (allThenEqualsDefault) { + return realTypeCoercionDefault; } - if (defaultResult == null) { - if (caseWhen.getDataType().isNullType()) { - // if caseWhen's type is NULL_TYPE, means all possible return values are nulls - // it's safe to return null literal here - return new NullLiteral(); - } else { - return TypeCoercionUtils.ensureSameResultType(originCaseWhen, new CaseWhen(whenClauses), context); + boolean hasNewChildren = newWhenClauses.size() != caseWhen.getWhenClauses().size() + || newDefault != oldDefault; + if (newWhenClauses.size() == caseWhen.getWhenClauses().size()) { + for (int i = 0; i < newWhenClauses.size(); i++) { + if (newWhenClauses.get(i) != caseWhen.getWhenClauses().get(i)) { + hasNewChildren = true; + break; + } } } - return TypeCoercionUtils.ensureSameResultType( - originCaseWhen, new CaseWhen(whenClauses, defaultResult), context - ); + if (hasNewChildren) { + caseWhen = newDefault == null + ? new CaseWhen(newWhenClauses) : new CaseWhen(newWhenClauses, newDefault); + } + return TypeCoercionUtils.ensureSameResultType(originCaseWhen, caseWhen, context); } @Override public Expression visitIf(If ifExpr, ExpressionRewriteContext context) { If originIf = ifExpr; ifExpr = rewriteChildren(ifExpr, context); - if (ifExpr.child(0) instanceof NullLiteral || ifExpr.child(0).equals(BooleanLiteral.FALSE)) { - return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.child(2), context); - } else if (ifExpr.child(0).equals(BooleanLiteral.TRUE)) { - return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.child(1), context); + Expression condition = ifExpr.getCondition(); + Expression typeCoercionTrueValue + = TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.getTrueValue(), context); + Expression typeCoercionFalseValue + = TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.getFalseValue(), context); + if (condition.equals(BooleanLiteral.TRUE)) { + return typeCoercionTrueValue; + } else if (condition.equals(BooleanLiteral.FALSE) || condition.isNullLiteral()) { + return typeCoercionFalseValue; + } else if (typeCoercionTrueValue.equals(typeCoercionFalseValue)) { + return typeCoercionTrueValue; } return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr, context); } @@ -641,7 +646,7 @@ public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteCon // now the inPredicate contains literal only. Expression value = inPredicate.child(0); if (value.isNullLiteral()) { - return new NullLiteral(BooleanType.INSTANCE); + return NullLiteral.BOOLEAN_INSTANCE; } boolean isOptionContainsNull = false; for (Expression item : inPredicate.getOptions()) { @@ -652,7 +657,7 @@ public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteCon } } return isOptionContainsNull - ? new NullLiteral(BooleanType.INSTANCE) + ? NullLiteral.BOOLEAN_INSTANCE : BooleanLiteral.FALSE; } @@ -728,18 +733,39 @@ public Expression visitVersion(Version version, ExpressionRewriteContext context public Expression visitNvl(Nvl nvl, ExpressionRewriteContext context) { Nvl originNvl = nvl; nvl = rewriteChildren(nvl, context); - - for (Expression expr : nvl.children()) { - if (expr.isLiteral()) { - if (!expr.isNullLiteral()) { - return TypeCoercionUtils.ensureSameResultType(originNvl, expr, context); - } - } else { - return TypeCoercionUtils.ensureSameResultType(originNvl, nvl, context); + Expression first = nvl.left(); + Expression second = nvl.right(); + Expression result = nvl; + if (first.equals(second) || second.isNullLiteral() || (first.isLiteral() && !first.isNullLiteral())) { + result = first; + } else if (first.isNullLiteral()) { + result = second; + } + return TypeCoercionUtils.ensureSameResultType(originNvl, result, context); + } + + @Override + public Expression visitNullIf(NullIf nullIf, ExpressionRewriteContext context) { + NullIf originNullIf = nullIf; + nullIf = rewriteChildren(nullIf, context); + Expression first = nullIf.left(); + Expression second = nullIf.right(); + Expression result = nullIf; + // if first is null, then first = second will be null + if (first.isNullLiteral() || second.isNullLiteral()) { + result = first; + } else if (first.equals(second)) { + // even if first is null, then first = second will be null, then result is first, so the result is also null + result = new NullLiteral(originNullIf.getDataType()); + } else if (first.isLiteral() && second.isLiteral()) { + Expression isEqual = visitEqualTo(new EqualTo(first, second), context); + if (isEqual.equals(BooleanLiteral.TRUE)) { + result = new NullLiteral(originNullIf.getDataType()); + } else if (isEqual.equals(BooleanLiteral.FALSE) || isEqual.isNullLiteral()) { + result = first; } } - // all nulls - return TypeCoercionUtils.ensureSameResultType(originNvl, nvl.child(0), context); + return TypeCoercionUtils.ensureSameResultType(originNullIf, result, context); } private E rewriteChildren(E expr, ExpressionRewriteContext context) { @@ -780,7 +806,7 @@ private E rewriteChildren(E expr, ExpressionRewriteContex } private Optional preProcess(Expression expression) { - if (expression instanceof AggregateFunction || expression instanceof TableGeneratingFunction) { + if (!expression.foldable()) { return Optional.of(expression); } if (ExpressionUtils.hasNullLiteral(expression.getArguments()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteral.java new file mode 100644 index 00000000000000..29ddf97e786db0 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteral.java @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.CompoundPredicate; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +import java.util.List; +import java.util.Map; + +/** + * For nested CaseWhen/IF expression, replace the inner CaseWhen/IF condition with TRUE/FALSE literal + * when the condition also exists in the outer CaseWhen/IF conditions. + * + * on the nested CASE/IF path, a condition may exist in multiple CASE/IF branches, + * for any inner case when or if condition, its boolean value is determined by the outermost CASE/IF branch, + * that is the first occurrence of the condition on the nested CASE/IF path. + * + *
+ * 1. if it exists in outer case's current branch condition, replace it with TRUE + * e.g. + * case when A then + * (case when A then 1 else 2 end) + * ... + * end + * then inner case condition A will replace with TRUE: + * case when A then + * (case when TRUE then 1 else 2 end) + * ... + * end + *
+ * 2. if it exists in outer case's previous branch condition, replace it with FALSE + * e.g. + * case when A then ... + * when B then + * (case when A then 1 else 2 end) + * ... + * end + * then inner case condition A will replace with FALSE: + * case when A then ... + * when B then + * (case when FALSE then 1 else 2 end) + * ... + * end + *
+ */ +public class NestedCaseWhenCondToLiteral implements ExpressionPatternRuleFactory { + + public static final NestedCaseWhenCondToLiteral INSTANCE = new NestedCaseWhenCondToLiteral(); + + @Override + public List> buildRules() { + return ImmutableList.of( + root(Expression.class) + .when(this::needRewrite) + .thenApply(ctx -> rewrite(ctx.expr, ctx.rewriteContext)) + .toRule(ExpressionRuleType.NESTED_CASE_WHEN_COND_TO_LITERAL) + ); + } + + private boolean needRewrite(Expression expression) { + return expression.containsType(CaseWhen.class, If.class); + } + + private Expression rewrite(Expression expression, ExpressionRewriteContext context) { + return expression.accept(new NestedCondReplacer(), null); + } + + /** NestedCondReplacer */ + @VisibleForTesting + public static class NestedCondReplacer extends DefaultExpressionRewriter { + + // condition literals is used to record the boolean literal for a condition expression, + // 1. if a condition, if it exists in outer case/if conditions, it will be replaced with the literal. + // 2. otherwise it's the first time occur, then: + // a) when enter a case/if branch, set this condition to TRUE literal + // b) when leave a case/if branch, set this condition to FALSE literal + // c) when leave the whole case/if statement, remove this condition literal + protected final Map conditionLiterals = Maps.newHashMap(); + + @Override + public Expression visit(Expression expr, Void context) { + if (INSTANCE.needRewrite(expr)) { + return super.visit(expr, context); + } else { + return expr; + } + } + + @Override + public Expression visitCaseWhen(CaseWhen caseWhen, Void context) { + ImmutableList.Builder newWhenClausesBuilder + = ImmutableList.builderWithExpectedSize(caseWhen.arity()); + List firstOccurConds = Lists.newArrayListWithExpectedSize(caseWhen.arity()); + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + Expression oldCondition = whenClause.getOperand(); + Pair replaceResult = replaceCondition(oldCondition, context); + Expression newCondition = replaceResult.first; + boolean condFirstOccur = replaceResult.second; + if (condFirstOccur) { + firstOccurConds.add(oldCondition); + conditionLiterals.put(oldCondition, BooleanLiteral.TRUE); + } + Expression newResult = whenClause.getResult().accept(this, context); + if (condFirstOccur) { + conditionLiterals.put(oldCondition, BooleanLiteral.FALSE); + } + if (whenClause.getOperand() != newCondition || whenClause.getResult() != newResult) { + newWhenClausesBuilder.add(new WhenClause(newCondition, newResult)); + } else { + newWhenClausesBuilder.add(whenClause); + } + } + Expression oldDefaultValue = caseWhen.getDefaultValue().orElse(null); + Expression newDefaultValue = oldDefaultValue; + if (newDefaultValue != null) { + newDefaultValue = newDefaultValue.accept(this, context); + } + for (Expression cond : firstOccurConds) { + conditionLiterals.remove(cond); + } + List newWhenClauses = newWhenClausesBuilder.build(); + boolean hasNewChildren = false; + if (newWhenClauses.size() != caseWhen.getWhenClauses().size()) { + hasNewChildren = true; + } else { + for (int i = 0; i < newWhenClauses.size(); i++) { + if (newWhenClauses.get(i) != caseWhen.getWhenClauses().get(i)) { + hasNewChildren = true; + break; + } + } + } + if (newDefaultValue != oldDefaultValue) { + hasNewChildren = true; + } + if (hasNewChildren) { + return newDefaultValue != null + ? new CaseWhen(newWhenClauses, newDefaultValue) + : new CaseWhen(newWhenClauses); + } else { + return caseWhen; + } + } + + @Override + public Expression visitIf(If ifExpr, Void context) { + Expression oldCondition = ifExpr.getCondition(); + Pair replaceResult = replaceCondition(oldCondition, context); + Expression newCondition = replaceResult.first; + boolean condFirstOccur = replaceResult.second; + if (condFirstOccur) { + conditionLiterals.put(oldCondition, BooleanLiteral.TRUE); + } + Expression newTrueValue = ifExpr.getTrueValue().accept(this, context); + if (condFirstOccur) { + conditionLiterals.put(oldCondition, BooleanLiteral.FALSE); + } + Expression newFalseValue = ifExpr.getFalseValue().accept(this, context); + if (condFirstOccur) { + conditionLiterals.remove(oldCondition); + } + if (newCondition != oldCondition + || newTrueValue != ifExpr.getTrueValue() + || newFalseValue != ifExpr.getFalseValue()) { + return new If(newCondition, newTrueValue, newFalseValue); + } else { + return ifExpr; + } + } + + // return newCondition + condition first occur flag + private Pair replaceCondition(Expression condition, Void context) { + if (condition.isLiteral()) { + // literal condition do not need to replace, and do not record it + return Pair.of(condition, false); + } else if (conditionLiterals.containsKey(condition)) { + return Pair.of(conditionLiterals.get(condition), false); + } else if (condition instanceof CompoundPredicate) { + ImmutableList.Builder newChildrenBuilder + = ImmutableList.builderWithExpectedSize(condition.arity()); + boolean hasNewChildren = false; + for (Expression child : condition.children()) { + Expression newChild = replaceCondition(child, context).first; + hasNewChildren = hasNewChildren || newChild != child; + newChildrenBuilder.add(newChild); + } + Expression newCondition = hasNewChildren + ? condition.withChildren(newChildrenBuilder.build()) : condition; + return Pair.of(newCondition, true); + } else { + return Pair.of(condition.accept(this, context), true); + } + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java index 6b3b53feb6b9fd..fbe8448c9a0959 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java @@ -18,46 +18,161 @@ package org.apache.doris.nereids.rules.expression.rules; import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; -import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NullSafeEqual; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; import java.util.List; +import java.util.Optional; +import java.util.Set; /** * convert "A <=> null" to "A is null" * null <=> null : true * null <=> 1 : false * 1 <=> 2 : 1 = 2 + * + * 1. if null safe equal is in a filter / join / case when / if condition, and at least one side is not nullable, + * then null safe equal can be converted to equal. + * 2. otherwise if both sides are not nullable, then null safe equal can converted to equal too. + * */ -public class NullSafeEqualToEqual implements ExpressionPatternRuleFactory { +public class NullSafeEqualToEqual extends ConditionRewrite { public static final NullSafeEqualToEqual INSTANCE = new NullSafeEqualToEqual(); @Override public List> buildRules() { - return ImmutableList.of( - matchesType(NullSafeEqual.class).then(NullSafeEqualToEqual::rewrite) - .toRule(ExpressionRuleType.NULL_SAFE_EQUAL_TO_EQUAL) - ); + return buildCondRules(ExpressionRuleType.NULL_SAFE_EQUAL_TO_EQUAL); } - private static Expression rewrite(NullSafeEqual nullSafeEqual) { - // because the nullable info hasn't been finalized yet, the optimization is limited - if (nullSafeEqual.left().isNullLiteral() && nullSafeEqual.right().isNullLiteral()) { + // rewrite all the expression tree, not only the condition part. + @Override + protected boolean needRewrite(Expression expression, boolean isInsideCondition) { + return expression.containsType(NullSafeEqual.class); + } + + @Override + public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, Boolean isInsideCondition) { + NullSafeEqual newNullSafeEqual = (NullSafeEqual) super.visitNullSafeEqual(nullSafeEqual, isInsideCondition); + Expression newLeft = newNullSafeEqual.left(); + Expression newRight = newNullSafeEqual.right(); + boolean canConvertToEqual = (!newLeft.nullable() && !newRight.nullable()) + || (isInsideCondition && (!newLeft.nullable() || !newRight.nullable())); + if (newLeft.equals(newRight)) { + return BooleanLiteral.TRUE; + } else if (newLeft.isNullLiteral() && newRight.isNullLiteral()) { return BooleanLiteral.TRUE; - } else if (nullSafeEqual.left().isNullLiteral()) { - return nullSafeEqual.right().isLiteral() ? BooleanLiteral.FALSE : new IsNull(nullSafeEqual.right()); - } else if (nullSafeEqual.right().isNullLiteral()) { - return nullSafeEqual.left().isLiteral() ? BooleanLiteral.FALSE : new IsNull(nullSafeEqual.left()); - } else if (nullSafeEqual.left().isLiteral() && nullSafeEqual.right().isLiteral()) { - return new EqualTo(nullSafeEqual.left(), nullSafeEqual.right()); + } else if (newLeft.isNullLiteral()) { + return !newRight.nullable() ? BooleanLiteral.FALSE : new IsNull(newRight); + } else if (newRight.isNullLiteral()) { + return !newLeft.nullable() ? BooleanLiteral.FALSE : new IsNull(newLeft); + } else if (canConvertToEqual) { + return new EqualTo(newLeft, newRight); + } else if (newRight.equals(BooleanLiteral.TRUE)) { + return simplifySafeEqualTrue(newLeft).orElse(newNullSafeEqual); + } else { + return newNullSafeEqual; + } + } + + /** + * try to simplify 'expression <=> TRUE', + * return the rewritten expression if it can be simplified, otherwise return empty. + */ + private Optional simplifySafeEqualTrue(Expression expression) { + if (expression.isLiteral()) { + return Optional.of(BooleanLiteral.of(expression.equals(BooleanLiteral.TRUE))); + } else if (!expression.nullable()) { + return Optional.of(expression); + } else if (expression instanceof PropagateNullable) { + Set conjuncts = Sets.newLinkedHashSet(); + conjuncts.add(expression); + if (tryProcessPropagateNullable(expression, conjuncts)) { + return Optional.of(ExpressionUtils.and(conjuncts)); + } + } else if (expression instanceof InPredicate) { + InPredicate in = (InPredicate) expression; + Expression compareExpr = in.getCompareExpr(); + if (!compareExpr.isConstant()) { + Set conjuncts = Sets.newLinkedHashSet(); + if (tryProcessPropagateNullable(compareExpr, conjuncts)) { + boolean allOptionNonNullLiteral = true; + ImmutableList.Builder newOptionsBuilder + = ImmutableList.builderWithExpectedSize(in.getOptions().size()); + for (Expression option : in.getOptions()) { + if (option.isNullLiteral()) { + continue; + } + if (!option.isLiteral()) { + allOptionNonNullLiteral = false; + break; + } + newOptionsBuilder.add(option); + } + if (allOptionNonNullLiteral) { + List newOptions = newOptionsBuilder.build(); + if (newOptions.isEmpty()) { + return Optional.of(BooleanLiteral.FALSE); + } + Expression newIn = newOptions.size() == in.getOptions().size() + ? in : ExpressionUtils.toInPredicateOrEqualTo(compareExpr, newOptions); + conjuncts.add(newIn); + return Optional.of(ExpressionUtils.and(conjuncts)); + } + } + } + } else if (expression instanceof CompoundPredicate) { + // process AND / OR + // (c1 and c2) <=> TRUE rewrite to (c1 <=> TRUE) and (c2 <=> TRUE) + // (c1 or c2) <=> TRUE rewrite to (c1 <=> TRUE) or (c2 <=> TRUE) + List oldChildren = expression.children(); + ImmutableList.Builder newChildrenBuilder + = ImmutableList.builderWithExpectedSize(oldChildren.size()); + boolean changed = false; + for (Expression child : expression.children()) { + // rewrite child to child <=> TRUE + Expression newChild = simplifySafeEqualTrue(child) + .orElse(new NullSafeEqual(child, BooleanLiteral.TRUE)); + if (newChild.getClass() == expression.getClass()) { + // flatten + newChildrenBuilder.addAll(newChild.children()); + changed = true; + } else { + changed = changed || child != newChild; + newChildrenBuilder.add(newChild); + } + } + return Optional.of(changed ? expression.withChildren(newChildrenBuilder.build()) : expression); + } + return Optional.empty(); + } + + private boolean tryProcessPropagateNullable(Expression expression, Set conjuncts) { + if (!expression.nullable()) { + return true; + } else if (expression instanceof SlotReference) { + conjuncts.add(ExpressionUtils.notIsNull(expression)); + return true; + } else if (expression instanceof PropagateNullable) { + for (Expression child : expression.children()) { + if (!tryProcessPropagateNullable(child, conjuncts)) { + return false; + } + } + return true; + } else { + return false; } - return nullSafeEqual; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java index e0d2df9c0f25cd..1257e9840743a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java @@ -30,7 +30,6 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; -import org.apache.doris.nereids.types.BooleanType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -115,7 +114,7 @@ public Expression visitInPredicate(InPredicate inPredicate, Map case when c1 then f(a1, a2, ..., p1, ..., an) + * when c2 then f(a1, a2, ..., p2, ..., an) + * else f(a1, a2, ..., p3, ..., an) end + * + * For example: 2 > case when TB = 1 then 1 else 3 end + * can be rewritten to: case when TB = 1 then true else false end. + * After this rule, the expression will continue to be optimized by other rules. + * Later rule CASE_WHEN_TO_COMPOUND will rewrite it to: (TB = 1) <=> TRUE, + * later rule NULL_SAFE_EQUAL_TO_EQUAL will rewrite it to: + * a) TB = 1, if expression is in filter or join; + * b) TB = 1 and TB is not null, otherwise. + */ +public class PushIntoCaseWhenBranch implements ExpressionPatternRuleFactory { + public static PushIntoCaseWhenBranch INSTANCE = new PushIntoCaseWhenBranch(); + private static final Class[] CASE_WHEN_LIKE_CLASSES + = new Class[] {CaseWhen.class, If.class, Nvl.class, NullIf.class}; + + @Override + public List> buildRules() { + return ImmutableList.of( + matchesType(Expression.class) + .when(this::needRewrite) + .thenApply(ctx -> rewrite(ctx.expr, ctx.rewriteContext)) + .toRule(ExpressionRuleType.PUSH_INTO_CASE_WHEN_BRANCH)); + } + + private boolean needRewrite(Expression expression) { + if (!expression.containsType(CASE_WHEN_LIKE_CLASSES) || !expression.foldable()) { + return false; + } + // for expression's children, if one of them is case when/if/nvl/nullif, and the others are literals, + // then try to push rewrite expression, and push expression into the case when/if/nvl/nullif branch. + boolean hasCaseWhenLikeChild = false; + for (Expression child : expression.children()) { + if (child.isLiteral()) { + continue; + } + boolean isCaseWhenLike = isClassCaseWhenLike(child.getClass()); + if (!isCaseWhenLike) { + return false; + } + // if there are more than one case when/if/nvl/nullif child, do not rewrite + if (hasCaseWhenLikeChild) { + return false; + } + hasCaseWhenLikeChild = true; + } + return hasCaseWhenLikeChild; + } + + private boolean isClassCaseWhenLike(Class clazz) { + for (Class pushIntoClass : CASE_WHEN_LIKE_CLASSES) { + if (clazz == pushIntoClass) { + return true; + } + } + return false; + } + + private Expression rewrite(Expression expression, ExpressionRewriteContext context) { + for (int i = 0; i < expression.children().size(); i++) { + Expression child = expression.child(i); + Optional newExpr = Optional.empty(); + if (child instanceof CaseWhen) { + newExpr = tryPushIntoCaseWhen(expression, i, (CaseWhen) child, context); + } else if (child instanceof If) { + newExpr = tryPushIntoIf(expression, i, (If) child, context); + } else if (child instanceof Nvl) { + newExpr = tryPushIntoNvl(expression, i, (Nvl) child, context); + } else if (child instanceof NullIf) { + newExpr = tryPushIntoNullIf(expression, i, (NullIf) child, context); + } + if (newExpr.isPresent()) { + return newExpr.get(); + } + } + return expression; + } + + private Optional tryPushIntoCaseWhen(Expression parent, int childIndex, CaseWhen caseWhen, + ExpressionRewriteContext context) { + List branchValues + = Lists.newArrayListWithExpectedSize(caseWhen.getWhenClauses().size() + 1); + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + branchValues.add(whenClause.getResult()); + } + branchValues.add(caseWhen.getDefaultValue().orElse(new NullLiteral(caseWhen.getDataType()))); + if (pushIntoBranches(parent, childIndex, branchValues, context)) { + List newWhenClauses = Lists.newArrayListWithExpectedSize(caseWhen.getWhenClauses().size()); + for (int i = 0; i < caseWhen.getWhenClauses().size(); i++) { + newWhenClauses.add(new WhenClause(caseWhen.getWhenClauses().get(i).getOperand(), branchValues.get(i))); + } + Expression defaultValue = branchValues.get(branchValues.size() - 1); + CaseWhen newCaseWhen = new CaseWhen(newWhenClauses, defaultValue); + return Optional.of(TypeCoercionUtils.ensureSameResultType(parent, newCaseWhen, context)); + } else { + return Optional.empty(); + } + } + + private Optional tryPushIntoIf(Expression parent, int childIndex, If ifExpr, + ExpressionRewriteContext context) { + List branchValues = Lists.newArrayList(ifExpr.getTrueValue(), ifExpr.getFalseValue()); + if (pushIntoBranches(parent, childIndex, branchValues, context)) { + If newIf = new If(ifExpr.getCondition(), branchValues.get(0), branchValues.get(1)); + return Optional.of(TypeCoercionUtils.ensureSameResultType(parent, newIf, context)); + } else { + return Optional.empty(); + } + } + + private Optional tryPushIntoNvl(Expression parent, int childIndex, Nvl nvl, + ExpressionRewriteContext context) { + Expression first = nvl.left(); + Expression second = nvl.right(); + boolean isConditionPlan = PlanUtils.isConditionExpressionPlan(context.plan.orElse(null)); + // after rewrite, nvl(first, second) => if(isnull(first), second, first), + // so there will exist twice 'first' in the rewritten IF expression, which may increase the computation cost. + // if the plan is not filter and not join, then push down action may not have positive effect, + // considering this, we give up the rewrite if the plan is not condition plan or first contains unique function. + if (first.containsUniqueFunction() || !isConditionPlan) { + return Optional.empty(); + } + If ifExpr = new If(new IsNull(first), second, first); + return tryPushIntoIf(parent, childIndex, ifExpr, context); + } + + private Optional tryPushIntoNullIf(Expression parent, int childIndex, NullIf nullIf, + ExpressionRewriteContext context) { + Expression first = nullIf.left(); + Expression second = nullIf.right(); + boolean isConditionPlan = PlanUtils.isConditionExpressionPlan(context.plan.orElse(null)); + // after rewrite, nullif(first, second) => if(first = second, null, first), + // so there will exist twice 'first' in the rewritten IF expression, which may increase the computation cost. + // if the plan is not filter and not join, then push down action may not have positive effect, + // considering this, we give up the rewrite if the plan is not condition plan or first contains unique function. + if (first.containsUniqueFunction() || !isConditionPlan) { + return Optional.empty(); + } + If ifExpr = new If(new EqualTo(first, second), new NullLiteral(nullIf.getDataType()), first); + return tryPushIntoIf(parent, childIndex, ifExpr, context); + } + + private boolean pushIntoBranches(Expression parent, int childIndex, + List branchValues, ExpressionRewriteContext context) { + List newChildren = Lists.newArrayList(parent.children()); + // for filter/join condition expression, we allow one non-literal branch after push down, + // because later rule CASE_WHEN_TO_COMPOUND may rewrite to AND/OR expression. + // for other plan expression, we require all branches should be literal after push down, + // just for pass the regression test nereids_rules_p0/mv/agg_without_roll_up. + final int MAX_NON_LIT_NUM = PlanUtils.isConditionExpressionPlan(context.plan.orElse(null)) ? 1 : 0; + int nonLiteralBranchNum = 0; + for (int i = 0; i < branchValues.size(); i++) { + Expression oldValue = branchValues.get(i); + newChildren.set(childIndex, oldValue); + Expression newValue = TypeCoercionUtils.ensureSameResultType( + parent, + FoldConstantRuleOnFE.evaluate(parent.withChildren(newChildren), context), + context); + if (!newValue.isLiteral()) { + nonLiteralBranchNum++; + if (nonLiteralBranchNum > MAX_NON_LIT_NUM) { + return false; + } + } + branchValues.set(i, newValue); + } + + return true; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java index 3a87149fa14dfb..5467de2a9f25d7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java @@ -17,9 +17,9 @@ package org.apache.doris.nereids.rules.expression.rules; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.And; -import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; @@ -28,31 +28,32 @@ import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; +import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral; -import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; -import com.google.common.collect.Multimap; -import com.google.common.collect.Multimaps; +import com.google.common.collect.Maps; import com.google.common.collect.Range; import com.google.common.collect.RangeSet; import com.google.common.collect.Sets; import com.google.common.collect.TreeRangeSet; -import java.util.ArrayList; -import java.util.Collection; -import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; -import java.util.function.BinaryOperator; -import java.util.stream.Collectors; /** * collect range of expression @@ -71,255 +72,707 @@ public ValueDesc visit(Expression expr, ExpressionRewriteContext context) { return new UnknownValue(context, expr); } - private ValueDesc buildRange(ExpressionRewriteContext context, ComparisonPredicate predicate) { - Expression right = predicate.child(1); - if (right.isNullLiteral()) { - return new UnknownValue(context, predicate); - } - // only handle `NumericType` and `DateLikeType` and `StringLikeType` - DataType rightDataType = right.getDataType(); - if (right instanceof ComparableLiteral - && (rightDataType.isNumericType() || rightDataType.isDateLikeType() - || rightDataType.isStringLikeType())) { - return ValueDesc.range(context, predicate); - } - return new UnknownValue(context, predicate); - } - @Override public ValueDesc visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) { - return buildRange(context, greaterThan); + Optional rightLiteral = tryGetComparableLiteral(greaterThan.right()); + if (rightLiteral.isPresent()) { + return new RangeValue(context, greaterThan.left(), Range.greaterThan(rightLiteral.get())); + } else { + return new UnknownValue(context, greaterThan); + } } @Override public ValueDesc visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) { - return buildRange(context, greaterThanEqual); + Optional rightLiteral = tryGetComparableLiteral(greaterThanEqual.right()); + if (rightLiteral.isPresent()) { + return new RangeValue(context, greaterThanEqual.left(), Range.atLeast(rightLiteral.get())); + } else { + return new UnknownValue(context, greaterThanEqual); + } } @Override public ValueDesc visitLessThan(LessThan lessThan, ExpressionRewriteContext context) { - return buildRange(context, lessThan); + Optional rightLiteral = tryGetComparableLiteral(lessThan.right()); + if (rightLiteral.isPresent()) { + return new RangeValue(context, lessThan.left(), Range.lessThan(rightLiteral.get())); + } else { + return new UnknownValue(context, lessThan); + } } @Override public ValueDesc visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) { - return buildRange(context, lessThanEqual); + Optional rightLiteral = tryGetComparableLiteral(lessThanEqual.right()); + if (rightLiteral.isPresent()) { + return new RangeValue(context, lessThanEqual.left(), Range.atMost(rightLiteral.get())); + } else { + return new UnknownValue(context, lessThanEqual); + } } @Override public ValueDesc visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) { - return buildRange(context, equalTo); + Optional rightLiteral = tryGetComparableLiteral(equalTo.right()); + if (rightLiteral.isPresent()) { + return new DiscreteValue(context, equalTo.left(), ImmutableSet.of(rightLiteral.get())); + } else { + return new UnknownValue(context, equalTo); + } } @Override public ValueDesc visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) { // only handle `NumericType` and `DateLikeType` if (inPredicate.getOptions().size() <= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE - && ExpressionUtils.isAllNonNullComparableLiteral(inPredicate.getOptions()) - && (ExpressionUtils.matchNumericType(inPredicate.getOptions()) - || ExpressionUtils.matchDateLikeType(inPredicate.getOptions()))) { - return ValueDesc.discrete(context, inPredicate); + && ExpressionUtils.isAllNonNullComparableLiteral(inPredicate.getOptions())) { + Set values = Sets.newLinkedHashSetWithExpectedSize(inPredicate.getOptions().size()); + boolean succ = true; + for (Expression value : inPredicate.getOptions()) { + Optional literal = tryGetComparableLiteral(value); + if (!literal.isPresent()) { + succ = false; + break; + } + values.add(literal.get()); + } + if (succ) { + return new DiscreteValue(context, inPredicate.getCompareExpr(), values); + } } + return new UnknownValue(context, inPredicate); } + private Optional tryGetComparableLiteral(Expression right) { + // only handle `NumericType` and `DateLikeType` and `StringLikeType` + DataType rightDataType = right.getDataType(); + if (right instanceof ComparableLiteral + && !right.isNullLiteral() + && (rightDataType.isNumericType() || rightDataType.isDateLikeType() + || rightDataType.isStringLikeType())) { + return Optional.of((ComparableLiteral) right); + } else { + return Optional.empty(); + } + } + + @Override + public ValueDesc visitNot(Not not, ExpressionRewriteContext context) { + ValueDesc childValue = not.child().accept(this, context); + if (childValue instanceof DiscreteValue) { + return new NotDiscreteValue(context, childValue.getReference(), ((DiscreteValue) childValue).values); + } else if (childValue instanceof IsNullValue) { + return new IsNotNullValue(context, childValue.getReference(), not); + } else { + return new UnknownValue(context, not); + } + } + + @Override + public ValueDesc visitIsNull(IsNull isNull, ExpressionRewriteContext context) { + return new IsNullValue(context, isNull.child()); + } + @Override public ValueDesc visitAnd(And and, ExpressionRewriteContext context) { - return simplify(context, ExpressionUtils.extractConjunction(and), - ValueDesc::intersect, true); + return processCompound(context, ExpressionUtils.extractConjunction(and), true); } @Override public ValueDesc visitOr(Or or, ExpressionRewriteContext context) { - return simplify(context, ExpressionUtils.extractDisjunction(or), - ValueDesc::union, false); + return processCompound(context, ExpressionUtils.extractDisjunction(or), false); } - private ValueDesc simplify(ExpressionRewriteContext context, List predicates, - BinaryOperator op, boolean isAnd) { - - boolean convertIsNullToEmptyValue = isAnd && predicates.stream().anyMatch(expr -> expr instanceof NullLiteral); - Multimap groupByReference - = Multimaps.newListMultimap(new LinkedHashMap<>(), ArrayList::new); + private ValueDesc processCompound(ExpressionRewriteContext context, List predicates, boolean isAnd) { + boolean hasNullExpression = false; + boolean hasIsNullExpression = false; + boolean hasNotIsNullExpression = false; + Predicate isNotNull = expression -> expression instanceof Not + && expression.child(0) instanceof IsNull + && !((Not) expression).isGeneratedIsNotNull(); for (Expression predicate : predicates) { - // EmptyValue(a) = IsNull(a) and null, it doesn't equals to IsNull(a). - // Only the and expression contains at least a null literal in its conjunctions, - // then EmptyValue(a) can equivalent to IsNull(a). - // so for expression and(IsNull(a), IsNull(b), ..., null), a, b can convert to EmptyValue. - // What's more, if a is not nullable, then EmptyValue(a) always equals to IsNull(a), - // but we don't consider this case here, we should fold IsNull(a) to FALSE using other rule. + hasNullExpression = hasNullExpression || predicate.isNullLiteral(); + hasIsNullExpression = hasIsNullExpression || predicate instanceof IsNull; + hasNotIsNullExpression = hasNotIsNullExpression || isNotNull.test(predicate); + } + boolean convertIsNullToEmptyValue = isAnd && hasNullExpression && hasIsNullExpression; + boolean convertNotIsNullToRangeAll = !isAnd && hasNullExpression && hasNotIsNullExpression; + Map, ValueDescCollector> groupByReference = Maps.newLinkedHashMap(); + int nextUniqueNum = 1; + for (Expression predicate : predicates) { + // given an expression A, no matter A is nullable or not, + // 'A is null and null' can represent as EmptyValue(A), + // 'A is not null or null' can represent as RangeAll(A). ValueDesc valueDesc = null; - if (convertIsNullToEmptyValue && predicate instanceof IsNull) { + if (predicate instanceof IsNull && convertIsNullToEmptyValue) { valueDesc = new EmptyValue(context, ((IsNull) predicate).child()); + } else if (isNotNull.test(predicate) && convertNotIsNullToRangeAll) { + valueDesc = new RangeValue(context, predicate.child(0).child(0), Range.all()); + } else if (predicate.isNullLiteral() && (convertIsNullToEmptyValue || convertNotIsNullToRangeAll)) { + continue; } else { valueDesc = predicate.accept(this, context); } - List valueDescs = (List) groupByReference.get(valueDesc.reference); - valueDescs.add(valueDesc); - } - List valuePerRefs = Lists.newArrayList(); - for (Entry> referenceValues : groupByReference.asMap().entrySet()) { - Expression reference = referenceValues.getKey(); - List valuePerReference = (List) referenceValues.getValue(); - if (!isAnd) { - valuePerReference = ValueDesc.unionDiscreteAndRange(context, reference, valuePerReference); - } + int uniqueNum = 0; - // merge per reference - ValueDesc simplifiedValue = valuePerReference.get(0); - for (int i = 1; i < valuePerReference.size(); i++) { - simplifiedValue = op.apply(simplifiedValue, valuePerReference.get(i)); + // for compound value with diff source value reference like 'a > 1 and b > 1', + // don't merge it with other values, so give them a unique num > 0. + // for other value desc, their unique num is always 0. + if (valueDesc instanceof CompoundValue && !((CompoundValue) valueDesc).isSameReference) { + nextUniqueNum++; + uniqueNum = nextUniqueNum; } - valuePerRefs.add(simplifiedValue); + Expression reference = valueDesc.reference; + groupByReference.computeIfAbsent(Pair.of(reference, uniqueNum), + key -> new ValueDescCollector()).add(valueDesc); + } + + List valuePerRefs = Lists.newArrayList(); + for (Entry, ValueDescCollector> referenceValues : groupByReference.entrySet()) { + Expression reference = referenceValues.getKey().first; + ValueDescCollector collector = referenceValues.getValue(); + ValueDesc mergedValue; + if (isAnd) { + mergedValue = intersect(context, reference, collector); + } else { + mergedValue = union(context, reference, collector); + } + valuePerRefs.add(mergedValue); } if (valuePerRefs.size() == 1) { return valuePerRefs.get(0); } - // use UnknownValue to wrap different references - return new UnknownValue(context, valuePerRefs, isAnd); + Expression reference = SimplifyRange.INSTANCE.getCompoundExpression(context, valuePerRefs, isAnd); + return new CompoundValue(context, reference, valuePerRefs, isAnd); } - /** - * value desc - */ - public abstract static class ValueDesc { - ExpressionRewriteContext context; - Expression reference; - - public ValueDesc(ExpressionRewriteContext context, Expression reference) { - this.context = context; - this.reference = reference; + private ValueDesc intersect(ExpressionRewriteContext context, Expression reference, ValueDescCollector collector) { + if (collector.hasIsNullValue) { + if (!collector.rangeValues.isEmpty() + || !collector.discreteValues.isEmpty() + || !collector.notDiscreteValues.isEmpty()) { + // TA is null and TA > 1 + // => TA is null and (null) + // => TA is null and null + // => EmptyValue(TA) + collector.rangeValues.clear(); + collector.discreteValues.clear(); + collector.notDiscreteValues.clear(); + collector.add(new EmptyValue(context, reference)); + } } - public Expression getReference() { - return reference; + List resultValues = Lists.newArrayList(); + // merge all the range values + Range mergeRangeValue = null; + if (!collector.hasEmptyValue && !collector.rangeValues.isEmpty()) { + RangeValue mergeRangeValueDesc = null; + for (RangeValue rangeValue : collector.rangeValues) { + // RangeAll(TA) and IsNotNull(TA) + // = (TA is not null or null) and (TA is not null) + // = TA is not null + // = IsNotNull(TA) + if (rangeValue.isRangeAll() && collector.isNotNullValueOpt.isPresent()) { + // Notice that if collector has only isGenerateNotNullValueOpt, we should not keep the rangeAll here + // for expression: (Not(IsNull(TA)) OR NULL) AND GeneratedNot(IsNull(TA)) + // will be converted to RangeAll(TA) AND IsNotNullValue(TA, generated=true) + // if we skip this RangeAll, the final result will be IsNotNullValue(TA, generated=true) + // then convert back to expression: GeneratedNot(IsNull(TA)), + // but later EliminateNotNull rule will remove this generated Not expression, + // then the final result will be TRUE, which is wrong. + continue; + } + if (mergeRangeValueDesc == null) { + mergeRangeValueDesc = rangeValue; + } else { + ValueDesc combineValue = mergeRangeValueDesc.intersect(rangeValue); + if (combineValue instanceof RangeValue) { + mergeRangeValueDesc = (RangeValue) combineValue; + } else { + collector.add(combineValue); + mergeRangeValueDesc = null; + // no need to process the lefts. + if (combineValue instanceof EmptyValue) { + break; + } + } + } + } + if (!collector.hasEmptyValue && mergeRangeValueDesc != null) { + mergeRangeValue = mergeRangeValueDesc.range; + } } - public ExpressionRewriteContext getExpressionRewriteContext() { - return context; + // merge all the discrete values + Set mergeDiscreteValues = null; + if (!collector.hasEmptyValue && !collector.discreteValues.isEmpty()) { + mergeDiscreteValues = Sets.newLinkedHashSet(collector.discreteValues.get(0).values); + for (int i = 1; i < collector.discreteValues.size(); i++) { + mergeDiscreteValues.retainAll(collector.discreteValues.get(i).values); + } + if (mergeDiscreteValues.isEmpty()) { + collector.add(new EmptyValue(context, reference)); + mergeDiscreteValues = null; + } } - public abstract ValueDesc union(ValueDesc other); - - /** or */ - public static ValueDesc union(ExpressionRewriteContext context, - RangeValue range, DiscreteValue discrete, boolean reverseOrder) { - if (discrete.values.stream().allMatch(x -> range.range.test(x))) { - return range; - } - List sourceValues = reverseOrder - ? ImmutableList.of(discrete, range) - : ImmutableList.of(range, discrete); - return new UnknownValue(context, sourceValues, false); - } - - /** merge discrete and ranges only, no merge other value desc */ - public static List unionDiscreteAndRange(ExpressionRewriteContext context, - Expression reference, List valueDescs) { - // Since in-predicate's options is a list, the discrete values need to kept options' order. - // If not keep options' order, the result in-predicate's option list will not equals to - // the input in-predicate, later nereids will need to simplify the new in-predicate, - // then cause dead loop. - Set discreteValues = Sets.newLinkedHashSet(); - for (ValueDesc valueDesc : valueDescs) { - if (valueDesc instanceof DiscreteValue) { - discreteValues.addAll(((DiscreteValue) valueDesc).getValues()); - } - } - - // for 'a > 8 or a = 8', then range (8, +00) can convert to [8, +00) - RangeSet rangeSet = TreeRangeSet.create(); - for (ValueDesc valueDesc : valueDescs) { - if (valueDesc instanceof RangeValue) { - Range range = ((RangeValue) valueDesc).range; - rangeSet.add(range); - if (range.hasLowerBound() - && range.lowerBoundType() == BoundType.OPEN - && discreteValues.contains(range.lowerEndpoint())) { - rangeSet.add(Range.singleton(range.lowerEndpoint())); - } - if (range.hasUpperBound() - && range.upperBoundType() == BoundType.OPEN - && discreteValues.contains(range.upperEndpoint())) { - rangeSet.add(Range.singleton(range.upperEndpoint())); - } + // merge all the not discrete values + Set mergeNotDiscreteValues = Sets.newLinkedHashSet(); + if (!collector.hasEmptyValue && !collector.notDiscreteValues.isEmpty()) { + for (NotDiscreteValue notDiscreteValue : collector.notDiscreteValues) { + mergeNotDiscreteValues.addAll(notDiscreteValue.values); + } + if (mergeRangeValue != null) { + Range finalValue = mergeRangeValue; + mergeNotDiscreteValues.removeIf(value -> !finalValue.contains(value)); + } + if (mergeDiscreteValues != null) { + Set finalValues = mergeDiscreteValues; + mergeNotDiscreteValues.removeIf(value -> !finalValues.contains(value)); + mergeDiscreteValues.removeIf(mergeNotDiscreteValues::contains); + if (mergeDiscreteValues.isEmpty()) { + collector.add(new EmptyValue(context, reference)); + mergeDiscreteValues = null; } } + } + if (!collector.hasEmptyValue) { + // merge range + discrete values + if (mergeRangeValue != null && mergeDiscreteValues != null) { + ValueDesc newMergeValue = new RangeValue(context, reference, mergeRangeValue) + .intersect(new DiscreteValue(context, reference, mergeDiscreteValues)); + resultValues.add(newMergeValue); + } else if (mergeRangeValue != null) { + resultValues.add(new RangeValue(context, reference, mergeRangeValue)); + } else if (mergeDiscreteValues != null) { + resultValues.add(new DiscreteValue(context, reference, mergeDiscreteValues)); + } + if (!collector.hasEmptyValue && !mergeNotDiscreteValues.isEmpty()) { + resultValues.add(new NotDiscreteValue(context, reference, mergeNotDiscreteValues)); + } + } - if (!rangeSet.isEmpty()) { - discreteValues.removeIf(x -> rangeSet.contains(x)); + // process empty value + if (collector.hasEmptyValue) { + if (!reference.nullable()) { + return new UnknownValue(context, BooleanLiteral.FALSE); + } + resultValues.add(new EmptyValue(context, reference)); + } + if (collector.hasIsNullValue) { + if (collector.hasIsNotNullValue()) { + return new UnknownValue(context, BooleanLiteral.FALSE); + } + // nullable's EmptyValue have contains IsNull, no need to add + if (!collector.hasEmptyValue) { + resultValues.add(new IsNullValue(context, reference)); + } + } + if (collector.hasIsNotNullValue()) { + if (collector.hasEmptyValue) { + return new UnknownValue(context, BooleanLiteral.FALSE); } + collector.isNotNullValueOpt.ifPresent(resultValues::add); + collector.isGenerateNotNullValueOpt.ifPresent(resultValues::add); + } + Optional shortCutResult = mergeCompoundValues(context, reference, resultValues, collector, true); + if (shortCutResult.isPresent()) { + return shortCutResult.get(); + } + // unknownValue should be empty + resultValues.addAll(collector.unknownValues); - List result = Lists.newArrayListWithExpectedSize(valueDescs.size()); + Preconditions.checkArgument(!resultValues.isEmpty()); + if (resultValues.size() == 1) { + return resultValues.get(0); + } else { + return new CompoundValue(context, reference, resultValues, true); + } + } + + private ValueDesc union(ExpressionRewriteContext context, Expression reference, ValueDescCollector collector) { + if (collector.hasIsNotNullValue()) { + if (!collector.rangeValues.isEmpty() + || !collector.discreteValues.isEmpty() + || !collector.notDiscreteValues.isEmpty()) { + // TA is not null or TA > 1 + // => TA is not null or (null) + // => TA is not null or null + // => RangeAll(TA) + collector.rangeValues.clear(); + collector.discreteValues.clear(); + collector.notDiscreteValues.clear(); + collector.add(new RangeValue(context, reference, Range.all())); + } + } + + List resultValues = Lists.newArrayListWithExpectedSize(collector.size() + 3); + // Since in-predicate's options is a list, the discrete values need to kept options' order. + // If not keep options' order, the result in-predicate's option list will not equals to + // the input in-predicate, later nereids will need to simplify the new in-predicate, + // then cause dead loop. + Set discreteValues = Sets.newLinkedHashSet(); + for (DiscreteValue discreteValue : collector.discreteValues) { + discreteValues.addAll(discreteValue.values); + } + + // for 'a > 8 or a = 8', then range (8, +00) can convert to [8, +00) + RangeSet rangeSet = TreeRangeSet.create(); + for (RangeValue rangeValue : collector.rangeValues) { + Range range = rangeValue.range; + rangeSet.add(range); + if (range.hasLowerBound() + && range.lowerBoundType() == BoundType.OPEN + && discreteValues.contains(range.lowerEndpoint())) { + rangeSet.add(Range.singleton(range.lowerEndpoint())); + } + if (range.hasUpperBound() + && range.upperBoundType() == BoundType.OPEN + && discreteValues.contains(range.upperEndpoint())) { + rangeSet.add(Range.singleton(range.upperEndpoint())); + } + } + + if (!rangeSet.isEmpty()) { + discreteValues.removeIf(rangeSet::contains); + } + + Set mergeNotDiscreteValues = Sets.newLinkedHashSet(); + boolean hasRangeAll = false; + if (!collector.notDiscreteValues.isEmpty()) { + mergeNotDiscreteValues.addAll(collector.notDiscreteValues.get(0).values); + // a not in (1, 2) or a not in (1, 2, 3) => a not in (1, 2) + for (int i = 1; i < collector.notDiscreteValues.size(); i++) { + mergeNotDiscreteValues.retainAll(collector.notDiscreteValues.get(i).values); + } + // a not in (1, 2, 3) or a in (1, 2, 4) => a not in (3) + mergeNotDiscreteValues.removeIf( + value -> discreteValues.contains(value) || rangeSet.contains(value)); + discreteValues.removeIf(mergeNotDiscreteValues::contains); + if (mergeNotDiscreteValues.isEmpty()) { + resultValues.add(new RangeValue(context, reference, Range.all())); + } else { + resultValues.add(new NotDiscreteValue(context, reference, mergeNotDiscreteValues)); + } + } else { if (!discreteValues.isEmpty()) { - result.add(new DiscreteValue(context, reference, discreteValues)); + resultValues.add(new DiscreteValue(context, reference, discreteValues)); } for (Range range : rangeSet.asRanges()) { - result.add(new RangeValue(context, reference, range)); + hasRangeAll = hasRangeAll || !range.hasUpperBound() && !range.hasLowerBound(); + resultValues.add(new RangeValue(context, reference, range)); } - for (ValueDesc valueDesc : valueDescs) { - if (!(valueDesc instanceof DiscreteValue) && !(valueDesc instanceof RangeValue)) { - result.add(valueDesc); + } + + if (collector.hasIsNullValue) { + if (collector.hasIsNotNullValue() || hasRangeAll) { + return new UnknownValue(context, BooleanLiteral.TRUE); + } + resultValues.add(new IsNullValue(context, reference)); + } + if (collector.hasIsNotNullValue()) { + if (collector.hasEmptyValue) { + // EmptyValue(TA) or TA is not null + // = TA is null and null or TA is not null + // = TA is not null or null + // = RangeAll(TA) + resultValues.add(new RangeValue(context, reference, Range.all())); + } else { + collector.isNotNullValueOpt.ifPresent(resultValues::add); + collector.isGenerateNotNullValueOpt.ifPresent(resultValues::add); + } + } + + Optional shortCutResult = mergeCompoundValues(context, reference, resultValues, collector, false); + if (shortCutResult.isPresent()) { + return shortCutResult.get(); + } + if (collector.hasEmptyValue) { + // for IsNotNull OR EmptyValue, need keep the EmptyValue + boolean ignoreEmptyValue = !resultValues.isEmpty() && !reference.nullable(); + for (ValueDesc valueDesc : resultValues) { + if (valueDesc instanceof CompoundValue) { + ignoreEmptyValue = ignoreEmptyValue || !((CompoundValue) valueDesc).hasNoneNullable; + } else if (valueDesc.nullable() || valueDesc instanceof IsNullValue) { + ignoreEmptyValue = true; + } + if (ignoreEmptyValue) { + break; + } + } + if (!ignoreEmptyValue) { + resultValues.add(new EmptyValue(context, reference)); + } + } + resultValues.addAll(collector.unknownValues); + Preconditions.checkArgument(!resultValues.isEmpty()); + if (resultValues.size() == 1) { + return resultValues.get(0); + } else { + return new CompoundValue(context, reference, resultValues, false); + } + } + + private Optional mergeCompoundValues(ExpressionRewriteContext context, Expression reference, + List resultValues, ValueDescCollector collector, boolean isAnd) { + // for A and (B or C): + // if A and B is false/empty, then A and (B or C) = A and C + // if B's range is bigger than A, then A and (B or C) = A + // for A or (B and C): + // if A's range is bigger than B, then A or (B and C) = A + // if A or B is true/all, then A or (B and C) = A or C + for (CompoundValue compoundValue : collector.compoundValues) { + if (isAnd != compoundValue.isAnd && compoundValue.reference.equals(reference)) { + ImmutableList.Builder newSourceValuesBuilder + = ImmutableList.builderWithExpectedSize(compoundValue.sourceValues.size()); + boolean skipWholeCompoundValue = false; + for (ValueDesc innerValue : compoundValue.sourceValues) { + IntersectType intersectType = IntersectType.OTHERS; + UnionType unionType = UnionType.OTHERS; + for (ValueDesc outerValue : resultValues) { + if (isAnd) { + skipWholeCompoundValue = skipWholeCompoundValue || innerValue.containsAll(outerValue); + IntersectType type = outerValue.getIntersectType(innerValue); + if (type == IntersectType.EMPTY_VALUE && intersectType != IntersectType.FALSE) { + intersectType = type; + } else if (type == IntersectType.FALSE) { + intersectType = type; + } + } else { + skipWholeCompoundValue = skipWholeCompoundValue || outerValue.containsAll(innerValue); + UnionType type = outerValue.getUnionType(innerValue); + if (type == UnionType.RANGE_ALL && unionType != UnionType.TRUE) { + unionType = type; + } else if (type == UnionType.TRUE) { + unionType = type; + } + } + } + if (skipWholeCompoundValue) { + break; + } + if (isAnd) { + if (intersectType == IntersectType.OTHERS) { + newSourceValuesBuilder.add(innerValue); + } else if (intersectType == IntersectType.EMPTY_VALUE) { + newSourceValuesBuilder.add(new EmptyValue(context, reference)); + } + } else { + if (unionType == UnionType.OTHERS) { + newSourceValuesBuilder.add(innerValue); + } else if (unionType == UnionType.RANGE_ALL) { + newSourceValuesBuilder.add(new RangeValue(context, reference, Range.all())); + } + } + } + if (!skipWholeCompoundValue) { + List newSourceValues = newSourceValuesBuilder.build(); + if (newSourceValues.isEmpty()) { + // when isAnd = true, A and (B or C or D) + // if A and B = FALSE, A and C = FALSE, A and D = FALSE, then newSourceValues is empty + // then A and (B or C or D) = FALSE + // when isAnd = false, A or (B and C and D) + // if A or B = TRUE, A or C = TRUE, A or D = TRUE, then newSourceValues is empty + // then A or (B and C and D) = TRUE + return Optional.of(new UnknownValue(context, BooleanLiteral.of(!isAnd))); + } else if (newSourceValues.size() == 1) { + resultValues.add(newSourceValues.get(0)); + } else { + resultValues.add(new CompoundValue(context, reference, newSourceValues, compoundValue.isAnd)); + } } + } else { + resultValues.add(compoundValue); } + } + + return Optional.empty(); + } + + /** value desc visitor */ + public interface ValueDescVisitor { + R visitEmptyValue(EmptyValue emptyValue, C context); + + R visitRangeValue(RangeValue rangeValue, C context); + + R visitDiscreteValue(DiscreteValue discreteValue, C context); + + R visitNotDiscreteValue(NotDiscreteValue notDiscreteValue, C context); + + R visitIsNullValue(IsNullValue isNullValue, C context); + + R visitIsNotNullValue(IsNotNullValue isNotNullValue, C context); + + R visitCompoundValue(CompoundValue compoundValue, C context); + + R visitUnknownValue(UnknownValue unknownValue, C context); + } + + private static class ValueDescCollector implements ValueDescVisitor { + // generated not is null != not is null + Optional isNotNullValueOpt = Optional.empty(); + Optional isGenerateNotNullValueOpt = Optional.empty(); + + boolean hasIsNullValue = false; + boolean hasEmptyValue = false; + List rangeValues = Lists.newArrayList(); + List discreteValues = Lists.newArrayList(); + List notDiscreteValues = Lists.newArrayList(); + List compoundValues = Lists.newArrayList(); + List unknownValues = Lists.newArrayList(); + + void add(ValueDesc value) { + value.accept(this, null); + } - return result; + int size() { + return rangeValues.size() + discreteValues.size() + compoundValues.size() + unknownValues.size(); } - /** intersect */ - public abstract ValueDesc intersect(ValueDesc other); + boolean hasIsNotNullValue() { + return isNotNullValueOpt.isPresent() || isGenerateNotNullValueOpt.isPresent(); + } + + @Override + public Void visitEmptyValue(EmptyValue emptyValue, Void context) { + hasEmptyValue = true; + return null; + } + + @Override + public Void visitRangeValue(RangeValue rangeValue, Void context) { + rangeValues.add(rangeValue); + return null; + } + + @Override + public Void visitDiscreteValue(DiscreteValue discreteValue, Void context) { + discreteValues.add(discreteValue); + return null; + } + + @Override + public Void visitNotDiscreteValue(NotDiscreteValue notDiscreteValue, Void context) { + notDiscreteValues.add(notDiscreteValue); + return null; + } + + @Override + public Void visitIsNullValue(IsNullValue isNullValue, Void context) { + hasIsNullValue = true; + return null; + } - /** intersect */ - public static ValueDesc intersect(ExpressionRewriteContext context, RangeValue range, DiscreteValue discrete) { - // Since in-predicate's options is a list, the discrete values need to kept options' order. - // If not keep options' order, the result in-predicate's option list will not equals to - // the input in-predicate, later nereids will need to simplify the new in-predicate, - // then cause dead loop. - Set newValues = discrete.values.stream().filter(x -> range.range.contains(x)) - .collect(Collectors.toCollection( - () -> Sets.newLinkedHashSetWithExpectedSize(discrete.values.size()))); - if (newValues.isEmpty()) { - return new EmptyValue(context, range.reference); + @Override + public Void visitIsNotNullValue(IsNotNullValue isNotNullValue, Void context) { + if (isNotNullValue.not.isGeneratedIsNotNull()) { + isGenerateNotNullValueOpt = Optional.of(isNotNullValue); } else { - return new DiscreteValue(context, range.reference, newValues); + isNotNullValueOpt = Optional.of(isNotNullValue); } + return null; } - private static ValueDesc range(ExpressionRewriteContext context, ComparisonPredicate predicate) { - ComparableLiteral value = (ComparableLiteral) predicate.right(); - if (predicate instanceof EqualTo) { - return new DiscreteValue(context, predicate.left(), Sets.newHashSet(value)); - } - Range range = null; - if (predicate instanceof GreaterThanEqual) { - range = Range.atLeast(value); - } else if (predicate instanceof GreaterThan) { - range = Range.greaterThan(value); - } else if (predicate instanceof LessThanEqual) { - range = Range.atMost(value); - } else if (predicate instanceof LessThan) { - range = Range.lessThan(value); - } + @Override + public Void visitCompoundValue(CompoundValue compoundValue, Void context) { + compoundValues.add(compoundValue); + return null; + } - return new RangeValue(context, predicate.left(), range); + @Override + public Void visitUnknownValue(UnknownValue unknownValue, Void context) { + unknownValues.add(unknownValue); + return null; } + } + + /** union two value result */ + public enum UnionType { + TRUE, // equals TRUE + RANGE_ALL, // trueOrNull(reference) + OTHERS, // other case + } - private static ValueDesc discrete(ExpressionRewriteContext context, InPredicate in) { - // Since in-predicate's options is a list, the discrete values need to kept options' order. - // If not keep options' order, the result in-predicate's option list will not equals to - // the input in-predicate, later nereids will need to simplify the new in-predicate, - // then cause dead loop. - // Set literals = (Set) Utils.fastToImmutableSet(in.getOptions()); - Set literals = in.getOptions().stream() - .map(ComparableLiteral.class::cast) - .collect(Collectors.toCollection( - () -> Sets.newLinkedHashSetWithExpectedSize(in.getOptions().size()))); - return new DiscreteValue(context, in.getCompareExpr(), literals); + /** intersect two value result */ + public enum IntersectType { + FALSE, // equals FALSE + EMPTY_VALUE, // falseOrNull(reference) + OTHERS, // other case + } + + /** + * value desc + */ + public abstract static class ValueDesc { + protected final ExpressionRewriteContext context; + protected final Expression reference; + + public ValueDesc(ExpressionRewriteContext context, Expression reference) { + this.context = context; + this.reference = reference; + } + + public ExpressionRewriteContext getExpressionRewriteContext() { + return context; + } + + public Expression getReference() { + return reference; + } + + public R accept(ValueDescVisitor visitor, C context) { + return visit(visitor, context); + } + + protected abstract R visit(ValueDescVisitor visitor, C context); + + protected abstract boolean nullable(); + + protected boolean nullForNullReference() { + return nullable(); } + + // X containsAll Y, means: + // 1) when Y is TRUE, X is TRUE; + // 2) when Y is FALSE, X can be any; + // 3) when Y is null, X is null; + // then will have: + // use in 'A and (B or C)', if B containsAll A, then rewrite it to 'A', + // use in 'A or (B and C)', if A containsAll B, then rewrite it to 'A'. + @VisibleForTesting + public final boolean containsAll(ValueDesc other) { + return containsAll(other, 0); + } + + protected abstract boolean containsAll(ValueDesc other, int depth); + + // X, Y intersectWithIsEmpty, means 'X and Y' is: + // 1) FALSE && !X.nullable() && !Y.nullable(); + // 2) EmptyValue && X.nullable() && Y.nullable()), the nullable check no loss the null + // use in 'A and (B or C)', if A, B intersectWithIsEmpty, then rewrite it to 'A and C' + @VisibleForTesting + public final IntersectType getIntersectType(ValueDesc other) { + return getIntersectType(other, 0); + } + + protected abstract IntersectType getIntersectType(ValueDesc other, int depth); + + // X, Y unionWithIsAll, means 'X union Y' is: + // 1) TRUE && !X.nullable() && !Y.nullable(); + // 2) Range.all() && X.nullable() && Y.nullable(), the nullable check no loss the null; + // use in 'A or (B and C)', if A, B unionWithIsAll, then rewrite it to 'A or C' + @VisibleForTesting + public final UnionType getUnionType(ValueDesc other) { + return getUnionType(other, 0); + } + + protected abstract UnionType getUnionType(ValueDesc other, int depth); } /** @@ -332,13 +785,49 @@ public EmptyValue(ExpressionRewriteContext context, Expression reference) { } @Override - public ValueDesc union(ValueDesc other) { - return other; + protected R visit(ValueDescVisitor visitor, C context) { + return visitor.visitEmptyValue(this, context); + } + + @Override + protected boolean nullable() { + return reference.nullable(); + } + + @Override + protected boolean containsAll(ValueDesc other, int depth) { + return other instanceof EmptyValue || (other instanceof IsNullValue && !reference.nullable()); } @Override - public ValueDesc intersect(ValueDesc other) { - return this; + protected IntersectType getIntersectType(ValueDesc other, int depth) { + if (other instanceof EmptyValue || other instanceof RangeValue + || other instanceof DiscreteValue || other instanceof NotDiscreteValue + || other instanceof IsNullValue) { + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; + } else if (other instanceof IsNotNullValue) { + return IntersectType.FALSE; + } else if (other instanceof CompoundValue) { + return other.getIntersectType(this, depth); + } else { + return IntersectType.OTHERS; + } + } + + @Override + protected UnionType getUnionType(ValueDesc other, int depth) { + if (other instanceof RangeValue) { + if (((RangeValue) other).isRangeAll()) { + return reference.nullable() ? UnionType.RANGE_ALL : UnionType.TRUE; + } + } else if (other instanceof IsNotNullValue) { + if (!reference.nullable()) { + return UnionType.TRUE; + } + } else if (other instanceof CompoundValue) { + return other.getUnionType(this, depth); + } + return UnionType.OTHERS; } } @@ -348,7 +837,8 @@ public ValueDesc intersect(ValueDesc other) { * a > 1 => (1...+∞) */ public static class RangeValue extends ValueDesc { - Range range; + + final Range range; public RangeValue(ExpressionRewriteContext context, Expression reference, Range range) { super(context, reference); @@ -360,54 +850,122 @@ public Range getRange() { } @Override - public ValueDesc union(ValueDesc other) { + protected R visit(ValueDescVisitor visitor, C context) { + return visitor.visitRangeValue(this, context); + } + + @Override + protected boolean nullable() { + return reference.nullable(); + } + + @Override + protected boolean containsAll(ValueDesc other, int depth) { if (other instanceof EmptyValue) { - return other.union(this); + return true; + } else if (other instanceof RangeValue) { + return range.encloses(((RangeValue) other).range); + } else if (other instanceof DiscreteValue) { + return range.containsAll(((DiscreteValue) other).values); + } else if (other instanceof NotDiscreteValue || other instanceof IsNotNullValue) { + return isRangeAll(); + } else if (other instanceof CompoundValue) { + return ((CompoundValue) other).isContainedAllBy(this, depth); + } else { + return false; } - if (other instanceof RangeValue) { - RangeValue o = (RangeValue) other; - if (range.isConnected(o.range)) { - return new RangeValue(context, reference, range.span(o.range)); + } + + @Override + protected IntersectType getIntersectType(ValueDesc other, int depth) { + if (other instanceof EmptyValue) { + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; + } else if (other instanceof RangeValue) { + if (intersect((RangeValue) other) instanceof EmptyValue) { + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; } - return new UnknownValue(context, ImmutableList.of(this, other), false); - } - if (other instanceof DiscreteValue) { - return union(context, this, (DiscreteValue) other, false); + } else if (other instanceof DiscreteValue) { + if (intersect((DiscreteValue) other) instanceof EmptyValue) { + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; + } + } else if (other instanceof IsNullValue) { + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; + } else if (other instanceof CompoundValue) { + return other.getIntersectType(this, depth); } - return new UnknownValue(context, ImmutableList.of(this, other), false); + return IntersectType.OTHERS; } @Override - public ValueDesc intersect(ValueDesc other) { - if (other instanceof EmptyValue) { - return other.intersect(this); + protected UnionType getUnionType(ValueDesc other, int depth) { + if ((other instanceof EmptyValue || other instanceof DiscreteValue) && isRangeAll()) { + return reference.nullable() ? UnionType.RANGE_ALL : UnionType.TRUE; + } else if (other instanceof RangeValue) { + Range otherRange = ((RangeValue) other).range; + if (range.isConnected(otherRange)) { + Range unionRange = range.span(otherRange); + if (!unionRange.hasLowerBound() && !unionRange.hasUpperBound()) { + return reference.nullable() ? UnionType.RANGE_ALL : UnionType.TRUE; + } + } + } else if (other instanceof NotDiscreteValue) { + Set notDiscreteValues = ((NotDiscreteValue) other).values; + boolean succ = true; + for (ComparableLiteral value : notDiscreteValues) { + if (!range.contains(value)) { + succ = false; + break; + } + } + if (succ) { + return reference.nullable() ? UnionType.RANGE_ALL : UnionType.TRUE; + } + } else if (other instanceof IsNullValue && !reference.nullable() && isRangeAll()) { + return UnionType.TRUE; + } else if (other instanceof IsNotNullValue) { + if (!reference.nullable()) { + return UnionType.TRUE; + } + } else if (other instanceof CompoundValue) { + return other.getUnionType(this, depth); } - if (other instanceof RangeValue) { - RangeValue o = (RangeValue) other; - if (range.isConnected(o.range)) { - Range newRange = range.intersection(o.range); - if (!newRange.isEmpty()) { - if (newRange.hasLowerBound() && newRange.hasUpperBound() - && newRange.lowerEndpoint().compareTo(newRange.upperEndpoint()) == 0 - && newRange.lowerBoundType() == BoundType.CLOSED - && newRange.lowerBoundType() == BoundType.CLOSED) { - return new DiscreteValue(context, reference, Sets.newHashSet(newRange.lowerEndpoint())); - } else { - return new RangeValue(context, reference, newRange); - } + return UnionType.OTHERS; + } + + private ValueDesc intersect(RangeValue other) { + if (range.isConnected(other.range)) { + Range newRange = range.intersection(other.range); + if (!newRange.isEmpty()) { + if (newRange.hasLowerBound() && newRange.hasUpperBound() + && newRange.lowerEndpoint().compareTo(newRange.upperEndpoint()) == 0 + && newRange.lowerBoundType() == BoundType.CLOSED + && newRange.lowerBoundType() == BoundType.CLOSED) { + return new DiscreteValue(context, reference, Sets.newHashSet(newRange.lowerEndpoint())); + } else { + return new RangeValue(context, reference, newRange); } } - return new EmptyValue(context, reference); } - if (other instanceof DiscreteValue) { - return intersect(context, this, (DiscreteValue) other); + return new EmptyValue(context, reference); + } + + private ValueDesc intersect(DiscreteValue other) { + Set intersectValues = Sets.newLinkedHashSetWithExpectedSize(other.values.size()); + for (ComparableLiteral value : other.values) { + if (range.contains(value)) { + intersectValues.add(value); + } + } + if (intersectValues.isEmpty()) { + return new EmptyValue(context, reference); + } else { + return new DiscreteValue(context, reference, intersectValues); } - return new UnknownValue(context, ImmutableList.of(this, other), true); } - @Override - public String toString() { - return range == null ? "UnknownRange" : range.toString(); + @VisibleForTesting + public boolean isRangeAll() { + return !range.hasLowerBound() && !range.hasUpperBound(); } } @@ -430,93 +988,331 @@ public Set getValues() { } @Override - public ValueDesc union(ValueDesc other) { + protected R visit(ValueDescVisitor visitor, C context) { + return visitor.visitDiscreteValue(this, context); + } + + @Override + protected boolean nullable() { + return reference.nullable(); + } + + @Override + protected boolean containsAll(ValueDesc other, int depth) { if (other instanceof EmptyValue) { - return other.union(this); + return true; + } else if (other instanceof DiscreteValue) { + return values.containsAll(((DiscreteValue) other).values); + } else if (other instanceof CompoundValue) { + return ((CompoundValue) other).isContainedAllBy(this, depth); + } else { + return false; } - if (other instanceof DiscreteValue) { - Set newValues = Sets.newLinkedHashSet(); - newValues.addAll(((DiscreteValue) other).values); - newValues.addAll(this.values); - return new DiscreteValue(context, reference, newValues); + } + + @Override + protected IntersectType getIntersectType(ValueDesc other, int depth) { + if (other instanceof EmptyValue) { + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; + } else if (other instanceof RangeValue) { + return other.getIntersectType(this, depth); + } else if (other instanceof DiscreteValue) { + Set otherValues = ((DiscreteValue) other).values; + for (ComparableLiteral value : otherValues) { + if (values.contains(value)) { + return IntersectType.OTHERS; + } + } + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; + } else if (other instanceof IsNullValue) { + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; + } else if (other instanceof CompoundValue) { + return other.getIntersectType(this, depth); + } else { + return IntersectType.OTHERS; } + } + + @Override + protected UnionType getUnionType(ValueDesc other, int depth) { if (other instanceof RangeValue) { - return union(context, (RangeValue) other, this, true); + return other.getUnionType(this, depth); + } else if (other instanceof NotDiscreteValue) { + boolean succ = true; + Set notDiscreteValues = ((NotDiscreteValue) other).values; + for (ComparableLiteral value : notDiscreteValues) { + if (!values.contains(value)) { + succ = false; + break; + } + } + if (succ) { + return reference.nullable() ? UnionType.RANGE_ALL : UnionType.TRUE; + } + } else if (other instanceof IsNotNullValue) { + if (!reference.nullable()) { + return UnionType.TRUE; + } + } else if (other instanceof CompoundValue) { + return other.getUnionType(this, depth); } - return new UnknownValue(context, ImmutableList.of(this, other), false); + return UnionType.OTHERS; + } + } + + /** + * for example: + * a not in (1,2,3) => [1,2,3] + */ + public static class NotDiscreteValue extends ValueDesc { + final Set values; + + public NotDiscreteValue(ExpressionRewriteContext context, + Expression reference, Set values) { + super(context, reference); + this.values = values; + } + + @Override + protected R visit(ValueDescVisitor visitor, C context) { + return visitor.visitNotDiscreteValue(this, context); + } + + @Override + protected boolean nullable() { + return reference.nullable(); } @Override - public ValueDesc intersect(ValueDesc other) { + protected boolean containsAll(ValueDesc other, int depth) { if (other instanceof EmptyValue) { - return other.intersect(this); - } - if (other instanceof DiscreteValue) { - Set newValues = Sets.newLinkedHashSet(); - newValues.addAll(this.values); - newValues.retainAll(((DiscreteValue) other).values); - if (newValues.isEmpty()) { - return new EmptyValue(context, reference); - } else { - return new DiscreteValue(context, reference, newValues); + return true; + } else if (other instanceof RangeValue) { + Range range = ((RangeValue) other).range; + for (ComparableLiteral value : values) { + if (range.contains(value)) { + return false; + } + } + return true; + } else if (other instanceof DiscreteValue) { + Set discreteValues = ((DiscreteValue) other).values; + for (ComparableLiteral value : values) { + if (discreteValues.contains(value)) { + return false; + } + } + return true; + } else if (other instanceof NotDiscreteValue) { + return ((NotDiscreteValue) other).values.containsAll(values); + } else if (other instanceof CompoundValue) { + return ((CompoundValue) other).isContainedAllBy(this, depth); + } else { + return false; + } + } + + @Override + protected IntersectType getIntersectType(ValueDesc other, int depth) { + if (other instanceof EmptyValue) { + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; + } else if (other instanceof DiscreteValue) { + if (values.containsAll(((DiscreteValue) other).values)) { + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; } + } else if (other instanceof IsNullValue) { + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; + } else if (other instanceof CompoundValue) { + return other.getIntersectType(this, depth); } + return IntersectType.OTHERS; + } + + @Override + protected UnionType getUnionType(ValueDesc other, int depth) { if (other instanceof RangeValue) { - return intersect(context, (RangeValue) other, this); + return other.getUnionType(this, depth); + } else if (other instanceof DiscreteValue) { + return other.getUnionType(this, depth); + } else if (other instanceof NotDiscreteValue) { + Set notDiscreteValues = ((NotDiscreteValue) other).values; + for (ComparableLiteral value : notDiscreteValues) { + if (values.contains(value)) { + return UnionType.OTHERS; + } + } + return reference.nullable() ? UnionType.RANGE_ALL : UnionType.TRUE; + } else if (other instanceof IsNotNullValue) { + if (!reference.nullable()) { + return UnionType.TRUE; + } + } else if (other instanceof CompoundValue) { + return other.getUnionType(this, depth); } - return new UnknownValue(context, ImmutableList.of(this, other), true); + return UnionType.OTHERS; + } + } + + /** + * a is null + */ + public static class IsNullValue extends ValueDesc { + + public IsNullValue(ExpressionRewriteContext context, Expression reference) { + super(context, reference); + } + + @Override + protected R visit(ValueDescVisitor visitor, C context) { + return visitor.visitIsNullValue(this, context); + } + + @Override + protected boolean nullable() { + return false; } @Override - public String toString() { - return values.toString(); + protected boolean containsAll(ValueDesc other, int depth) { + if (other instanceof EmptyValue) { + return !reference.nullable(); + } else if (other instanceof IsNullValue) { + return true; + } else if (other instanceof CompoundValue) { + return ((CompoundValue) other).isContainedAllBy(this, depth); + } else { + return false; + } + } + + @Override + protected IntersectType getIntersectType(ValueDesc other, int depth) { + if (other instanceof EmptyValue || other instanceof RangeValue + || other instanceof DiscreteValue || other instanceof NotDiscreteValue) { + return reference.nullable() ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; + } else if (other instanceof IsNotNullValue) { + return IntersectType.FALSE; + } else if (other instanceof CompoundValue) { + return other.getIntersectType(this, depth); + } + return IntersectType.OTHERS; + } + + @Override + protected UnionType getUnionType(ValueDesc other, int depth) { + if (other instanceof RangeValue) { + return ((RangeValue) other).getUnionType(this, depth); + } else if (other instanceof IsNotNullValue) { + return UnionType.TRUE; + } else { + return UnionType.OTHERS; + } } } /** - * Represents processing result. + * a is not null */ - public static class UnknownValue extends ValueDesc { - private final List sourceValues; - private final boolean isAnd; + public static class IsNotNullValue extends ValueDesc { + final Not not; - private UnknownValue(ExpressionRewriteContext context, Expression expr) { - super(context, expr); - sourceValues = ImmutableList.of(); - isAnd = false; + public IsNotNullValue(ExpressionRewriteContext context, Expression reference, Not not) { + super(context, reference); + this.not = not; } - private UnknownValue(ExpressionRewriteContext context, - List sourceValues, boolean isAnd) { - super(context, getReference(context, sourceValues, isAnd)); - this.sourceValues = ImmutableList.copyOf(sourceValues); - this.isAnd = isAnd; + public Not getNotExpression() { + return this.not; + } + + @Override + protected R visit(ValueDescVisitor visitor, C context) { + return visitor.visitIsNotNullValue(this, context); + } + + @Override + protected boolean nullable() { + return false; + } + + @Override + protected boolean containsAll(ValueDesc other, int depth) { + if (other instanceof IsNotNullValue) { + return not.isGeneratedIsNotNull() == ((IsNotNullValue) other).not.isGeneratedIsNotNull(); + } else if (other instanceof CompoundValue) { + return ((CompoundValue) other).isContainedAllBy(this, depth); + } else { + return false; + } + } + + @Override + protected IntersectType getIntersectType(ValueDesc other, int depth) { + if (other instanceof EmptyValue || other instanceof IsNullValue) { + return IntersectType.FALSE; + } else if (other instanceof CompoundValue) { + return other.getIntersectType(this, depth); + } else { + return IntersectType.OTHERS; + } + } + + @Override + protected UnionType getUnionType(ValueDesc other, int depth) { + if (other instanceof EmptyValue || other instanceof RangeValue + || other instanceof DiscreteValue || other instanceof NotDiscreteValue) { + if (!reference.nullable()) { + return UnionType.TRUE; + } + } else if (other instanceof IsNullValue) { + return UnionType.TRUE; + } else if (other instanceof CompoundValue) { + return other.getUnionType(this, depth); + } + return UnionType.OTHERS; } + } - // reference is used to simplify multiple ValueDescs. - // when ValueDesc A op ValueDesc B, only A and B's references equals, - // can reduce them, like A op B = A. - // If A and B's reference not equal, A op B will always get UnknownValue(A op B). - // - // for example: - // 1. RangeValue(a < 10, reference=a) union RangeValue(a > 20, reference=a) - // = UnknownValue1(a < 10 or a > 20, reference=a) - // 2. RangeValue(a < 10, reference=a) union RangeValue(b > 20, reference=b) - // = UnknownValue2(a < 10 or b > 20, reference=(a < 10 or b > 20)) - // then given EmptyValue(, reference=a) E, - // 1. since E and UnknownValue1's reference equals, then - // E union UnknownValue1 = E.union(UnknownValue1) = UnknownValue1, - // 2. since E and UnknownValue2's reference not equals, then - // E union UnknownValue2 = UnknownValue3(E union UnknownValue2, reference=E union UnknownValue2) - private static Expression getReference(ExpressionRewriteContext context, + /** + * Represents processing compound predicate. + */ + public static class CompoundValue extends ValueDesc { + private static final int MAX_SEARCH_DEPTH = 1; + private final List sourceValues; + private final boolean isAnd; + private final Set> subClasses; + private final boolean hasNullable; + private final boolean hasNoneNullable; + private final boolean isSameReference; + + /** constructor */ + public CompoundValue(ExpressionRewriteContext context, Expression reference, List sourceValues, boolean isAnd) { - Expression reference = sourceValues.get(0).reference; - for (int i = 1; i < sourceValues.size(); i++) { - if (!reference.equals(sourceValues.get(i).reference)) { - return SimplifyRange.INSTANCE.getExpression(context, sourceValues, isAnd); + super(context, reference); + this.sourceValues = ImmutableList.copyOf(sourceValues); + this.isAnd = isAnd; + this.subClasses = Sets.newHashSet(); + this.subClasses.add(getClass()); + boolean hasNullable = false; + boolean hasNonNullable = false; + boolean isSameReference = true; + for (ValueDesc sourceValue : sourceValues) { + if (sourceValue instanceof CompoundValue) { + CompoundValue compoundSource = (CompoundValue) sourceValue; + this.subClasses.addAll(compoundSource.subClasses); + hasNullable = hasNullable || compoundSource.hasNullable; + hasNonNullable = hasNonNullable || compoundSource.hasNoneNullable; + isSameReference = isSameReference && compoundSource.isSameReference; + } else { + this.subClasses.add(sourceValue.getClass()); + hasNullable = hasNullable || sourceValue.nullable(); + hasNonNullable = hasNonNullable || !sourceValue.nullable(); } + isSameReference = isSameReference && sourceValue.getReference().equals(reference); } - return reference; + this.hasNullable = hasNullable; + this.hasNoneNullable = hasNonNullable; + this.isSameReference = isSameReference; } public List getSourceValues() { @@ -528,23 +1324,229 @@ public boolean isAnd() { } @Override - public ValueDesc union(ValueDesc other) { - // for RangeValue/DiscreteValue/UnknownValue, when union with EmptyValue, - // call EmptyValue.union(this) => this - if (other instanceof EmptyValue) { - return other.union(this); + protected R visit(ValueDescVisitor visitor, C context) { + return visitor.visitCompoundValue(this, context); + } + + @Override + protected boolean nullable() { + return hasNullable; + } + + @Override + protected boolean nullForNullReference() { + return reference.nullable() && !hasNoneNullable; + } + + @Override + protected boolean containsAll(ValueDesc other, int depth) { + // in fact, when merge the value desc for the same reference, + // all the value desc should not be unknown value + if (depth > MAX_SEARCH_DEPTH || other instanceof UnknownValue || subClasses.contains(UnknownValue.class)) { + return false; + } + if (!isAnd && (!other.nullable() || !hasNoneNullable)) { + // for OR value desc: + // 1) if other not nullable, then no need to consider other is null, this is null + // 2) if other is nullable, then when other is null, then the reference is null, + // so if this OR no non-nullable, then this is null too. + for (ValueDesc valueDesc : sourceValues) { + if (valueDesc.containsAll(other, depth + 1)) { + return true; + } + } + return false; + } else { + // when other is nullable, why OR should check all source values containsAll ? + // give an example: for an OR: (c1 or c2 or c3), suppose c1 containsAll other, + // then when other is null, the OR = null or c2 or c3, it may not be null. + // a example: 'a > 1 or a is null' not contains all 'a > 10', even if 'a > 1' contains all 'a > 10' + for (ValueDesc valueDesc : sourceValues) { + if (!valueDesc.containsAll(other, depth + 1)) { + return false; + } + } + return true; + } + } + + // check other containsAll this + private boolean isContainedAllBy(ValueDesc other, int depth) { + // do want to process the complicate cases, + // and in fact, when merge value desc for same reference, + // all the value should not contain UnknownValue. + if (depth > MAX_SEARCH_DEPTH || other instanceof UnknownValue || subClasses.contains(UnknownValue.class)) { + return false; + } + if (isAnd) { + // for C = c1 and c2 and c3, suppose other containsAll c1, then will have: + // when c1 is true, other is true, + // when c1 is null, other is null, + // so, when C is true, then c1 is true, so other is true, + // when C is null, then the reference must be null, so, c1 is null too, then other is null + for (ValueDesc valueDesc : sourceValues) { + if (other.containsAll(valueDesc, depth)) { + return true; + } + } + return false; + } else { + // for C = c1 or c2 or c3, suppose other contains c1, c2, c3. + // so when C is true, then at least one ci is true, so other is true. + // when C is null, then at least one ci is null, so other is null. + // so other will contain all C + for (ValueDesc valueDesc : sourceValues) { + if (!other.containsAll(valueDesc, depth)) { + return false; + } + } + return true; } - return new UnknownValue(context, ImmutableList.of(this, other), false); } @Override - public ValueDesc intersect(ValueDesc other) { - // for RangeValue/DiscreteValue/UnknownValue, when intersect with EmptyValue, - // call EmptyValue.intersect(this) => EmptyValue - if (other instanceof EmptyValue) { - return other.intersect(this); + protected IntersectType getIntersectType(ValueDesc other, int depth) { + if ((!nullable() && other.nullable()) || depth > MAX_SEARCH_DEPTH) { + return IntersectType.OTHERS; + } + if (isAnd) { + // process A and ((B and C) or ...) + boolean hasEmptyValue = false; + boolean hasIsNotNull = false; + boolean allOtherNullForNullReference = true; + for (ValueDesc valueDesc : sourceValues) { + IntersectType type = valueDesc.getIntersectType(other, depth + 1); + if (type == IntersectType.FALSE) { + return type; + } + if (type == IntersectType.EMPTY_VALUE) { + hasEmptyValue = true; + } else { + allOtherNullForNullReference = allOtherNullForNullReference && valueDesc.nullForNullReference(); + } + hasIsNotNull = hasIsNotNull || valueDesc instanceof IsNotNullValue; + } + if (hasEmptyValue) { + if (hasIsNotNull) { + // EmptyValue and IsNotNull = FALSE + return IntersectType.FALSE; + } + // A and ((B and C) or ...) + // if A intersect B is EMPTY_VALUE, A intersect C is OTHERS, C is nullable + if (allOtherNullForNullReference) { + return IntersectType.EMPTY_VALUE; + } + } + return IntersectType.OTHERS; + } else { + // process A and (B or C) => A and B or A and C + boolean hasEmptyValue = false; + for (ValueDesc valueDesc : sourceValues) { + IntersectType type = valueDesc.getIntersectType(other, depth + 1); + if (type == IntersectType.OTHERS) { + return type; + } + hasEmptyValue = hasEmptyValue || type == IntersectType.EMPTY_VALUE; + } + + // must hasEmptyValue or hasFalse + return hasEmptyValue ? IntersectType.EMPTY_VALUE : IntersectType.FALSE; } - return new UnknownValue(context, ImmutableList.of(this, other), true); + } + + @Override + protected UnionType getUnionType(ValueDesc other, int depth) { + if ((!nullable() && other.nullable()) || depth > MAX_SEARCH_DEPTH) { + return UnionType.OTHERS; + } + if (isAnd) { + // process `A or (B and C)`: => (A or B) and (A or C) + boolean hasRangeAll = false; + for (ValueDesc valueDesc : sourceValues) { + UnionType type = valueDesc.getUnionType(other, depth + 1); + if (type == UnionType.OTHERS) { + return type; + } + hasRangeAll = hasRangeAll || type == UnionType.RANGE_ALL; + } + // must hasRangeAll or hasTrue + return hasRangeAll ? UnionType.RANGE_ALL : UnionType.TRUE; + } else { + // process 'A or ((B or C) and ...)' + // then `this`: '(B or C)', `other`: A + boolean hasRangeAll = false; + boolean hasIsNull = false; + boolean allOtherNullForNullReference = true; + for (ValueDesc valueDesc : sourceValues) { + UnionType type = valueDesc.getUnionType(other, depth + 1); + if (type == UnionType.TRUE) { + return type; + } + if (type == UnionType.RANGE_ALL) { + hasRangeAll = true; + } else { + allOtherNullForNullReference = allOtherNullForNullReference && valueDesc.nullForNullReference(); + } + hasIsNull = hasIsNull || valueDesc instanceof IsNullValue; + } + if (hasRangeAll) { + if (hasIsNull) { + // A or ((B or C) and ....) + // if A union B is RANGE_ALL, C is IsNull + // then A or ((B or C) and ...) = A or (TRUE and ...) + // RangeAll or IsNull = TRUE + return UnionType.TRUE; + } + if (allOtherNullForNullReference) { + // A or ((B or C) and ....) + // if A union B is RANGE_ALL, and C is nullable + // then A or ((B or C) and ...) = A or (Range.all() and ...) + return UnionType.RANGE_ALL; + } + } + return UnionType.OTHERS; + } + } + } + + /** + * Represents unknown value expression. + */ + public static class UnknownValue extends ValueDesc { + + public UnknownValue(ExpressionRewriteContext context, Expression expression) { + super(context, expression); + } + + @Override + protected R visit(ValueDescVisitor visitor, C context) { + return visitor.visitUnknownValue(this, context); + } + + @Override + protected boolean nullable() { + return reference.nullable(); + } + + @Override + protected boolean nullForNullReference() { + return false; + } + + @Override + protected boolean containsAll(ValueDesc other, int depth) { + // when merge all the value desc, the value desc's reference are the same. + return other instanceof UnknownValue; + } + + @Override + protected IntersectType getIntersectType(ValueDesc other, int depth) { + return IntersectType.OTHERS; + } + + @Override + protected UnionType getUnionType(ValueDesc other, int depth) { + return UnionType.OTHERS; } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java index c08f3aafea6bc3..a25f1bb81d4e91 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java @@ -83,7 +83,7 @@ public class SimplifyArithmeticComparisonRule implements ExpressionPatternRuleFa public List> buildRules() { return ImmutableList.of( matchesType(ComparisonPredicate.class) - .thenApply(ctx -> simplify(ctx.expr, new ExpressionRewriteContext(ctx.cascadesContext))) + .thenApply(ctx -> simplify(ctx.expr, ctx.rewriteContext)) .toRule(ExpressionRuleType.SIMPLIFY_ARITHMETIC_COMPARISON) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyEqualBooleanLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyEqualBooleanLiteral.java new file mode 100644 index 00000000000000..0497406edd0e6c --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyEqualBooleanLiteral.java @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * Simplify expression equal to true / false: + * 1.'expr = true' => 'expr'; + * 2.'expr = false' => 'not expr'. + * + * NOTE: This rule may downgrade the performance for InferPredicate rule, + * because InferPredicate will collect predicate `f(xxx) = literal`, + * after this rule rewrite `f(xxx) = true/false` to `f(xxx)`/`not f(xxx)`, the predicate will not be collected. + * + * But we think this rule is more useful than harmful. + * + * What's more, for InferPredicate, it will collect f(xxx) = literal, and infer f(yyy) = literal, + * but f(yyy) may be very complex, so it is not always useful, so InferPredicate may also cause downgrade. + * By the way, if InferPredicate not considering the f(yyy) = literal is complex or not, + * the better way for it is to collect all the boolean predicates, not just only the 'xx compare literal' form. + */ +public class SimplifyEqualBooleanLiteral implements ExpressionPatternRuleFactory { + public static final SimplifyEqualBooleanLiteral INSTANCE = new SimplifyEqualBooleanLiteral(); + + @Override + public List> buildRules() { + return ImmutableList.of( + matchesType(EqualTo.class) + .when(this::needRewrite) + .then(equal -> rewrite(equal, (BooleanLiteral) equal.right())) + .toRule(ExpressionRuleType.SIMPLIFY_EQUAL_BOOLEAN_LITERAL) + ); + } + + private boolean needRewrite(EqualPredicate equal) { + // we don't rewrite 'slot = true/false' to slot, because: + // 1. for delete command, the where predicate need slot = xxx; + // 2. slot = true/false can generate a uniform for this slot, later it can use in constant propagation. + return !(equal.left() instanceof SlotReference) && equal.right() instanceof BooleanLiteral; + } + + private Expression rewrite(EqualTo equal, BooleanLiteral right) { + Expression left = equal.left(); + return right.equals(BooleanLiteral.TRUE) ? left : new Not(left); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java index d47f5877a96c96..cd86283afe8ebe 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java @@ -21,11 +21,16 @@ import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.CompoundValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.DiscreteValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.EmptyValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNotNullValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNullValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.NotDiscreteValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.RangeValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDescVisitor; import org.apache.doris.nereids.rules.rewrite.SkipSimpleExprs; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Expression; @@ -34,6 +39,7 @@ import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; +import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; @@ -43,10 +49,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Range; -import org.apache.commons.lang3.NotImplementedException; import java.util.List; -import java.util.stream.Collectors; +import java.util.Set; /** * This class implements the function to simplify expression range. @@ -57,6 +62,7 @@ * a in (1,2,3) and a > 1 => a in (2,3) * a in (1,2,3) and a in (3,4,5) => a = 3 * a in (1,2,3) and a in (4,5,6) => false + * a > 10 and (a < 10 or a > 20 ) => a > 20 * * The logic is as follows: * 1. for `And` expression. @@ -67,9 +73,40 @@ * 1. a > 1 => RangeValueDesc((1...+∞)), a > 2 => RangeValueDesc((2...+∞)) * 2. (1...+∞) intersect (2...+∞) => (2...+∞) * 2. for `Or` expression (similar to `And`). - * todo: support a > 10 and (a < 10 or a > 20 ) => a > 20 + * + * How to simplify range for a expression ? + * + * An expression may contain multiple references, then for each reference, we calculate its range. + * After getting the range of each reference, we can reconstruct the expression. + * + * We use `ValueDesc` to describe the range of a reference, it includes: + * 1. EmptyValueDesc: the expression is always false or null for this reference, like `a > 1 and a < 0`. + * 2. RangeValueDesc: the expression can be represented as a range for this reference, like `a > 1`. + * 3. DiscreteValueDesc: the expression can be represented as discrete values for this reference, like `a in (1,2,3)`. + * 4. NotDiscreteValueDesc: the expression can be represented as not discrete values for this reference, + * like `a not in (1,2,3)`. + * 5. IsNullValueDesc: the expression is `is null` for this reference, like `a is null`. + * 6. IsNotNullValueDesc: the expression is `is not null` for this reference, like `a is not null`. + * 7. CompoundValueDesc: the expression is a compound expression (And/Or) for this reference, + * like `a > 10 or a in (0, 1)` + * 8. UnknownValueDesc: we cannot infer the range for this reference. + * + * The expression is a tree structure, each node is an operator (And/Or), leaf node is a simple expression. + * The `ValueDesc` is also a tree structure, each node is a `CompoundValueDesc`, + * leaf node is one of the other `ValueDesc`. + * When we want to simplify a reference's range, that is to say, we want to get the merged `ValueDesc` + * for this reference. Here is the simplify range algorithm: + * 1. Convert the expression tree to `ValueDesc` tree from bottom to top. + * 2. When converting, we can merge `ValueDesc` in the same level for those have the same reference. + * The `merged` is the most important step, it will perform intersect/union operation according to the operator, + * and return a new `ValueDesc` for the reference, and make the reference's range more precise. + * 3. After getting the `ValueDesc` tree, we can convert it back to expression tree from bottom to top. + * + * Since the merged `ValueDesc` is more precise than the original one, + * the final expression is simplified than the original one. + * */ -public class SimplifyRange implements ExpressionPatternRuleFactory { +public class SimplifyRange implements ExpressionPatternRuleFactory, ValueDescVisitor { public static final SimplifyRange INSTANCE = new SimplifyRange(); @Override @@ -83,34 +120,22 @@ public List> buildRules() { } /** rewrite */ - public static Expression rewrite(CompoundPredicate expr, ExpressionRewriteContext context) { + public Expression rewrite(CompoundPredicate expr, ExpressionRewriteContext context) { if (SkipSimpleExprs.isSimpleExpr(expr)) { return expr; } ValueDesc valueDesc = (new RangeInference()).getValue(expr, context); - return INSTANCE.getExpression(valueDesc); + return valueDesc.accept(this, null); } - private Expression getExpression(ValueDesc value) { - if (value instanceof EmptyValue) { - return getExpression((EmptyValue) value); - } else if (value instanceof DiscreteValue) { - return getExpression((DiscreteValue) value); - } else if (value instanceof RangeValue) { - return getExpression((RangeValue) value); - } else if (value instanceof UnknownValue) { - return getExpression((UnknownValue) value); - } else { - throw new NotImplementedException("not implements"); - } - } - - private Expression getExpression(EmptyValue value) { + @Override + public Expression visitEmptyValue(EmptyValue value, Void context) { Expression reference = value.getReference(); return ExpressionUtils.falseOrNull(reference); } - private Expression getExpression(RangeValue value) { + @Override + public Expression visitRangeValue(RangeValue value, Void context) { Expression reference = value.getReference(); Range range = value.getRange(); List result = Lists.newArrayList(); @@ -135,27 +160,51 @@ private Expression getExpression(RangeValue value) { } } - private Expression getExpression(DiscreteValue value) { - return ExpressionUtils.toInPredicateOrEqualTo(value.getReference(), - value.getValues().stream().map(Literal.class::cast).collect(Collectors.toList())); + @Override + public Expression visitDiscreteValue(DiscreteValue value, Void context) { + return getDiscreteExpression(value.getReference(), value.values); } - private Expression getExpression(UnknownValue value) { - List sourceValues = value.getSourceValues(); - if (sourceValues.isEmpty()) { - return value.getReference(); - } else { - return getExpression(value.getExpressionRewriteContext(), sourceValues, value.isAnd()); + @Override + public Expression visitNotDiscreteValue(NotDiscreteValue value, Void context) { + return new Not(getDiscreteExpression(value.getReference(), value.values)); + } + + @Override + public Expression visitIsNullValue(IsNullValue value, Void context) { + return new IsNull(value.getReference()); + } + + @Override + public Expression visitIsNotNullValue(IsNotNullValue value, Void context) { + return value.getNotExpression(); + } + + @Override + public Expression visitCompoundValue(CompoundValue value, Void context) { + return getCompoundExpression(value.getExpressionRewriteContext(), value.getSourceValues(), value.isAnd()); + } + + @Override + public Expression visitUnknownValue(UnknownValue value, Void context) { + return value.getReference(); + } + + private Expression getDiscreteExpression(Expression reference, Set values) { + ImmutableList.Builder options = ImmutableList.builderWithExpectedSize(values.size()); + for (ComparableLiteral value : values) { + options.add((Expression) value); } + return ExpressionUtils.toInPredicateOrEqualTo(reference, options.build()); } /** getExpression */ - public Expression getExpression(ExpressionRewriteContext context, + public Expression getCompoundExpression(ExpressionRewriteContext context, List sourceValues, boolean isAnd) { Preconditions.checkArgument(!sourceValues.isEmpty()); List sourceExprs = Lists.newArrayListWithExpectedSize(sourceValues.size()); for (ValueDesc sourceValue : sourceValues) { - Expression expr = getExpression(sourceValue); + Expression expr = sourceValue.accept(this, null); if (isAnd) { sourceExprs.addAll(ExpressionUtils.extractConjunction(expr)); } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java index 72283f46a49e45..2f160aba3fced6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.properties.DataTrait; import org.apache.doris.nereids.properties.OrderKey; @@ -77,7 +78,7 @@ * then use them and the plan's expressions to infer more equal sets and constants uniforms, * finally use the combine uniforms to replace this plan's expression's slot with literals. */ -public class ConstantPropagation extends DefaultPlanRewriter implements CustomRewriter { +public class ConstantPropagation extends DefaultPlanRewriter implements CustomRewriter { @Override public Plan rewriteRoot(Plan plan, JobContext jobContext) { @@ -85,15 +86,14 @@ public Plan rewriteRoot(Plan plan, JobContext jobContext) { if (plan.containsType(LogicalApply.class)) { return plan; } - ExpressionRewriteContext context = new ExpressionRewriteContext(jobContext.getCascadesContext()); - return plan.accept(this, context); + return plan.accept(this, jobContext.getCascadesContext()); } @Override - public Plan visitLogicalFilter(LogicalFilter filter, ExpressionRewriteContext context) { + public Plan visitLogicalFilter(LogicalFilter filter, CascadesContext context) { filter = visitChildren(this, filter, context); Expression oldPredicate = filter.getPredicate(); - Expression newPredicate = replaceConstantsAndRewriteExpr(filter, oldPredicate, true, context); + Expression newPredicate = replaceConstantsAndRewriteExpr(filter, oldPredicate, context); if (isExprEqualIgnoreOrder(oldPredicate, newPredicate)) { return filter; } else { @@ -103,10 +103,10 @@ public Plan visitLogicalFilter(LogicalFilter filter, ExpressionR } @Override - public Plan visitLogicalHaving(LogicalHaving having, ExpressionRewriteContext context) { + public Plan visitLogicalHaving(LogicalHaving having, CascadesContext context) { having = visitChildren(this, having, context); Expression oldPredicate = having.getPredicate(); - Expression newPredicate = replaceConstantsAndRewriteExpr(having, oldPredicate, true, context); + Expression newPredicate = replaceConstantsAndRewriteExpr(having, oldPredicate, context); if (isExprEqualIgnoreOrder(oldPredicate, newPredicate)) { return having; } else { @@ -116,15 +116,16 @@ public Plan visitLogicalHaving(LogicalHaving having, ExpressionR } @Override - public Plan visitLogicalProject(LogicalProject project, ExpressionRewriteContext context) { + public Plan visitLogicalProject(LogicalProject project, CascadesContext context) { project = visitChildren(this, project, context); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(project, context); Pair, Map> childEqualTrait = - getChildEqualSetAndConstants(project, context); + getChildEqualSetAndConstants(project, rewriteContext); ImmutableList.Builder newProjectsBuilder = ImmutableList.builderWithExpectedSize(project.getProjects().size()); for (NamedExpression expr : project.getProjects()) { - newProjectsBuilder.add( - replaceNameExpressionConstants(expr, context, childEqualTrait.first, childEqualTrait.second)); + newProjectsBuilder.add(replaceNameExpressionConstants( + expr, rewriteContext, childEqualTrait.first, childEqualTrait.second)); } List newProjects = newProjectsBuilder.build(); @@ -132,16 +133,18 @@ public Plan visitLogicalProject(LogicalProject project, Expressi } @Override - public Plan visitLogicalSort(LogicalSort sort, ExpressionRewriteContext context) { + public Plan visitLogicalSort(LogicalSort sort, CascadesContext context) { sort = visitChildren(this, sort, context); - Pair, Map> childEqualTrait = getChildEqualSetAndConstants(sort, context); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(sort, context); + Pair, Map> childEqualTrait + = getChildEqualSetAndConstants(sort, rewriteContext); // for be, order key must be a column, not a literal, so `order by 100#xx` is ok, // but `order by 100` will make be core. // so after replaced, we need to remove the constant expr. ImmutableList.Builder newOrderKeysBuilder = ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size()); for (OrderKey key : sort.getOrderKeys()) { - Expression newExpr = replaceConstants(key.getExpr(), false, context, + Expression newExpr = replaceConstants(key.getExpr(), false, rewriteContext, childEqualTrait.first, childEqualTrait.second); if (!newExpr.isConstant()) { newOrderKeysBuilder.add(key.withExpression(newExpr)); @@ -158,15 +161,17 @@ public Plan visitLogicalSort(LogicalSort sort, ExpressionRewrite } @Override - public Plan visitLogicalAggregate(LogicalAggregate aggregate, ExpressionRewriteContext context) { + public Plan visitLogicalAggregate(LogicalAggregate aggregate, CascadesContext context) { aggregate = visitChildren(this, aggregate, context); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(aggregate, context); Pair, Map> childEqualTrait = - getChildEqualSetAndConstants(aggregate, context); + getChildEqualSetAndConstants(aggregate, rewriteContext); List oldGroupByExprs = aggregate.getGroupByExpressions(); List newGroupByExprs = Lists.newArrayListWithExpectedSize(oldGroupByExprs.size()); for (Expression expr : oldGroupByExprs) { - Expression newExpr = replaceConstants(expr, false, context, childEqualTrait.first, childEqualTrait.second); + Expression newExpr + = replaceConstants(expr, false, rewriteContext, childEqualTrait.first, childEqualTrait.second); if (!newExpr.isConstant()) { newGroupByExprs.add(newExpr); } @@ -196,7 +201,7 @@ public Plan visitLogicalAggregate(LogicalAggregate aggregate, Ex for (NamedExpression expr : oldOutputExprs) { // ColumnPruning will also add all group by expression into output expressions // agg output need contains group by expression - Expression replacedExpr = replaceConstants(expr, false, context, + Expression replacedExpr = replaceConstants(expr, false, rewriteContext, childEqualTrait.first, childEqualTrait.second); Expression newOutputExpr = newGroupByExprSet.contains(expr) ? expr : replacedExpr; if (newOutputExpr instanceof NamedExpression) { @@ -225,7 +230,7 @@ public Plan visitLogicalAggregate(LogicalAggregate aggregate, Ex } @Override - public Plan visitLogicalJoin(LogicalJoin join, ExpressionRewriteContext context) { + public Plan visitLogicalJoin(LogicalJoin join, CascadesContext context) { // Combine all the join conjuncts together, may infer more constant relations. // Then after rewrite the combine conjuncts, we need split the rewritten expression into hash/other/mark // join conjuncts. But we can not extract the mark join conjuncts from the rewritten expression. @@ -241,17 +246,8 @@ public Plan visitLogicalJoin(LogicalJoin join, E hashOtherConjuncts.addAll(join.getHashJoinConjuncts()); hashOtherConjuncts.addAll(join.getOtherJoinConjuncts()); if (!hashOtherConjuncts.isEmpty()) { - // useInnerInfer = true means for a nullable column 'column_a', will extract constant relation - // (include 'nullable_a = column_b' and 'nullable_a = literal') from the expression itself, - // then use the extracted constant relation + children's constant relation to rewrite the expression. - // then its effect will result in: the special NULL (those all its ancestors are AND/OR) will be replaced - // with FALSE; - // so useInnerInfer = false will not replace the NULL with FALSE. - // For null ware left anti join, NULL can not replace with FALSE. - boolean useInnerInfer = join.getJoinType() != JoinType.NULL_AWARE_LEFT_ANTI_JOIN; Expression oldHashOtherPredicate = ExpressionUtils.and(hashOtherConjuncts); - Expression newHashOtherPredicate - = replaceConstantsAndRewriteExpr(join, oldHashOtherPredicate, useInnerInfer, context); + Expression newHashOtherPredicate = replaceConstantsAndRewriteExpr(join, oldHashOtherPredicate, context); if (!isExprEqualIgnoreOrder(oldHashOtherPredicate, newHashOtherPredicate)) { // TODO: code from FindHashConditionForJoin Pair, List> pair @@ -288,7 +284,7 @@ public Plan visitLogicalJoin(LogicalJoin join, E } @Override - public Plan visitLogicalSink(LogicalSink sink, ExpressionRewriteContext context) { + public Plan visitLogicalSink(LogicalSink sink, CascadesContext context) { sink = visitChildren(this, sink, context); // // for sql: create table t as select cast('1' as varchar(30)) // // the select will add a parent plan: result sink. the result sink contains a output slot reference, and its @@ -309,18 +305,30 @@ public Plan visitLogicalSink(LogicalSink sink, ExpressionRewrite * replace constants and rewrite expression. */ @VisibleForTesting - public Expression replaceConstantsAndRewriteExpr(LogicalPlan plan, Expression expression, - boolean useInnerInfer, ExpressionRewriteContext context) { + public Expression replaceConstantsAndRewriteExpr(LogicalPlan plan, Expression expression, CascadesContext context) { // for expression `a = 1 and a + b = 2 and b + c = 2 and c + d =2 and ...`: // propagate constant `a = 1`, then get `1 + b = 2`, after rewrite this expression, will get `b = 1`; // then propagate constant `b = 1`, then get `1 + c = 2`, after rewrite this expression, will get `c = 1`, // ... // so constant propagate and rewrite expression need to do in a loop. - Pair, Map> childEqualTrait = getChildEqualSetAndConstants(plan, context); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(plan, context); + Pair, Map> childEqualTrait + = getChildEqualSetAndConstants(plan, rewriteContext); Expression afterExpression = expression; + // useInnerInfer = true means for a nullable column 'column_a', will extract constant relation + // (include 'nullable_a = column_b' and 'nullable_a = literal') from the expression itself, + // then use the extracted constant relation + children's constant relation to rewrite the expression. + // then its effect will result in: the special NULL (those all its ancestors are AND/OR) will be replaced + // with FALSE; + // so useInnerInfer = false will not replace the NULL with FALSE. + // For null ware left anti join, NULL can not replace with FALSE. + boolean useInnerInfer = plan instanceof LogicalFilter + || plan instanceof LogicalHaving + || (plan instanceof LogicalJoin + && ((LogicalJoin) plan).getJoinType() != JoinType.NULL_AWARE_LEFT_ANTI_JOIN); for (int i = 0; i < 100; i++) { Expression beforeExpression = afterExpression; - afterExpression = replaceConstants(beforeExpression, useInnerInfer, context, + afterExpression = replaceConstants(beforeExpression, useInnerInfer, rewriteContext, childEqualTrait.first, childEqualTrait.second); if (isExprEqualIgnoreOrder(beforeExpression, afterExpression)) { break; @@ -330,7 +338,7 @@ public Expression replaceConstantsAndRewriteExpr(LogicalPlan plan, Expression ex } beforeExpression = afterExpression; afterExpression = ExpressionNormalizationAndOptimization.NO_MIN_MAX_RANGE_INSTANCE - .rewrite(beforeExpression, context); + .rewrite(beforeExpression, rewriteContext); } return afterExpression; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java index 6d28c7d030f8ab..a3426a9548dca7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule; -import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; @@ -34,7 +33,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.util.ExpressionUtils; -import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -53,9 +51,9 @@ public List buildRules() { .thenApply(ctx -> { LogicalFilter filter = ctx.root; ImmutableSet.Builder newConjuncts = ImmutableSet.builder(); - ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(filter, ctx.cascadesContext); for (Expression expression : filter.getConjuncts()) { - expression = FoldConstantRule.evaluate(eliminateNullLiteral(expression), context); + expression = FoldConstantRule.evaluate(expression, context); if (expression == BooleanLiteral.FALSE || expression.isNullLiteral()) { return new LogicalEmptyRelation(ctx.statementContext.getNextRelationId(), filter.getOutput()); @@ -84,11 +82,10 @@ public static LogicalPlan eliminateFilterOnOneRowRelation( Map replaceMap = ExpressionUtils.generateReplaceMap(filter.child().getOutputs()); ImmutableSet.Builder newConjuncts = ImmutableSet.builder(); - ExpressionRewriteContext context = new ExpressionRewriteContext(cascadesContext); + ExpressionRewriteContext context = new ExpressionRewriteContext(filter, cascadesContext); for (Expression expression : filter.getConjuncts()) { Expression newExpr = ExpressionUtils.replace(expression, replaceMap); - Expression foldExpression = FoldConstantRule.evaluate(eliminateNullLiteral(newExpr), context); - + Expression foldExpression = FoldConstantRule.evaluate(newExpr, context); if (foldExpression == BooleanLiteral.FALSE || expression.isNullLiteral()) { return new LogicalEmptyRelation( cascadesContext.getStatementContext().getNextRelationId(), filter.getOutput()); @@ -104,31 +101,4 @@ public static LogicalPlan eliminateFilterOnOneRowRelation( return new LogicalFilter<>(conjuncts, filter.child()); } } - - @VisibleForTesting - public static Expression eliminateNullLiteral(Expression expression) { - if (!expression.anyMatch(e -> ((Expression) e).isNullLiteral())) { - return expression; - } - - return replaceNullToFalse(expression); - } - - // only replace null which its ancestors are all and/or - // NOTICE: NOT's type is boolean too, if replace null to false in NOT, will got NOT(NULL) = NOT(FALSE) = TRUE, - // but it is wrong, NOT(NULL) = NULL. For a filter, only the AND / OR, can keep NULL as FALSE. - private static Expression replaceNullToFalse(Expression expression) { - if (expression.isNullLiteral()) { - return BooleanLiteral.FALSE; - } - - if (expression instanceof CompoundPredicate) { - ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize( - expression.children().size()); - expression.children().forEach(e -> builder.add(replaceNullToFalse(e))); - return expression.withChildren(builder.build()); - } - - return expression; - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java new file mode 100644 index 00000000000000..929ecbc272391f --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/JoinExtractOrFromCaseWhen.java @@ -0,0 +1,347 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.JoinUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Join extract OR expression from case when / if / nullif expressions. + * + * 1. extract conditions for one side, latter can push down the one side condition: + * + * t1 join t2 on not (case when t1.a = 1 then t2.a else t2.b) + t2.b + t2.c > 10) + * => + * t1 join t2 on not (case when t1.a = 1 then t2.a else t2.b end) + t2.b + t2.c > 10) + * AND (not (t2.a + t2.b + t2.c > 10) or not (t2.b + t2.b + t2.c > 10)) + * + * + * 2. extract conditions for both sides, which use for OR EXPANSION rule: + * the OR EXPANSION is an OR expression which all its disjuncts are hash join conditions. + * + * t1 join t2 on (case when t1.a = 1 then t2.a else t2.b end) = t1.a + t1.b + * => + * t1 join t2 on (case when t1.a = 1 then t2.a else t2.b end) = t1.a + t1.b + * AND (t2.a = t1.a + t1.b or t2.b = t1.a + t1.b) + + * Notice we don't extract more than one case when like expressions. + * because it may generate expressions with combinatorial explosion. + * + * (((case c1 then p1 else p2 end) + (case when d1 then q1 else q2 end))) + a > 10 + * => + * (p1 + q1 + a > 10) + * or (p1 + q2 + a > 10) + * or (p2 + q1 + a > 10) + * or (p2 + q2 + a > 10) + * + * so we only extract at most one case when like expression for each condition. + */ +public class JoinExtractOrFromCaseWhen implements RewriteRuleFactory { + + @Override + public List buildRules() { + return ImmutableList.of(logicalJoin() + .when(this::needRewrite) + .thenApply(ctx -> rewrite(ctx.root, new ExpressionRewriteContext(ctx.root, ctx.cascadesContext))) + .toRule(RuleType.JOIN_EXTRACT_OR_FROM_CASE_WHEN)); + } + + private boolean needRewrite(LogicalJoin join) { + if (PushDownJoinOtherCondition.needRewrite(join) || OrExpansion.INSTANCE.needRewriteJoin(join)) { + Set leftSlots = join.left().getOutputSet(); + Set rightSlots = join.right().getOutputSet(); + for (Expression expr : join.getOtherJoinConjuncts()) { + if (isConditionNeedRewrite(expr, leftSlots, rightSlots)) { + return true; + } + } + } + if (!join.isMarkJoin()) { + Set leftSlots = join.left().getOutputSet(); + Set rightSlots = join.right().getOutputSet(); + for (Expression expr : join.getHashJoinConjuncts()) { + if (isConditionNeedRewrite(expr, leftSlots, rightSlots)) { + return true; + } + } + } + return false; + } + + // 1. expr contains slots from both sides; + private boolean isConditionNeedRewrite(Expression expr, Set leftSlots, Set rightSlots) { + if (expr.containsUniqueFunction()) { + return false; + } + Set exprSlots = expr.getInputSlots(); + // all slots are from one side, no need to process it. + if (Collections.disjoint(exprSlots, leftSlots) || Collections.disjoint(exprSlots, rightSlots) + || !ExpressionUtils.containsCaseWhenLikeType(expr)) { + return false; + } + return true; + } + + private Plan rewrite(LogicalJoin join, ExpressionRewriteContext context) { + Set newOtherConditions = Sets.newLinkedHashSetWithExpectedSize(join.getOtherJoinConjuncts().size()); + newOtherConditions.addAll(join.getOtherJoinConjuncts()); + int oldCondSize = newOtherConditions.size(); + AtomicReference> leastOrExpandCondRef = new AtomicReference<>(); + for (Expression expr : join.getOtherJoinConjuncts()) { + tryAddOrExpansionHashCondition(expr, join, context, leastOrExpandCondRef); + } + for (Expression expr : join.getOtherJoinConjuncts()) { + extractExpression(expr, context, join, newOtherConditions, leastOrExpandCondRef); + } + if (!join.isMarkJoin()) { + // Notice: if join's hash conditions is not empty, then OrExpansion.needRewriteJoin will return fail + // so it will not extract OrExpansion from the hash condition, it only extract one side condition + // for hash condition: if(t1.a > 10, 1, 100) = if(t2.x > 10, 2, 200), + // we can still extract two one-side condition: + // 1) if(t1.a > 10, 1, 100) = 2 or if(t1.a > 10, 1, 100) = 200 + // 2) if(t2.x > 10, 2, 200) = 1 or if(t2.x > 10, 2, 200) = 100 + for (Expression expr : join.getHashJoinConjuncts()) { + extractExpression(expr, context, join, newOtherConditions, leastOrExpandCondRef); + } + } + if (leastOrExpandCondRef.get() != null) { + newOtherConditions.add(leastOrExpandCondRef.get().first); + } + if (newOtherConditions.size() == oldCondSize) { + return join; + } else { + return join.withJoinConjuncts(join.getHashJoinConjuncts(), ImmutableList.copyOf(newOtherConditions), + join.getJoinReorderContext()); + } + } + + private void extractExpression(Expression expr, ExpressionRewriteContext context, + LogicalJoin join, Set newOtherConditions, + AtomicReference> leastOrExpandCondRef) { + Set leftSlots = join.left().getOutputSet(); + Set rightSlots = join.right().getOutputSet(); + if (!isConditionNeedRewrite(expr, leftSlots, rightSlots)) { + return; + } + List containsLeftSlotChildIndexes = Lists.newArrayList(); + List containsRightSlotChildIndexes = Lists.newArrayList(); + for (int i = 0; i < expr.children().size(); i++) { + Set childSlots = expr.child(i).getInputSlots(); + if (!Collections.disjoint(childSlots, leftSlots)) { + containsLeftSlotChildIndexes.add(i); + } + if (!Collections.disjoint(childSlots, rightSlots)) { + containsRightSlotChildIndexes.add(i); + } + } + // all slots are from one side, no need handle + if (containsLeftSlotChildIndexes.isEmpty() || containsRightSlotChildIndexes.isEmpty()) { + return; + } + boolean canPushDownOther = PushDownJoinOtherCondition.needRewrite(join); + boolean extractedLeftSideCond = canPushDownOther && PushDownJoinOtherCondition.PUSH_DOWN_LEFT_VALID_TYPE + .contains(join.getJoinType()); + boolean extractedRightSideCond = canPushDownOther && PushDownJoinOtherCondition.PUSH_DOWN_RIGHT_VALID_TYPE + .contains(join.getJoinType()); + boolean extractOrExpansionCond = OrExpansion.INSTANCE.needRewriteJoin(join); + // eliminate all the right slots of all children, but we rewrite at most 1 case when expression, + // so require contains right slot child num not exceeds 1. + if (extractedLeftSideCond && containsRightSlotChildIndexes.size() == 1) { + doExtractExpression(expr, containsRightSlotChildIndexes.get(0), true, leftSlots, rightSlots) + .ifPresent(newOtherConditions::add); + } + // eliminate all the left slots of all children, but we rewrite at most 1 case when expression, + // so require contains left slot child num not exceeds 1. + if (extractedRightSideCond && containsLeftSlotChildIndexes.size() == 1) { + doExtractExpression(expr, containsLeftSlotChildIndexes.get(0), false, leftSlots, rightSlots) + .ifPresent(newOtherConditions::add); + } + if (extractOrExpansionCond && expr instanceof EqualPredicate) { + Optional orExpansionExpr = Optional.empty(); + if (containsLeftSlotChildIndexes.size() == 1 && containsRightSlotChildIndexes.size() == 2) { + // equal's two children all contain right slots, while only one child contains left slot, + // then we eliminate the right slot from the child which it contains left slots. + orExpansionExpr = doExtractExpression( + expr, containsLeftSlotChildIndexes.get(0), true, leftSlots, rightSlots); + } else if (containsLeftSlotChildIndexes.size() == 2 && containsRightSlotChildIndexes.size() == 1) { + // equal's two children all contain left slots, while one child contains left slot, + // then we eliminate the left slot from the child which it contains right slots. + orExpansionExpr = doExtractExpression( + expr, containsRightSlotChildIndexes.get(0), false, leftSlots, rightSlots); + } + orExpansionExpr.ifPresent( + cond -> tryAddOrExpansionHashCondition(cond, join, context, leastOrExpandCondRef)); + } + } + + // Or Expansion only use one condition, so we keep the one with least disjunctions. + private void tryAddOrExpansionHashCondition(Expression condition, LogicalJoin join, + ExpressionRewriteContext context, AtomicReference> leastOrExpandCondRef) { + // Or Expansion only works for all the disjunctions are equal predicates + List disjunctions = ExpressionUtils.extractDisjunction(condition); + List remainOtherConditions = JoinUtils.extractExpressionForHashTable( + join.left().getOutput(), join.right().getOutput(), disjunctions).second; + int hashCondLen = disjunctions.size() - remainOtherConditions.size(); + // no hash condition extracted, all are other conditions + if (hashCondLen == 0) { + return; + } + + for (Expression expr : remainOtherConditions) { + // for case when t1.a > t2.x then t1.a when t1.b > t2.y then t1. b else null end = t2.x + // then will extract E1 = (t1.a = t2.x or t1.b = t2.x or null = t2.x) + // but E1 can not use as OR Expansion condition, because null = t2.x is not a valid hash join condition. + // but after we fold null = t2.x to null, latter expression simplifier can simplify E1 + // to (t1.a = t2.x or t1.b = t2.x), then it becomes a valid OR Expansion condition. + Expression foldExpr = FoldConstantRuleOnFE.evaluate(expr, context); + if (!foldExpr.isLiteral()) { + return; + } + // foldExpr should be NULL / TRUE /FALSE, later expression simplifier can handle them. + } + + Pair leastOrExpandCond = leastOrExpandCondRef.get(); + int oldHashCondLen = leastOrExpandCond == null ? -1 : leastOrExpandCond.second; + if (oldHashCondLen == -1 || hashCondLen < oldHashCondLen) { + leastOrExpandCondRef.set(Pair.of(condition, hashCondLen)); + } + } + + // extract case when expression from `expr`'s child C at `extractChildIndex`. + // after extraction, new C will contain only slots from one side indicated by `extractedChildSlotFromLeft`, + // but new C can allow no contains any slots. + private Optional doExtractExpression(Expression expr, int extractChildIndex, + boolean extractChildSlotFromLeft, Set leftSlots, Set rightSlots) { + Expression expandChild = expr.child(extractChildIndex); + Optional> resultOpt = tryExtractCaseWhen( + expandChild, extractChildSlotFromLeft, leftSlots, rightSlots); + if (!resultOpt.isPresent()) { + return Optional.empty(); + } + + List expandTargetExpressions = resultOpt.get(); + if (expandTargetExpressions.size() <= 1) { + // if size = 1, then C don't expand, should be just the expr itself. + return Optional.empty(); + } + + List newChildren = Lists.newArrayList(expr.children()); + List disjuncts = Lists.newArrayListWithExpectedSize(expandTargetExpressions.size()); + for (Expression expandTargetExpr : expandTargetExpressions) { + newChildren.set(extractChildIndex, expandTargetExpr); + disjuncts.add(expr.withChildren(newChildren)); + } + + return Optional.of(ExpressionUtils.or(disjuncts)); + } + + // try to extract case when like expressions from expr. + // after extraction, all slots in expr are from one side indicated by `slotFromLeft`. + // if `expr`'s all slots are already from `slotFromLeft`, return expr itself, no need handle with its children. + // otherwise will recurse its children to extract case when like expressions. + private Optional> tryExtractCaseWhen(Expression expr, boolean slotFromLeft, + Set leftSlots, Set rightSlots) { + if (isAllSlotsFromLeftSide(expr, slotFromLeft, leftSlots, rightSlots)) { + return Optional.of(ImmutableList.of(expr)); + } + + if (!ExpressionUtils.containsCaseWhenLikeType(expr)) { + return Optional.empty(); + } + + // process case when like expression. + Optional> caseWhenLikeResults = ExpressionUtils.getCaseWhenLikeBranchResults(expr); + if (caseWhenLikeResults.isPresent()) { + for (Expression branchResult : caseWhenLikeResults.get()) { + if (!isAllSlotsFromLeftSide(branchResult, slotFromLeft, leftSlots, rightSlots)) { + return Optional.empty(); + } + } + return caseWhenLikeResults; + } + + int expandChildIndex = -1; + List expandChildExpressions = null; + List newChildren = Lists.newArrayListWithExpectedSize(expr.children().size()); + for (int i = 0; i < expr.children().size(); i++) { + Expression child = expr.child(i); + Optional> childExtractedOpt = tryExtractCaseWhen( + child, slotFromLeft, leftSlots, rightSlots); + if (!childExtractedOpt.isPresent()) { + return Optional.empty(); + } + List childExtracted = childExtractedOpt.get(); + if (childExtracted.size() == 1) { + Expression newChild = childExtracted.get(0); + newChildren.add(newChild); + } else { + // more than one child to expand + if (expandChildIndex != -1) { + return Optional.empty(); + } + expandChildIndex = i; + expandChildExpressions = childExtracted; + // will replace the child later, add a placeholder first + newChildren.add(child); + } + } + if (expandChildIndex == -1) { + return Optional.empty(); + } + List resultExpressions = Lists.newArrayListWithExpectedSize(expandChildExpressions.size()); + for (Expression expandChildExpr : expandChildExpressions) { + newChildren.set(expandChildIndex, expandChildExpr); + Expression newExpr = expr.withChildren(newChildren); + resultExpressions.add(newExpr); + } + return Optional.of(resultExpressions); + } + + // check whether all slots in expr are from one side, allow contains no slots. + private boolean isAllSlotsFromLeftSide(Expression expr, boolean slotFromLeft, + Set leftSlots, Set rightSlots) { + Set exprSlots = expr.getInputSlots(); + if (slotFromLeft) { + // no slots from right + return Collections.disjoint(exprSlots, rightSlots); + } else { + // no slots from left + return Collections.disjoint(exprSlots, leftSlots); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java index 4cb07fd44fd89b..cf06ee608c14dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java @@ -121,12 +121,10 @@ public Plan visitLogicalCTEAnchor( @Override public Plan visitLogicalJoin(LogicalJoin join, OrExpandsionContext ctx) { join = (LogicalJoin) this.visit(join, ctx); - if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) { - return join; - } - if (!supportJoinType.contains(join.getJoinType())) { + if (!needRewriteJoin(join)) { return join; } + Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(), "Only Expansion nest loop join without hashCond"); @@ -207,6 +205,16 @@ public Plan visitLogicalJoin(LogicalJoin join, O return null; } + /** + * check whether it need to rewrite the join + */ + public boolean needRewriteJoin(LogicalJoin join) { + if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) { + return false; + } + return supportJoinType.contains(join.getJoinType()); + } + private Map constructReplaceMap(LogicalCTEConsumer leftConsumer, Map leftCloneToLeft, LogicalCTEConsumer rightConsumer, Map rightCloneToRight) { Map replaced = new HashMap<>(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 459082926d3483..8da895d34ab7fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -344,7 +344,8 @@ public ImmutableSet visitLogicalAggregate(LogicalAggregate PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of( + /** + * left push support type + */ + public static final ImmutableList PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of( JoinType.INNER_JOIN, JoinType.LEFT_SEMI_JOIN, JoinType.RIGHT_OUTER_JOIN, @@ -46,7 +49,10 @@ public class PushDownJoinOtherCondition extends OneRewriteRuleFactory { JoinType.CROSS_JOIN ); - private static final ImmutableList PUSH_DOWN_RIGHT_VALID_TYPE = ImmutableList.of( + /** + * right push support type + */ + public static final ImmutableList PUSH_DOWN_RIGHT_VALID_TYPE = ImmutableList.of( JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN, JoinType.LEFT_ANTI_JOIN, @@ -60,7 +66,8 @@ public class PushDownJoinOtherCondition extends OneRewriteRuleFactory { public Rule build() { return logicalJoin() // TODO: we may need another rule to handle on true or on false condition - .when(join -> !join.getOtherJoinConjuncts().isEmpty() && !join.isMarkJoin()) + .when(join -> !join.getOtherJoinConjuncts().isEmpty()) + .when(PushDownJoinOtherCondition::needRewrite) .then(join -> { List otherJoinConjuncts = join.getOtherJoinConjuncts(); List remainingOther = Lists.newArrayList(); @@ -93,7 +100,14 @@ && allCoveredBy(otherConjunct, join.right().getOutputSet())) { }).toRule(RuleType.PUSH_DOWN_JOIN_OTHER_CONDITION); } - private boolean allCoveredBy(Expression predicate, Set inputSlotSet) { + /** + * check need rewrite + */ + public static boolean needRewrite(LogicalJoin join) { + return !join.isMarkJoin(); + } + + private static boolean allCoveredBy(Expression predicate, Set inputSlotSet) { return inputSlotSet.containsAll(predicate.getInputSlots()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectIntoUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectIntoUnion.java index 78610c949a7546..d212c11f457a2e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectIntoUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectIntoUnion.java @@ -49,7 +49,7 @@ public Rule build() { .when(this::canPushProjectIntoUnion ).thenApply(ctx -> { LogicalProject p = ctx.root; - ExpressionRewriteContext expressionRewriteContext = new ExpressionRewriteContext(ctx.cascadesContext); + ExpressionRewriteContext expressionRewriteContext = new ExpressionRewriteContext(p, ctx.cascadesContext); LogicalUnion union = p.child(); ImmutableList.Builder> newConstExprs = ImmutableList.builder(); for (List constExprs : union.getConstantExprsList()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/And.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/And.java index 683960cecaa3ff..0ab5ee108c1713 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/And.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/And.java @@ -35,9 +35,9 @@ public class And extends CompoundPredicate { * @param right right child of comparison predicate */ public And(Expression left, Expression right) { - super(ExpressionUtils.mergeList( + this(ExpressionUtils.mergeList( ExpressionUtils.extractConjunction(left), - ExpressionUtils.extractConjunction(right)), "AND"); + ExpressionUtils.extractConjunction(right))); } public And(List children) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java index 3e71b3b89a01d8..e45b5013dd5386 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java @@ -36,7 +36,7 @@ public EqualTo(Expression left, Expression right) { } public EqualTo(Expression left, Expression right, boolean inferred) { - super(ImmutableList.of(left, right), "=", inferred); + this(ImmutableList.of(left, right), inferred); } private EqualTo(List children, boolean inferred) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java index 0b77118eb1835f..96bdb3ba257314 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java @@ -19,7 +19,6 @@ import org.apache.doris.nereids.exceptions.NotSupportedException; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeAcquire; import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeArithmetic; import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeExtractAndTransform; @@ -56,7 +55,7 @@ public enum ExpressionEvaluator { * Evaluate the value of the expression. */ public Expression eval(Expression expression) { - if (!(expression.isConstant() || expression.foldable()) || expression instanceof AggregateFunction) { + if (!(expression.isConstant() || expression.foldable())) { return expression; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java index 12c0252d3a302f..e12276ff57fb4d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java @@ -50,13 +50,11 @@ public Not(List child, boolean isGeneratedIsNotNull, boolean inferre } public Not(Expression child, boolean isGeneratedIsNotNull) { - super(ImmutableList.of(child)); - this.isGeneratedIsNotNull = isGeneratedIsNotNull; + this(ImmutableList.of(child), isGeneratedIsNotNull); } private Not(List child, boolean isGeneratedIsNotNull) { - super(child); - this.isGeneratedIsNotNull = isGeneratedIsNotNull; + this(child, isGeneratedIsNotNull, false); } public boolean isGeneratedIsNotNull() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java index 41ed2bbb2c68f6..e356811fdefc47 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java @@ -122,6 +122,11 @@ public int computeHashCode() { return Objects.hash(distinct, getName(), children); } + @Override + public boolean foldable() { + return false; + } + @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitAggregateFunction(this, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/generator/TableGeneratingFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/generator/TableGeneratingFunction.java index b67f7c1df623c5..b5b03002bd1be6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/generator/TableGeneratingFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/generator/TableGeneratingFunction.java @@ -52,4 +52,9 @@ public R accept(ExpressionVisitor visitor, C context) { protected GeneratorFunctionParams getFunctionParams(List arguments) { return new GeneratorFunctionParams(this, getName(), arguments, isInferred()); } + + @Override + public boolean foldable() { + return false; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java index cf1323834b729a..118ed8622ec602 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java @@ -98,4 +98,19 @@ public FunctionSignature customSignature() { return null; } } + + /** get condition */ + public Expression getCondition() { + return child(0); + } + + /** get true value */ + public Expression getTrueValue() { + return child(1); + } + + /** get false value */ + public Expression getFalseValue() { + return child(2); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java index 80f7514e0f90b9..69b38cc46e6df2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/NullLiteral.java @@ -19,6 +19,7 @@ import org.apache.doris.analysis.LiteralExpr; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.NullType; @@ -30,9 +31,10 @@ public class NullLiteral extends Literal implements ComparableLiteral { public static final NullLiteral INSTANCE = new NullLiteral(); + public static final NullLiteral BOOLEAN_INSTANCE = new NullLiteral(BooleanType.INSTANCE); public NullLiteral() { - super(NullType.INSTANCE); + this(NullType.INSTANCE); } public NullLiteral(DataType dataType) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java index 1e563d5676dd21..1a1b3b2e84e639 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java @@ -227,7 +227,7 @@ public void computeUnique(DataTrait.Builder builder) { @Override public void computeUniform(DataTrait.Builder builder) { final Optional context = ConnectContext.get() == null ? Optional.empty() - : Optional.of(new ExpressionRewriteContext(CascadesContext.initContext( + : Optional.of(new ExpressionRewriteContext(this, CascadesContext.initContext( ConnectContext.get().getStatementContext(), this, PhysicalProperties.ANY))); for (int i = 0; i < getOutputs().size(); i++) { Optional value = Optional.empty(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 8f1ca1658de818..17e5a45b38927e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; @@ -51,11 +52,15 @@ import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; import org.apache.doris.nereids.trees.expressions.functions.scalar.UniqueFunction; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral; @@ -69,7 +74,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer; -import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.VariantType; import org.apache.doris.nereids.types.coercion.NumericType; import org.apache.doris.qe.ConnectContext; @@ -219,7 +223,7 @@ public static Expression and(Collection expressions) { } } - List exprList = Lists.newArrayList(distinctExpressions); + List exprList = ImmutableList.copyOf(distinctExpressions); if (exprList.isEmpty()) { return BooleanLiteral.TRUE; } else if (exprList.size() == 1) { @@ -267,7 +271,7 @@ public static Expression or(Collection expressions) { } } - List exprList = Lists.newArrayList(distinctExpressions); + List exprList = ImmutableList.copyOf(distinctExpressions); if (exprList.isEmpty()) { return BooleanLiteral.FALSE; } else if (exprList.size() == 1) { @@ -279,7 +283,7 @@ public static Expression or(Collection expressions) { public static Expression falseOrNull(Expression expression) { if (expression.nullable()) { - return new And(new IsNull(expression), new NullLiteral(BooleanType.INSTANCE)); + return new And(new IsNull(expression), NullLiteral.BOOLEAN_INSTANCE); } else { return BooleanLiteral.FALSE; } @@ -287,12 +291,16 @@ public static Expression falseOrNull(Expression expression) { public static Expression trueOrNull(Expression expression) { if (expression.nullable()) { - return new Or(new Not(new IsNull(expression)), new NullLiteral(BooleanType.INSTANCE)); + return new Or(new Not(new IsNull(expression)), NullLiteral.BOOLEAN_INSTANCE); } else { return BooleanLiteral.TRUE; } } + public static Expression notIsNull(Expression expression) { + return new Not(new IsNull(expression)); + } + public static Expression toInPredicateOrEqualTo(Expression reference, Collection values) { if (values.size() < 2) { return or(values.stream().map(value -> new EqualTo(reference, value)).collect(Collectors.toList())); @@ -601,22 +609,6 @@ public static List rewriteDownShortCircuit( return result.build(); } - private static class ExpressionReplacer - extends DefaultExpressionRewriter> { - public static final ExpressionReplacer INSTANCE = new ExpressionReplacer(); - - private ExpressionReplacer() { - } - - @Override - public Expression visit(Expression expr, Map replaceMap) { - if (replaceMap.containsKey(expr)) { - return replaceMap.get(expr); - } - return super.visit(expr, replaceMap); - } - } - /** * merge arguments into an expression array * @@ -657,26 +649,6 @@ public static boolean isAllNonNullComparableLiteral(List children) { return true; } - /** matchNumericType */ - public static boolean matchNumericType(List children) { - for (Expression child : children) { - if (!child.getDataType().isNumericType()) { - return false; - } - } - return true; - } - - /** matchDateLikeType */ - public static boolean matchDateLikeType(List children) { - for (Expression child : children) { - if (!child.getDataType().isDateLikeType()) { - return false; - } - } - return true; - } - /** hasNullLiteral */ public static boolean hasNullLiteral(List children) { for (Expression child : children) { @@ -687,16 +659,6 @@ public static boolean hasNullLiteral(List children) { return false; } - /** hasOnlyMetricType */ - public static boolean hasOnlyMetricType(List children) { - for (Expression child : children) { - if (child.getDataType().isOnlyMetricType()) { - return true; - } - } - return false; - } - /** * canInferNotNullForMarkSlot */ @@ -711,7 +673,7 @@ public static boolean canInferNotNullForMarkSlot(Expression predicate, Expressio * and in semi join, we can safely change the mark conjunct to hash conjunct */ ImmutableList literals = - ImmutableList.of(new NullLiteral(BooleanType.INSTANCE), BooleanLiteral.FALSE); + ImmutableList.of(NullLiteral.BOOLEAN_INSTANCE, BooleanLiteral.FALSE); List markJoinSlotReferenceList = new ArrayList<>((predicate.collect(MarkJoinSlotReference.class::isInstance))); int markSlotSize = markJoinSlotReferenceList.size(); @@ -1272,6 +1234,40 @@ public static String slotListShapeInfo(List materializedSlots) { return shapeBuilder.toString(); } + /** + * check whether the expression contains CaseWhen like type + */ + public static boolean containsCaseWhenLikeType(Expression expression) { + return expression.containsType(CaseWhen.class, If.class, NullIf.class, Nvl.class); + } + + /** + * get the results of each branch in CaseWhen like expression + */ + public static Optional> getCaseWhenLikeBranchResults(Expression expression) { + if (expression instanceof CaseWhen) { + CaseWhen caseWhen = (CaseWhen) expression; + ImmutableList.Builder builder + = ImmutableList.builderWithExpectedSize(caseWhen.getWhenClauses().size() + 1); + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + builder.add(whenClause.getResult()); + } + builder.add(caseWhen.getDefaultValue().orElse(new NullLiteral(caseWhen.getDataType()))); + return Optional.of(builder.build()); + } else if (expression instanceof If) { + If ifExpr = (If) expression; + return Optional.of(ImmutableList.of(ifExpr.getTrueValue(), ifExpr.getFalseValue())); + } else if (expression instanceof NullIf) { + NullIf nullIf = (NullIf) expression; + return Optional.of(ImmutableList.of(new NullLiteral(nullIf.getDataType()), nullIf.left())); + } else if (expression instanceof Nvl) { + Nvl nvl = (Nvl) expression; + return Optional.of(ImmutableList.of(nvl.left(), nvl.right())); + } else { + return Optional.empty(); + } + } + /** * has aggregate function, exclude the window function */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java index 7e55a18d51166f..ef46223367d76e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java @@ -44,6 +44,8 @@ import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.Filter; +import org.apache.doris.nereids.trees.plans.algebra.Join; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; @@ -161,6 +163,11 @@ public static Optional> tryMergeProjections(List filter = new LogicalFilter(ImmutableSet.of(), + new LogicalEmptyRelation(new RelationId(1), ImmutableList.of())); + // AddMinMax run in filter plan + context = new ExpressionRewriteContext(filter, cascadesContext); + } + @Test void testNotRewrite() { executor = new ExpressionRuleExecutor(ImmutableList.of( @@ -316,6 +328,7 @@ void testAddMinMax() { ) )); + assertRewriteAfterTypeCoercion("5 * 100 >= 10 and 5 * 100 <= 5", "5 * 100 >= 10 and 5 * 100 <= 5"); assertRewriteAfterTypeCoercion("TA >= 10", "TA >= 10"); assertRewriteAfterTypeCoercion("TA between 10 and 20", "TA >= 10 and TA <= 20"); assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA >= 30", @@ -389,26 +402,26 @@ void testSimplifyRangeAndAddMinMax() { assertRewriteAfterTypeCoercion("ISNULL(TA) and TA between 20 and 10", "ISNULL(TA) and null"); // assertRewriteAfterTypeCoercion("ISNULL(TA) and TA > 10", "ISNULL(TA) and null"); // should be, but not support now assertRewriteAfterTypeCoercion("ISNULL(TA) and TA > 10 and null", "ISNULL(TA) and null"); - assertRewriteAfterTypeCoercion("ISNULL(TA) or TA > 10", "ISNULL(TA) or TA > 10"); + assertRewriteAfterTypeCoercion("ISNULL(TA) or TA > 10", "TA > 10 or ISNULL(TA)"); // assertRewriteAfterTypeCoercion("(TA < 30 or TA > 40) and TA between 20 and 10", "TA IS NULL AND NULL"); // should be, but not support because flatten and assertRewriteAfterTypeCoercion("(TA < 30 or TA > 40) and TA is null and null", "TA IS NULL AND NULL"); assertRewriteAfterTypeCoercion("(TA < 30 or TA > 40) or TA between 20 and 10", "TA < 30 or TA > 40"); assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between 30 and 40 or TA between 60 and 50", - "(TA <= 20 or TA >= 30) and TA >= 10 and TA <= 40"); + "TA >= 10 and TA <= 40 and (TA <= 20 or TA >= 30)"); // should be, but not support yet, because 'TA is null and null' => UnknownValue(EmptyValue(TA) and null) //assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between 30 and 40 or TA is null and null", // "(TA <= 20 or TA >= 30) and TA >= 10 and TA <= 40"); assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between 30 and 40 or TA is null and null", - "(TA <= 20 or TA >= 30 or TA is null and null) and TA >= 10 and TA <= 40"); + "TA >= 10 and TA <= 40 and (TA <= 20 or TA >= 30)"); assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between 30 and 40 or TA is null", "TA >= 10 and TA <= 20 or TA >= 30 and TA <= 40 or TA is null"); assertRewriteAfterTypeCoercion("ISNULL(TB) and (TA between 10 and 20 or TA between 30 and 40 or TA between 60 and 50)", - "ISNULL(TB) and ((TA <= 20 or TA >= 30) and TA >= 10 and TA <= 40)"); + "ISNULL(TB) and TA >= 10 and TA <= 40 and (TA <= 20 or TA >= 30)"); assertRewriteAfterTypeCoercion("ISNULL(TB) and (TA between 10 and 20 or TA between 30 and 40 or TA is null)", "ISNULL(TB) and (TA >= 10 and TA <= 20 or TA >= 30 and TA <= 40 or TA is null)"); assertRewriteAfterTypeCoercion("TB between 20 and 10 and (TA between 10 and 20 or TA between 30 and 40 or TA between 60 and 50)", - "TB IS NULL AND NULL and (TA <= 20 or TA >= 30) and TA >= 10 and TA <= 40"); + "TB IS NULL AND NULL and TA >= 10 and TA <= 40 and (TA <= 20 or TA >= 30)"); assertRewriteAfterTypeCoercion("TA between 10 and 20 and TB between 10 and 20 or TA between 30 and 40 and TB between 30 and 40 or TA between 60 and 50 and TB between 60 and 50", "(TA <= 20 and TB <= 20 or TA >= 30 and TB >= 30 or TA is null and null and TB is null) and TA >= 10 and TA <= 40 and TB >= 10 and TB <= 40"); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java index b38ca618907db8..2be6cabba53afc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java @@ -29,6 +29,8 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DataType; @@ -43,12 +45,14 @@ import org.apache.doris.nereids.util.MemoTestUtils; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.junit.jupiter.api.Assertions; import java.util.List; import java.util.Map; +import java.util.function.Function; public abstract class ExpressionRewriteTestHelper extends ExpressionRewrite { protected static final NereidsParser PARSER = new NereidsParser(); @@ -63,7 +67,14 @@ public ExpressionRewriteTestHelper() { context = new ExpressionRewriteContext(cascadesContext); } - protected final void assertRewrite(String expression, String expected) { + protected void setExpressionOnFilter() { + LogicalFilter filter = new LogicalFilter(ImmutableSet.of(), + new LogicalEmptyRelation(new RelationId(1), ImmutableList.of())); + // AddMinMax run in filter plan + context = new ExpressionRewriteContext(filter, cascadesContext); + } + + protected void assertRewrite(String expression, String expected) { Map mem = Maps.newHashMap(); Expression needRewriteExpression = replaceUnboundSlot(PARSER.parseExpression(expression), mem); Expression expectedExpression = replaceUnboundSlot(PARSER.parseExpression(expected), mem); @@ -91,12 +102,14 @@ protected void assertNotRewrite(Expression expression, Expression expectedExpres } protected void assertRewriteAfterTypeCoercion(String expression, String expected) { + assertRewriteAfterConvert(expression, expected, ExpressionRewriteTestHelper::typeCoercion); + } + + protected void assertRewriteAfterConvert(String expression, String expected, Function converter) { Map mem = Maps.newHashMap(); - Expression needRewriteExpression = PARSER.parseExpression(expression); - needRewriteExpression = typeCoercion(replaceUnboundSlot(needRewriteExpression, mem)); - Expression expectedExpression = PARSER.parseExpression(expected); + Expression needRewriteExpression = converter.apply(replaceUnboundSlot(PARSER.parseExpression(expression), mem)); Expression rewrittenExpression = executor.rewrite(needRewriteExpression, context); - expectedExpression = typeCoercion(replaceUnboundSlot(expectedExpression, mem)); + Expression expectedExpression = converter.apply(replaceUnboundSlot(PARSER.parseExpression(expected), mem)); Assertions.assertEquals(expectedExpression.toSql(), rewrittenExpression.toSql()); } @@ -116,8 +129,9 @@ public static Expression replaceUnboundSlot(Expression expression, Map qualifier = slot.getQualifier(); DataType dataType = getType(name.charAt(0)); + boolean notNullable = name.charAt(0) == 'X' || name.length() >= 2 && name.charAt(1) == 'X'; Column column = new Column(name, dataType.toCatalogDataType()); - mem.putIfAbsent(name, new SlotReference(exprId, name, dataType, true, qualifier, null, column, null, null)); + mem.putIfAbsent(name, new SlotReference(exprId, name, dataType, !notNullable, qualifier, null, column, null, null)); return mem.get(name); } return hasNewChildren ? expression.withChildren(children) : expression; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java index 188b229b574fec..277bad6f8dfc54 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java @@ -158,6 +158,18 @@ void testCaseWhenFold() { assertRewriteAfterTypeCoercion("case when null = 2 then 1 else 4 end", "4"); assertRewriteAfterTypeCoercion("case when null = 2 then 1 end", "null"); assertRewriteAfterTypeCoercion("case when TA = TB then 1 when TC is null then 2 end", "CASE WHEN (TA = TB) THEN 1 WHEN TC IS NULL THEN 2 END"); + assertRewriteAfterTypeCoercion("case when a > 1 then a + 1 when a > 1 then a + 10 when a > 2 then a + 2 else a + 100 end", + "case when a > 1 then a + 1 when a > 2 then a + 2 else a + 100 end"); + assertRewriteAfterTypeCoercion("case when a > 1 then a + 1 when a > 2 then a + 1 when a > 3 then a + 1 else a + 1 end", + "a + 1"); + assertRewriteAfterTypeCoercion("case when a > 1 then a + 1 when a > 2 then a + 1 when a > 3 then a + 1 end", + "case when a > 1 then a + 1 when a > 2 then a + 1 when a > 3 then a + 1 end"); + assertRewriteAfterTypeCoercion("case when null then 1 when false then 2 when a > 3 then 3 when a > 4 then 4 end", + "case when a > 3 then 3 when a > 4 then 4 end"); + assertRewriteAfterTypeCoercion("case when null then 1 when false then 2 when a > 3 then 3 when true then 0 when a > 4 then 4 end", + "case when a > 3 then 3 else 0 end"); + assertRewriteAfterTypeCoercion("case when true then 100 when a > 1 then a + 1 when a > 1 then a + 10 when a > 2 then a + 2 else a + 100 end", + "100"); // make sure the case when return datetime(6) Expression analyzedCaseWhen = ExpressionAnalyzer.analyzeFunction(null, null, PARSER.parseExpression( @@ -169,6 +181,16 @@ void testCaseWhenFold() { Assertions.assertEquals(new DateTimeV2Literal(DateTimeV2Type.of(6), "2025-04-17"), foldCaseWhen); } + @Test + void testIfFold() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); + assertRewriteAfterTypeCoercion("if(true, a + 1, a + 2)", "a + 1"); + assertRewriteAfterTypeCoercion("if(false, a + 1, a + 2)", "a + 2"); + assertRewriteAfterTypeCoercion("if(b > 0, a + 100, a + 100)", "a + 100"); + } + @Test void testInFold() { executor = new ExpressionRuleExecutor(ImmutableList.of( @@ -1462,7 +1484,8 @@ void testFoldNvl() { assertRewriteExpression("nvl(NULL, 1)", "1"); assertRewriteExpression("nvl(NULL, NULL)", "NULL"); - assertRewriteAfterTypeCoercion("nvl(IA, NULL)", "ifnull(IA, NULL)"); + assertRewriteAfterTypeCoercion("nvl(IA, NULL)", "IA"); + assertRewriteAfterTypeCoercion("nvl(IA, IA)", "IA"); assertRewriteAfterTypeCoercion("nvl(IA, 1)", "ifnull(IA, 1)"); Expression foldNvl = executor.rewrite( @@ -1472,6 +1495,33 @@ void testFoldNvl() { Assertions.assertEquals(new DateTimeV2Literal(DateTimeV2Type.of(6), "2025-04-17"), foldNvl); } + @Test + void testFoldNullIf() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + FoldConstantRule.INSTANCE + ) + )); + assertRewriteAfterTypeCoercion("nullif(a, b)", "nullif(a, b)"); + assertRewriteAfterTypeCoercion("nullif(a, a)", "null"); + assertRewriteAfterTypeCoercion("nullif(a, null)", "a"); + assertRewriteAfterTypeCoercion("nullif(null, a)", "null"); + assertRewriteAfterTypeCoercion("nullif(1, 1)", "null"); + assertRewriteAfterTypeCoercion("nullif(1, 2)", "1"); + } + + @Test + void testNonFoldable() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + FoldConstantRule.INSTANCE + ) + )); + assertRewriteAfterTypeCoercion("random(0, 1)", "random(0, 1)"); + assertRewriteAfterTypeCoercion("sum(1 + 2)", "sum(3)"); + assertRewriteAfterTypeCoercion("explode([1, 2, 3])", "explode([1, 2, 3])"); + } + private void assertRewriteExpression(String actualExpression, String expectedExpression) { ExpressionRewriteContext context = new ExpressionRewriteContext( MemoTestUtils.createCascadesContext(new UnboundRelation(new RelationId(1), ImmutableList.of("test_table")))); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java index e6e856852d0c13..1c5910f44e1615 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java @@ -23,7 +23,11 @@ import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer; import org.apache.doris.nereids.rules.expression.rules.RangeInference; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.CompoundValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.EmptyValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNotNullValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNullValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.NotDiscreteValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.RangeValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc; @@ -58,33 +62,54 @@ public class SimplifyRangeTest extends ExpressionRewrite { private static final NereidsParser PARSER = new NereidsParser(); private ExpressionRuleExecutor executor; private ExpressionRewriteContext context; + private final Map commonMem; public SimplifyRangeTest() { CascadesContext cascadesContext = MemoTestUtils.createCascadesContext( new UnboundRelation(new RelationId(1), ImmutableList.of("tbl"))); context = new ExpressionRewriteContext(cascadesContext); + commonMem = Maps.newHashMap(); } @Test public void testRangeInference() { ValueDesc valueDesc = getValueDesc("TA IS NULL"); + Assertions.assertInstanceOf(IsNullValue.class, valueDesc); + Assertions.assertEquals("TA", valueDesc.getReference().toSql()); + + valueDesc = getValueDesc("NULL"); Assertions.assertInstanceOf(UnknownValue.class, valueDesc); - List sourceValues = ((UnknownValue) valueDesc).getSourceValues(); - Assertions.assertEquals(0, sourceValues.size()); - Assertions.assertEquals("TA IS NULL", valueDesc.getReference().toSql()); + Assertions.assertEquals("NULL", valueDesc.getReference().toSql()); + + valueDesc = getValueDesc("TA IS NOT NULL"); + Assertions.assertInstanceOf(IsNotNullValue.class, valueDesc); + Assertions.assertEquals("TA", valueDesc.getReference().toSql()); + + valueDesc = getValueDesc("TA != 10"); + Assertions.assertInstanceOf(NotDiscreteValue.class, valueDesc); + Assertions.assertEquals("TA", valueDesc.getReference().toSql()); + + valueDesc = getValueDesc("TA IS NULL AND NULL"); + Assertions.assertInstanceOf(EmptyValue.class, valueDesc); + Assertions.assertEquals("TA", valueDesc.getReference().toSql()); + + valueDesc = getValueDesc("TA IS NOT NULL OR NULL"); + Assertions.assertInstanceOf(RangeValue.class, valueDesc); + Assertions.assertEquals("TA", valueDesc.getReference().toSql()); + Assertions.assertTrue(((RangeValue) valueDesc).isRangeAll()); valueDesc = getValueDesc("TA IS NULL AND TB IS NULL AND NULL"); - Assertions.assertInstanceOf(UnknownValue.class, valueDesc); - sourceValues = ((UnknownValue) valueDesc).getSourceValues(); - Assertions.assertEquals(3, sourceValues.size()); + Assertions.assertInstanceOf(CompoundValue.class, valueDesc); + List sourceValues = ((CompoundValue) valueDesc).getSourceValues(); + Assertions.assertEquals(2, sourceValues.size()); Assertions.assertInstanceOf(EmptyValue.class, sourceValues.get(0)); Assertions.assertInstanceOf(EmptyValue.class, sourceValues.get(1)); Assertions.assertEquals("TA", sourceValues.get(0).getReference().toSql()); Assertions.assertEquals("TB", sourceValues.get(1).getReference().toSql()); valueDesc = getValueDesc("L + RANDOM(1, 10) > 8 AND L + RANDOM(1, 10) < 1"); - Assertions.assertInstanceOf(UnknownValue.class, valueDesc); - sourceValues = ((UnknownValue) valueDesc).getSourceValues(); + Assertions.assertInstanceOf(CompoundValue.class, valueDesc); + sourceValues = ((CompoundValue) valueDesc).getSourceValues(); Assertions.assertEquals(2, sourceValues.size()); for (ValueDesc value : sourceValues) { Assertions.assertInstanceOf(RangeValue.class, value); @@ -93,11 +118,109 @@ public void testRangeInference() { } @Test - public void testSimplify() { + public void testValueDescContainsAll() { + SlotReference xa = new SlotReference("xa", IntegerType.INSTANCE, false); + + checkContainsAll(true, "TA is null and null", "TA is null and null"); + checkContainsAll(false, "TA is null and null", "TA > 1"); + checkContainsAll(false, "TA is null and null", "TA = 1"); + checkContainsAll(false, "TA is null and null", "TA != 1"); + checkContainsAll(false, "TA is null and null", "TA is null"); + // XA is null and null will rewrite to 'FALSE' + // checkContainsAll(true, "XA is null and null", "XA is null"); + Assertions.assertTrue(new EmptyValue(context, xa).containsAll(new IsNullValue(context, xa))); + checkContainsAll(false, "TA is null and null", "TA = 1 or TA > 10"); + + checkContainsAll(true, "TA > 1", "TA is null and null"); + checkContainsAll(true, "TA > 1", "TA > 10"); + checkContainsAll(false, "TA > 1", "TA > 0"); + checkContainsAll(true, "TA >= 1", "TA > 1"); + checkContainsAll(false, "TA > 1", "TA >= 1"); + checkContainsAll(true, "TA > 1", "TA > 1"); + checkContainsAll(true, "TA > 1", "TA > 1 and TA < 10"); + checkContainsAll(false, "TA > 1", "TA >= 1 and TA < 10"); + checkContainsAll(true, "TA > 0", "TA in (1, 2, 3)"); + checkContainsAll(false, "TA > 0", "TA in (-1, 1, 2, 3)"); + checkContainsAll(false, "TA > 1", "TA != 0"); + checkContainsAll(false, "TA > 1", "TA != 1"); + checkContainsAll(false, "TA > 1", "TA != 2"); + checkContainsAll(true, "TA is not null or null", "TA != 2"); + checkContainsAll(false, "TA is not null or null", "TA is null"); + checkContainsAll(true, "TA is not null or null", "TA is not null"); + checkContainsAll(true, "TA is not null or null", "TA is null and null"); + checkContainsAll(false, "TA > 1", "TA is null"); + checkContainsAll(false, "TA > 1", "TA is not null"); + checkContainsAll(true, "TA > 1", "(TA > 2 and TA < 5) or (TA > 7 and TA < 9)"); + checkContainsAll(false, "TA > 1", "(TA >= 1 and TA < 5) or (TA > 7 and TA < 9)"); + checkContainsAll(true, "TA > 1", "TA > 5 and TA is not null"); + checkContainsAll(true, "TA > 1", "(TA > 5 and TA < 8) and TA is not null"); + checkContainsAll(true, "TA > 1", "TA > 5 and TA != 0"); + checkContainsAll(false, "TA > 1", "TA > 5 or TA is not null"); + checkContainsAll(false, "TA > 1", "TA > 5 or TA != 0"); + + checkContainsAll(true, "TA in (1, 2, 3)", "TA is null and null"); + checkContainsAll(false, "TA in (1, 2, 3, 4)", "TA between 2 and 3"); + checkContainsAll(true, "TA in (1, 2, 3)", "TA in (1, 2)"); + checkContainsAll(false, "TA in (1, 2, 3)", "TA in (1, 2, 4)"); + checkContainsAll(false, "TA in (1, 2, 3)", "TA not in (1, 2)"); + checkContainsAll(false, "TA in (1, 2, 3)", "TA not in (5, 6)"); + checkContainsAll(false, "TA in (1, 2, 3)", "TA is null"); + checkContainsAll(false, "TA in (1, 2, 3)", "TA is not null"); + checkContainsAll(true, "TA in (1, 2, 3)", "TA in (1, 2) and TA is not null"); + checkContainsAll(true, "TA in (1, 2, 3)", "TA in (1, 2) and TA is null"); + checkContainsAll(false, "TA in (1, 2, 3)", "TA != 1 and TA is not null"); + checkContainsAll(false, "TA in (0, 1, 2, 3)", "TA between 1 and 2 and TA is not null"); + + checkContainsAll(true, "TA not in (1, 2)", "TA is null and null"); + checkContainsAll(false, "TA not in (1, 2, 3, 4, 5)", "TA between 2 and 4"); + checkContainsAll(false, "TA not in (1, 2, 3, 4, 5)", "TA is not null or null"); + checkContainsAll(false, "TA not in (1, 2)", "TA in (1)"); + checkContainsAll(false, "TA not in (1, 2)", "TA in (1, 2)"); + checkContainsAll(true, "TA not in (1, 2)", "TA in (3, 4)"); + checkContainsAll(false, "TA not in (1, 2, 3)", "TA in (1, 4)"); + checkContainsAll(false, "TA not in (1, 2, 3)", "TA is null"); + checkContainsAll(false, "TA not in (1, 2, 3)", "TA is not null"); + checkContainsAll(false, "TA not in (1, 2, 3)", "TA is not null or null"); + checkContainsAll(true, "TA not in (1, 2)", "(TA is not null or null) and (TA is null or TA > 10)"); + + checkContainsAll(false, "TA is null", "TA is null and null"); + checkContainsAll(false, "TA is null", "TA > 10"); + checkContainsAll(false, "TA is null", "TA = 10"); + checkContainsAll(false, "TA is null", "TA != 10"); + checkContainsAll(true, "TA is null", "TA is null"); + checkContainsAll(false, "TA is null", "TA is not null"); + checkContainsAll(false, "TA is null", "TA is null and (TA > 10)"); + checkContainsAll(false, "TA is null", "TA is null or (TA > 10)"); + + checkContainsAll(false, "TA is not null", "TA is null and null"); + checkContainsAll(false, "TA is not null", "TA > 10"); + checkContainsAll(false, "TA is not null", "TA = 10"); + checkContainsAll(false, "TA is not null", "TA != 10"); + checkContainsAll(false, "TA is not null", "TA is null"); + checkContainsAll(true, "TA is not null", "TA is not null"); + checkContainsAll(false, "TA is not null", "TA is not null or null"); + checkContainsAll(true, "TA is not null", "TA is not null and (TA > 10)"); + checkContainsAll(false, "TA is not null", "TA is not null or (TA > 10)"); + + checkContainsAll(true, "TA < 1 or TA > 10", "TA is null and null"); + checkContainsAll(true, "TA < 1 or TA > 10", "TA < 0"); + checkContainsAll(false, "TA < 1 or TA > 10", "TA <= 1"); + checkContainsAll(true, "TA < 1 or TA > 10", "TA = 0"); + checkContainsAll(false, "TA < 1 or TA > 10", "TA in (0, 1)"); + checkContainsAll(true, "TA not in (1, 2, 13) or TA > 10", "TA not in (1, 2, 13, 15)"); + } + + private void checkContainsAll(boolean isContains, String expr1, String expr2) { + Assertions.assertEquals(isContains, getValueDesc(expr1).containsAll(getValueDesc(expr2))); + } + + @Test + public void testSimplifyNumeric() { executor = new ExpressionRuleExecutor(ImmutableList.of( bottomUp(SimplifyRange.INSTANCE) )); assertRewrite("TA", "TA"); + assertRewrite("TA > 10 and (TA > 20 or TA < 10)", "TA > 20"); assertRewrite("TA > 3 or TA > null", "TA > 3 OR NULL"); assertRewrite("TA > 3 or TA < null", "TA > 3 OR NULL"); assertRewrite("TA > 3 or TA = null", "TA > 3 OR NULL"); @@ -107,6 +230,7 @@ public void testSimplify() { "TA in (11, 12) OR TA <= 10 OR TA >= 13"); assertRewrite("TA > 3 or TA <> null", "TA > 3 or null"); assertRewrite("TA > 3 or TA <=> null", "TA > 3 or TA <=> null"); + assertRewrite("TA >= 0 and TA <= 3", "TA >= 0 and TA <= 3"); assertRewrite("(TA < 1 or TA > 2) or (TA >= 0 and TA <= 3)", "TA IS NOT NULL OR NULL"); assertRewrite("TA between 10 and 20 or TA between 100 and 120 or TA between 15 and 25 or TA between 115 and 125", "TA >= 10 and TA <= 25 or TA >= 100 and TA <= 125"); @@ -128,10 +252,51 @@ public void testSimplify() { assertRewriteNotNull("TA = 1 and TA > 10", "FALSE"); assertRewrite("TA = 1 and TA > 10", "TA is null and null"); assertRewrite("TA >= 1 and TA <= 1", "TA = 1"); + assertRewrite("TA = 1 and TA = 2", "TA IS NULL AND NULL"); + assertRewriteNotNull("TA = 1 and TA = 2", "FALSE"); + assertRewrite("TA not in (1) and TA not in (1)", "TA != 1"); + assertRewrite("TA not in (1, 2, 3) and TA not in (1, 4, 5)", "TA not in (1, 2, 3, 4, 5)"); + assertRewrite("TA = 1 and TA not in (2)", "TA = 1"); + assertRewrite("TA = 1 and TA not in (1, 2)", "TA is null and null"); + assertRewriteNotNull("TA = 1 and TA not in (1, 2)", "FALSE"); + assertRewrite("TA > 10 and TA not in (1, 2, 3)", "TA > 10"); + assertRewrite("TA > 10 and TA not in (1, 2, 3, 11)", "TA > 10 and TA != 11"); + assertRewrite("TA > 10 and TA not in (1, 2, 3, 11, 12)", "TA > 10 and TA NOT IN (11, 12)"); + assertRewrite("TA is null", "TA is null"); + assertRewriteNotNull("TA is null", "TA is null"); + assertRewrite("TA is not null", "TA is not null"); + assertRewrite("TA is null and TA is not null", "FALSE"); + assertRewriteNotNull("TA is null and TA is not null", "FALSE"); + assertRewrite("TA = 1 and TA != 1 and TA is null", "TA is null and null"); + assertRewriteNotNull("TA = 1 and TA != 1 and TA is null", "FALSE"); + assertRewrite("TA = 1 and TA != 1 and TA is not null", "FALSE"); + assertRewriteNotNull("TA = 1 and TA != 1 and TA is not null", "FALSE"); + assertRewrite("TA = 1 and TA != 1 and (TA > 10 or TA < 5)", "TA is null and null"); + assertRewriteNotNull("TA = 1 and TA != 1 and (TA > 10 or TA < 5)", "FALSE"); + assertRewrite("TA = 1 and TA != 1 and (TA > 10 or TA is not null)", "TA is null and null"); + assertRewrite("TA = 1 and TA != 1 and (TA > 10 or (TA < 5 and TA is not null))", "TA is null and null"); + assertRewrite("TA = 1 and TA != 1 and (TA > 10 or (TA < 5 and TA is not null) or (TA > 7 and TA is not null))", + "TA is null and null"); assertRewrite("TA > 5 or TA < 1", "TA < 1 or TA > 5"); assertRewrite("TA > 5 or TA > 1 or TA > 10", "TA > 1"); assertRewrite("TA > 5 or TA > 1 or TA < 10", "TA is not null or null"); assertRewriteNotNull("TA > 5 or TA > 1 or TA < 10", "TRUE"); + assertRewrite("TA != 1 or TA != 1", "TA != 1"); + assertRewrite("TA != 1 or TA != 2", "TA is not null or null"); + assertRewriteNotNull("TA != 1 or TA != 2", "TRUE"); + assertRewrite("TA not in (1, 2, 3) or TA not in (1, 2, 4)", "TA not in (1, 2)"); + assertRewrite("TA not in (1, 2) or TA in (2, 1)", "TA is not null or null"); + assertRewrite("TA not in (1, 2) or TA in (1)", "TA != 2"); + assertRewrite("TA not in (1, 2) or TA in (1, 2, 3)", "TA is not null or null"); + assertRewrite("TA not in (1, 3) or TA < 2", "TA != 3"); + assertRewrite("TA is null and null", "TA is null and null"); + assertRewrite("TA is null", "TA is null"); + assertRewrite("TA is null and null or TA = 1", "TA = 1"); + assertRewrite("TA is null and null or TA is null", "TA is null"); + assertRewrite("TA is null and null or (TA is null and TA > 10) ", "TA is null and null"); + assertRewrite("TA is null and null or TA is not null", "TA is not null or null"); + assertRewriteNotNull("TA != 1 or TA != 2", "TRUE"); + assertRewrite("TA is null or TA is not null", "TRUE"); assertRewrite("TA > 5 and TA > 1 and TA > 10", "TA > 10"); assertRewrite("TA > 5 and TA > 1 and TA < 10", "TA > 5 and TA < 10"); assertRewrite("TA > 1 or TA < 1", "TA < 1 or TA > 1"); @@ -149,12 +314,14 @@ public void testSimplify() { assertRewrite("(TA > 10 or TA > 20) and (TB > 10 and TB > 20)", "TA > 10 and TB > 20"); assertRewrite("((TB > 30 and TA > 40) and TA > 20) and (TB > 10 and TB > 20)", "TB > 30 and TA > 40"); assertRewrite("(TA > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA > 10 and TB > 10 or TB > 20"); - assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB < 10 or TB > 20))"); + assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or TB > 20"); assertRewriteNotNull("TA in (1,2,3) and TA > 10", "FALSE"); assertRewrite("TA in (1,2,3) and TA > 10", "TA is null and null"); assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)"); assertRewrite("TA in (1,2,3) and TA > 1", "TA IN (2, 3)"); assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1"); + assertRewrite("TA is null and (TA = 4 or TA = 5)", "TA is null and null"); + assertRewrite("(TA != 3 or TA is null) and (TA = 4 or TA = 5)", "TA in (4, 5)"); assertRewrite("TA in (1)", "TA in (1)"); assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)"); assertRewriteNotNull("TA in (1,2,3) and TA < 1", "FALSE"); @@ -179,7 +346,162 @@ public void testSimplify() { assertRewrite("(TA > 3 and TA < 1) and (TB < 5 and TB = 6)", "TA is null and null and TB is null"); assertRewrite("TA > 3 and TB < 5 and TA < 1", "TA is null and null and TB < 5"); assertRewrite("(TA > 3 and TA < 1) or TB < 5", "(TA is null and null) or TB < 5"); - assertRewrite("((IA = 1 AND SC ='1') OR SC = '1212') AND IA =1", "((IA = 1 AND SC ='1') OR SC = '1212') AND IA =1"); + + assertRewrite("TA is null and TA > 10", "TA is null and null"); + assertRewrite("TA is null and TA = 10", "TA is null and null"); + assertRewrite("TA is null and TA != 10", "TA is null and null"); + assertRewriteNotNull("TA is null and TA > 10", "FALSE"); + assertRewriteNotNull("TA is null and TA = 10", "FALSE"); + assertRewriteNotNull("TA is null and TA != 10", "FALSE"); + assertRewrite("TA is not null or TA > 10", "TA is not null or null"); + assertRewrite("TA is not null or TA = 10", "TA is not null or null"); + assertRewrite("TA is not null or TA != 10", "TA is not null or null"); + assertRewriteNotNull("TA is not null or TA > 10", "TRUE"); + assertRewriteNotNull("TA is not null or TA = 10", "TRUE"); + assertRewriteNotNull("TA is not null or TA != 10", "TRUE"); + + // A and (B or C) = A + assertRewrite("TA < 10 and (TA is not null or TA is null and null)", "TA < 10"); + assertRewrite("TA > 10 and (TA > 5 or (TA is not null and TA > 1))", "TA > 10"); + assertRewrite("TA > 10 and (TA != 4 or (TA is not null and TA > 1))", "TA > 10"); + assertRewrite("TA = 5 and (TA != 4 or (TA is not null and TA > 1))", "TA = 5"); + assertRewrite("TA = 5 and (TA in (1, 2, 5) or (TA is not null and TA > 1))", "TA = 5"); + assertRewrite("TA = 5 and (TA > 3 or (TA is not null and TA > 1))", "TA = 5"); + assertRewrite("TA not in (1, 2) and (TA not in (1) or (TA is not null and TA > 1))", "TA not in (1, 2)"); + assertRewrite("TA not in (1, 2) and (TA not in (1, 2) or (TA is not null and TA > 1))", "TA not in (1, 2)"); + assertRewrite("TA not in (2, 3) or (TA is not null and TA > 1)", "TA is not null or null"); + assertRewrite("TA not in (1, 2) and (TA not in (2, 3) or (TA is not null and TA > 1))", "TA not in (1, 2)"); + assertRewrite("TA is null and null and (TA = 10 or (TA is not null and TA > 1))", "TA is null and null"); + assertRewrite("TA is null and null and (TA != 10 or (TA is not null and TA > 1))", "TA is null and null"); + assertRewrite("TA is null and null and (TA > 20 or (TA is not null and TA > 1))", "TA is null and null"); + assertRewrite("TA is null and null and (TA is null and null or (TA is not null and TA > 1))", "TA is null and null"); + assertRewrite("TA is null and null and (TA is null or (TA is not null and TA > 1))", "TA is null and null"); + assertRewrite("TA is null and (TA is null or (TA is not null and TA > 1))", "TA is null"); + assertRewrite("TA is not null and (TA is not null or (TA is not null and TA > 1))", "TA is not null"); + + assertRewrite("TA is null and null", "TA is null and null"); + assertRewriteNotNull("TA is null and null", "FALSE"); + assertRewrite("TA is null", "TA is null"); + assertRewriteNotNull("TA is null", "TA is null"); + assertRewrite("TA is not null", "TA is not null"); + assertRewriteNotNull("TA is not null", "TA is not null"); + assertRewrite("TA is null and null or TA is null", "TA is null"); + assertRewriteNotNull("TA is null and null or TA is null", "TA is null"); + assertRewrite("TA is null and null or TA is not null", "TA is not null or null"); + assertRewriteNotNull("TA is null and null or TA is not null", "not TA is null"); + assertRewrite("TA is null or TA is not null", "TRUE"); + assertRewriteNotNull("TA is null or TA is not null", "TRUE"); + assertRewrite("(TA is null and null) and TA is null", "TA is null and null"); + assertRewriteNotNull("(TA is null and null) and TA is null", "FALSE"); + assertRewrite("TA is null and null and TA is not null", "FALSE"); + assertRewriteNotNull("TA is null and null and TA is not null", "FALSE"); + assertRewrite("TA is null and TA is not null", "FALSE"); + assertRewriteNotNull("TA is null and TA is not null", "FALSE"); + + assertRewrite("(TA is not null or null) and TA > 10", "TA > 10"); + assertRewrite("(TA is not null or null) or TA > 10", "TA is not null or null"); + + assertRewrite("(TA is null and null) and TA is null", "TA is null and null"); + assertRewrite("(TA is null and null) or TA is null", "TA is null"); + // can simplify to 'TA is null', but not supported yet + assertRewrite("(TA is null or null) and TA is null", "(TA is null or null) and TA is null"); + assertRewrite("(TA is null or null) or TA is null", "TA is null or null"); + assertRewrite("(TA is not null and null) and TA is null", "FALSE"); + assertRewrite("(TA is not null and null) or TA is null", "TA is not null and null or TA is null"); + assertRewrite("(TA is not null or null) and TA is null", "TA is null and null"); + assertRewrite("(TA is not null or null) or TA is null", "TRUE"); + assertRewrite("(TA is null and null) and TA is not null", "FALSE"); + assertRewrite("(TA is null and null) or TA is not null", "TA is not null or null"); + assertRewrite("(TA is null or null) and TA is not null", "(TA is null or null) and TA is not null"); + assertRewrite("(TA is null or null) or TA is not null", "TRUE"); + assertRewrite("(TA is not null and null) and TA is not null", "TA is not null and null"); + // can simplify to 'TA is not null', but not supported yet + assertRewrite("(TA is not null and null) or TA is not null", "TA is not null and null or TA is not null"); + // can simplify to 'TA is not null', but not supported yet + assertRewrite("(TA is not null or null) and TA is not null", "TA is not null"); + assertRewrite("(TA is not null or null) or TA is not null", "TA is not null or null"); + + assertRewrite("(XA is null and null) and XA is null", "FALSE"); + // can simplify to 'FALSE', but not supported yet + assertRewrite("(XA is null and null) or XA is null", "XA is null"); + // can simplify to 'FALSE', but not supported yet + assertRewrite("(XA is null or null) and XA is null", "(XA is null or null) and XA is null"); + // can simplify to 'null', but not supported yet + assertRewrite("(XA is null or null) or XA is null", "XA is null or null"); + assertRewrite("(XA is not null and null) and XA is null", "FALSE"); + // can simplify to 'null', but not supported yet + assertRewrite("(XA is not null and null) or XA is null", "(XA is not null and null) or XA is null"); + assertRewrite("(XA is not null or null) and XA is null", "FALSE"); + assertRewrite("(XA is not null or null) or XA is null", "TRUE"); + assertRewrite("(XA is null and null) and XA is not null", "FALSE"); + // can simplify to 'TRUE', but not supported yet + assertRewrite("(XA is null and null) or XA is not null", "XA is not null"); + // can simplify to 'NULL', but not supported yet + assertRewrite("(XA is null or null) and XA is not null", "(XA is null or null) and XA is not null"); + assertRewrite("(XA is null or null) or XA is not null", "TRUE"); + // can simplify to 'NULL', but not supported yet + assertRewrite("(XA is not null and null) and XA is not null", "XA is not null and null"); + // can simplify to 'NULL', but not supported yet + assertRewrite("(XA is not null and null) or XA is not null", "XA is not null and null or XA is not null"); + // can simplify to 'TRUE', but not supported yet + assertRewrite("(XA is not null or null) and XA is not null", "XA is not null"); + assertRewrite("(XA is not null or null) or XA is not null", "TRUE"); + + assertRewrite("TA < 10 or (TA is null or (TA != 1 and TA != 2))", "TRUE"); + assertRewrite("TA < 10 or ((TA != 1 or TA is null) and (TA != 2 or TA is null))", "TRUE"); + + assertRewrite("(TA between 10 and 20 or TA between 30 and 40) and (TA between 5 and 15 or TA between 35 and 45)", + "(TA between 10 and 20 or TA between 30 and 40) and (TA between 5 and 15 or TA between 35 and 45)"); + assertRewrite("(TA between 10 and 20 or TA > 30) and (TA between 5 and 15 or TA > 40)", + "(TA between 10 and 20 or TA > 30) and (TA between 5 and 15 or TA > 40)"); + + assertRewrite("TA < 10 and TA is not null or TA > 20 and TA is not null", + "TA < 10 and TA is not null or TA > 20 and TA is not null"); + assertRewrite("TA < 10 and TA != 0 or TA > 20 and TA != 25", "TA < 10 and TA != 0 or TA > 20 and TA != 25"); + + // A and ((B1 and B2) or (C1 and C2)) + assertRewrite("TA = 15 and (TA < 10 and TA is not null or TA > 20 and TA is not null)", "FALSE"); + assertRewrite("TA = 15 and (TA < 10 and TA is not null or TA > 20 and TA is null)", "TA is null and null"); + assertRewrite("TA = 15 and (TA < 10 and TA != 0 or TA > 20 and TA != 25)", "TA is null and null"); + assertRewriteNotNull("TA = 15 and (TA < 10 and TA is not null or TA > 20 and TA is not null)", "FALSE"); + assertRewriteNotNull("TA = 15 and (TA < 10 and TA is not null or TA > 20 and TA is null)", "FALSE"); + assertRewriteNotNull("TA = 15 and (TA < 10 and TA != 0 or TA > 20 and TA != 25)", "FALSE"); + + // A or ((B1 or B2) and (C1 or C2)) + assertRewrite("TA < 10 or ((TA != 1 or TA is null) and (TA != 2 or TA is null))", "TRUE"); + assertRewrite("TA < 10 or ((TA != 1 or TA is not null) and (TA != 2 or TA is not null))", "TA is not null or null"); + assertRewrite("TA < 10 or ((TA != 1 or TA is null) and (TA != 2 or TA is not null))", "TA is not null or null"); + assertRewrite("TA < 10 or ((TA != 1 or TA is null) and (TA is null or TA is not null))", "TRUE"); + assertRewrite("TA < 100 or (TA between 10 and 20 or TA > 30) and (TA between 5 and 15 or TA > 40)", "TA is not null or null"); + assertRewriteNotNull("(TA between 10 and 20 or TA between 30 and 40) and (TA between 5 and 15 or TA between 35 and 45)", + "(TA between 10 and 20 or TA between 30 and 40) and (TA between 5 and 15 or TA between 35 and 45)"); + assertRewriteNotNull("(TA between 10 and 20 or TA > 30) and (TA between 5 and 15 or TA > 40)", + "(TA between 10 and 20 or TA > 30) and (TA between 5 and 15 or TA > 40)"); + assertRewriteNotNull("TA < 10 or ((TA != 1 or TA is null) and (TA != 2 or TA is null))", "TRUE"); + assertRewriteNotNull("TA < 10 or ((TA != 1 or TA is not null) and (TA != 2 or TA is not null))", "TRUE"); + assertRewriteNotNull("TA < 10 or ((TA != 1 or TA is null) and (TA != 2 or TA is not null))", "TRUE"); + assertRewriteNotNull("TA < 10 or ((TA != 1 or TA is null) and (TA is null or TA is not null))", "TRUE"); + assertRewriteNotNull("TA < 100 or (TA between 10 and 20 or TA > 30) and (TA between 5 and 15 or TA > 40)", "TRUE"); + + assertRewrite("TA is not null or TA is null and null", "TA is not null or null"); + assertRewrite("TA > 100 and (TA < 10 or TA between 15 and 20)", "TA is null and null"); + assertRewrite("TA > 100 and (TA < 10 or TA between 15 and 20 or TA between 110 and 115)", "TA between 110 and 115"); + assertRewrite("TA > 100 and (TA < 10 or TA between 15 and 20 or TA is null)", "TA is null and null"); + assertRewrite("TA > 100 and (TA < 10 or TA between 15 and 20 or TA is not null)", "TA > 100"); + assertRewriteNotNull("TA is not null or TA is null and null", "TA is not null"); + assertRewriteNotNull("TA > 100 and (TA < 10 or TA between 15 and 20)", "FALSE"); + assertRewriteNotNull("TA > 100 and (TA < 10 or TA between 15 and 20 or TA is null)", "FALSE"); + assertRewriteNotNull("TA > 100 and (TA < 10 or TA between 15 and 20 or TA is not null)", "TA > 100"); + assertRewrite("TA > 100 or (TA < 120 and TA is null)", "TA > 100"); + assertRewrite("TA > 100 or (TA < 120 and TA is not null)", "TA is not null or null"); + assertRewrite("TA > 100 or (TA < 120 and TA != 80)", + "TA > 100 or ((TA is not null or null) and TA != 80)"); + assertRewrite("TA > 100 or (TA < 120 and TA != 110)", "TA is not null or null"); + assertRewriteNotNull("TA > 100 or (TA < 120 and TA is null)", "TA > 100"); + assertRewriteNotNull("TA > 100 or (TA < 120 and TA is not null)", "TRUE"); + assertRewriteNotNull("TA > 100 or (TA < 120 and TA != 80)", + "TA != 80"); + assertRewriteNotNull("TA > 100 or (TA < 120 and TA != 110)", "TRUE"); assertRewrite("TA + TC", "TA + TC"); assertRewrite("(TA + TC >= 1 and TA + TC <=3 ) or (TA + TC > 5 and TA + TC < 7)", "(TA + TC >= 1 and TA + TC <=3 ) or (TA + TC > 5 and TA + TC < 7)"); @@ -208,7 +530,8 @@ public void testSimplify() { assertRewrite("(TA + TC > 10 or TA + TC > 20) and (TB > 10 and TB > 20)", "TA + TC > 10 and TB > 20"); assertRewrite("((TB > 30 and TA + TC > 40) and TA + TC > 20) and (TB > 10 and TB > 20)", "TB > 30 and TA + TC > 40"); assertRewrite("(TA + TC > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA + TC > 10 and TB > 10 or TB > 20"); - assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA + TC > 5 and TB > 10) or (TB > 10 and (TB < 10 or TB > 20))"); + assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", + "(TA + TC > 5 and TB > 10) or TB > 20"); assertRewriteNotNull("TA + TC in (1,2,3) and TA + TC > 10", "FALSE"); assertRewrite("TA + TC in (1,2,3) and TA + TC > 10", "(TA + TC) is null and null"); assertRewrite("TA + TC in (1,2,3) and TA + TC >= 1", "TA + TC in (1,2,3)"); @@ -230,25 +553,31 @@ public void testSimplify() { assertRewrite("TA + TC = 1 and TA + TC = 3", "(TA + TC) is null and null"); assertRewriteNotNull("TA + TC in (1) and TA + TC in (3)", "FALSE"); assertRewrite("TA + TC in (1) and TA + TC in (3)", "(TA + TC) is null and null"); - assertRewrite("TA + TC in (1) and TA + TC in (1)", "TA + TC in (1)"); + assertRewrite("TA + TC in (1) and TA + TC in (1)", "TA + TC = 1"); assertRewriteNotNull("(TA + TC > 3 and TA + TC < 1) and TB < 5", "FALSE"); assertRewrite("(TA + TC > 3 and TA + TC < 1) and TB < 5", "(TA + TC) is null and null and TB < 5"); assertRewrite("(TA + TC > 3 and TA + TC < 1) or TB < 5", "((TA + TC) is null and null) OR TB < 5"); - assertRewrite("(TA + TC > 3 OR TA < 1) AND TB = 2 AND IA =1", "(TA + TC > 3 OR TA < 1) AND TB = 2 AND IA =1"); - assertRewrite("SA = '20250101' and SA < '20200101'", "SA is null and null"); - assertRewrite("SA > '20250101' and SA > '20260110'", "SA > '20260110'"); // random is non-foldable, so the two random(1, 10) are distinct, cann't merge range for them. Expression expr = rewriteExpression("X + random(1, 10) > 10 AND X + random(1, 10) < 1", true); Assertions.assertEquals("AND[((X + random(1, 10)) > 10),((X + random(1, 10)) < 1)]", expr.toSql()); - expr = rewrite("TA + random(1, 10) between 10 and 20", Maps.newHashMap()); Assertions.assertEquals("AND[((cast(TA as BIGINT) + random(1, 10)) >= 10),((cast(TA as BIGINT) + random(1, 10)) <= 20)]", expr.toSql()); expr = rewrite("TA + random(1, 10) between 20 and 10", Maps.newHashMap()); Assertions.assertEquals("AND[(cast(TA as BIGINT) + random(1, 10)) IS NULL,NULL]", expr.toSql()); } + @Test + public void testSimplifyString() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyRange.INSTANCE) + )); + assertRewrite("SA = '20250101' and SA < '20200101'", "SA is null and null"); + assertRewrite("SA > '20250101' and SA > '20260110'", "SA > '20260110'"); + assertRewrite("((IA = 1 AND SC ='1') OR SC = '1212') AND IA =1", "((IA = 1 AND SC ='1') OR SC = '1212') AND IA =1"); + } + @Test public void testSimplifyDate() { executor = new ExpressionRuleExecutor(ImmutableList.of( @@ -414,9 +743,18 @@ public void testSimplifyDateTime() { "(CA is null and null) OR CB < timestamp '2024-01-05 00:50:00'"); } + @Test + public void testMixTypes() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyRange.INSTANCE) + )); + assertRewrite("(TA > 1 and FALSE or FALSE and SA > 'aaaa') and TB is null", "FALSE"); + assertRewrite("(TA > 1 and FALSE or FALSE and SA > 'aaaa') and (TA > 1 and FALSE or FALSE and SA > 'aaaa') and TB is null", + "FALSE"); + } + private ValueDesc getValueDesc(String expression) { - Map mem = Maps.newHashMap(); - Expression parseExpression = replaceUnboundSlot(PARSER.parseExpression(expression), mem); + Expression parseExpression = replaceUnboundSlot(PARSER.parseExpression(expression), commonMem); parseExpression = typeCoercion(parseExpression); return (new RangeInference()).getValue(parseExpression, context); } @@ -466,7 +804,8 @@ private Expression replaceUnboundSlot(Expression expression, Map m } if (expression instanceof UnboundSlot) { String name = ((UnboundSlot) expression).getName(); - mem.putIfAbsent(name, new SlotReference(name, getType(name.charAt(0)))); + boolean notNullable = name.charAt(0) == 'X' || name.length() >= 2 && name.charAt(1) == 'X'; + mem.putIfAbsent(name, new SlotReference(name, getType(name.charAt(0)), !notNullable)); return mem.get(name); } return hasNewChildren ? expression.withChildren(children) : expression; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicateTest.java new file mode 100644 index 00000000000000..d15048122829d2 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToCompoundPredicateTest.java @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.junit.jupiter.api.Test; + +class CaseWhenToCompoundPredicateTest extends ExpressionRewriteTestHelper { + + @Test + void testCaseWhen() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + CaseWhenToCompoundPredicate.INSTANCE + ) + )); + assertRewriteAfterTypeCoercion("case when a = 1 then true end", "(a = 1 <=> TRUE) or null"); + assertRewriteAfterTypeCoercion("case when a = 1 then true else null end", "(a = 1 <=> TRUE) or null"); + assertRewriteAfterTypeCoercion("case when a = 1 then true else false end", "(a = 1 <=> TRUE) or false"); + assertRewriteAfterTypeCoercion("case when a = 1 then true else true end", "(a = 1 <=> TRUE) or true"); + assertRewriteAfterTypeCoercion("case when a = 1 then true else b = 1 end", "(a = 1 <=> TRUE) or b = 1"); + assertRewriteAfterTypeCoercion("case when a = 1 then true when b = 1 then true when c = 1 then true end", + "(a = 1 <=> TRUE) or (b = 1 <=> TRUE) or (c = 1 <=> TRUE) or null"); + assertRewriteAfterTypeCoercion("case when a = 1 then false when b = 1 then false when c = 1 then false end", + "not(a = 1 <=> TRUE) and not (b = 1 <=> TRUE) and not(c = 1 <=> TRUE) and null"); + assertRewriteAfterTypeCoercion("case when a = 1 then true when b = 1 then false when c = 1 then true end", + "(a = 1 <=> TRUE) or (not (b = 1 <=> TRUE) and ((c = 1 <=> TRUE) or null))"); + } + + @Test + void testIf() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + CaseWhenToCompoundPredicate.INSTANCE + ) + )); + + assertRewriteAfterTypeCoercion("if(a = 1, true, a > b)", "(a = 1 <=> TRUE) or a > b"); + assertRewriteAfterTypeCoercion("if(a = 1, false, a > b)", "not (a = 1 <=> TRUE) and a > b"); + assertRewriteAfterTypeCoercion("if(a = 1, b = 1, true)", "if(a = 1, b = 1, true)"); + assertRewriteAfterTypeCoercion("if(a = 1, b = 1, false)", "if(a = 1, b = 1, false)"); + assertRewriteAfterTypeCoercion("if(a = 1, b = 1, null)", "if(a = 1, b = 1, null)"); + } + + @Test + void testIfInCond() { + LogicalFilter filter = new LogicalFilter(ImmutableSet.of(), + new LogicalEmptyRelation(new RelationId(1), ImmutableList.of())); + ExpressionRewriteContext oldContext = context; + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + CaseWhenToCompoundPredicate.INSTANCE + ) + )); + try { + context = new ExpressionRewriteContext(filter, cascadesContext); + assertRewriteAfterTypeCoercion("if(a = 1, b = 1, true)", "not((a = 1) <=> true) or b = 1"); + assertRewriteAfterTypeCoercion("if(a = 1, b = 1, false)", "a = 1 and b = 1"); + assertRewriteAfterTypeCoercion("if(a = 1, b = 1, null)", "a = 1 and b = 1"); + } finally { + context = oldContext; + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/CondReplaceNullWithFalseTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/CondReplaceNullWithFalseTest.java new file mode 100644 index 00000000000000..ff5fae18c07146 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/CondReplaceNullWithFalseTest.java @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.junit.jupiter.api.Test; + +class CondReplaceNullWithFalseTest extends ExpressionRewriteTestHelper { + + @Test + void testInsideCondition() { + LogicalFilter filter = new LogicalFilter(ImmutableSet.of(), + new LogicalEmptyRelation(new RelationId(1), ImmutableList.of())); + context = new ExpressionRewriteContext(filter, cascadesContext); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(CondReplaceNullWithFalse.INSTANCE) + )); + + assertRewriteAfterTypeCoercion("null", "false"); + assertRewriteAfterTypeCoercion("not(null)", "not(null)"); + assertRewriteAfterTypeCoercion("null and true", "false and true"); + assertRewriteAfterTypeCoercion("null or true", "false or true"); + assertRewriteAfterTypeCoercion("case when null and true then null and true else null end", + "case when false and true then false and true else false end"); + assertRewriteAfterTypeCoercion("if(null and true, null and true, null and true)", + "if(false and true, false and true, false and true)"); + assertRewriteAfterTypeCoercion("not(case when null and true then null and true else null end)", + "not(case when false and true then null and true else null end)"); + assertRewriteAfterTypeCoercion("not(if(null and true, null and true, null and true))", + "not(if(false and true, null and true, null and true))"); + + assertRewriteAfterTypeCoercion( + "case when null then null" + + " when null and a = 1 and not(null) or " + + " (case when a = 2 and null then null " + + " when null then not(null) " + + " else null or a=3" + + " end) " + + " then (case when null then null else null end) " + + " else null end", + + "case when false then false" + + " when false and a = 1 and not(null) or " + + " (case when a = 2 and false then false " + + " when false then not(null) " + + " else false or a=3" + + " end) " + + " then (case when false then false else false end) " + + " else false end" + ); + + assertRewriteAfterTypeCoercion( + "if(" + + " null and not(null) and if(null and not(null), null and true, null)," + + " null and not(null)," + + " if(a = 1 and null, null and true, null)" + + ")", + + "if(" + + " false and not(null) and if(false and not(null), false and true, false)," + + " false and not(null)," + + " if(a = 1 and false, false and true, false)" + + ")" + ); + } + + @Test + void testNotInCondition() { + context = new ExpressionRewriteContext(cascadesContext); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(CondReplaceNullWithFalse.INSTANCE) + )); + + assertRewriteAfterTypeCoercion("null", "null"); + assertRewriteAfterTypeCoercion("not(null)", "not(null)"); + assertRewriteAfterTypeCoercion("null and true", "null and true"); + assertRewriteAfterTypeCoercion("null or true", "null or true"); + assertRewriteAfterTypeCoercion("case when null and true then null and true else null end", + "case when false and true then null and true else null end"); + assertRewriteAfterTypeCoercion("if(null and true, null and true, null and true)", + "if(false and true, null and true, null and true)"); + assertRewriteAfterTypeCoercion("not(case when null and true then null and true else null end)", + "not(case when false and true then null and true else null end)"); + assertRewriteAfterTypeCoercion("not(if(null and true, null and true, null and true))", + "not(if(false and true, null and true, null and true))"); + assertRewriteAfterTypeCoercion("case when null and true then true and null end", "case when false and true then true and null end"); + + assertRewriteAfterTypeCoercion( + "case when null then null" + + " when null and a = 1 and not(null) or " + + " (case when a = 2 and null then null " + + " when null then not(null) " + + " else null or a=3" + + " end) " + + " then (case when null then null else null end) " + + " else null end", + + "case when false then null" + + " when false and a = 1 and not(null) or " + + " (case when a = 2 and false then false " + + " when false then not(null) " + + " else false or a=3" + + " end) " + + " then (case when false then null else null end) " + + " else null end" + ); + + assertRewriteAfterTypeCoercion( + "if(" + + " null and not(null) and if(null and not(null), null and true, null)," + + " null and not(null)," + + " if(a = 1 and null, null, null)" + + ")", + + "if(" + + " false and not(null) and if(false and not(null), false and true, false)," + + " null and not(null)," + + " if(a = 1 and false, null, null)" + + ")" + ); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/ExpressionRewriteSqlTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/ExpressionRewriteSqlTest.java index 7628620fec6640..434c693b68ad64 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/ExpressionRewriteSqlTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/ExpressionRewriteSqlTest.java @@ -71,7 +71,7 @@ public void testSimplifyRangeAndExtractCommonFactor() { .rewrite() .matches( logicalFilter().when(f -> f.getPredicate().toSql().equals( - "AND[(score > 1),(id > 1)]" + "AND[(id > 1),(score > 1)]" ))); sql = "select * from T1 where id > 1 and id < 0 or score > 1 and score < 0"; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteralTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteralTest.java new file mode 100644 index 00000000000000..04af10abaa9e23 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteralTest.java @@ -0,0 +1,334 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.types.IntegerType; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +class NestedCaseWhenCondToLiteralTest extends ExpressionRewriteTestHelper { + + @Test + void testNestedCaseWhen() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NestedCaseWhenCondToLiteral.INSTANCE) + )); + + assertRewriteAfterTypeCoercion( + "case when a > 1 then 1" + + " when a > 2 then" + + " (case when a > 1 then 2" + + " when a > 2 then 3" + + " when a > 1 and a > 1 and a > 2 and a > 2 and a > 3 then 100" + + " when a > 3 then (case when a > 1 then 4" + + " when a > 2 then 5" + + " when a > 3 then 6" + + " end)" + + " when a > 1 and a > 1 and a > 2 and a > 2 and a > 3 then 101" + + " end)" + + " when (case when a > 1 then a > 1" + + " when a > 2 then a > 2" + + " when a > 3 then a > 3" + + " when a > 1 then a > 1" + + " end) then 100" + + " when a > 3 then 7" + + " when a > 1 then 8" + + " else (case when a > 1 then 9" + + " when a > 2 then 10" + + " when a > 3 then 11" + + " when a > 4 then 12" + + " else (case when a > 1 then 13" + + " when a > 2 then 14" + + " when a > 3 then 15" + + " when a > 4 then 16" + + " when a > 5 then (case when a > 1 then 17 when a > 5 then 18 end)" + + " end)" + + " end)" + + " end", + "case when a > 1 then 1" + + " when a > 2 then" + + " (case when false then 2" + + " when true then 3" + + " when false and false and true and true and a > 3 then 100" + + " when a > 3 then (case when false then 4" + + " when true then 5" + + " when true then 6" + + " end)" + + " when false then 101" + + " end)" + + " when (case when false then a > 1" + + " when false then a > 2" + + " when a > 3 then a > 3" + + " when false then a > 1" + + " end) then 100" + + " when a > 3 then 7" + + " when false then 8" + + " else (case when false then 9" + + " when false then 10" + + " when false then 11" + + " when a > 4 then 12" + + " else (case when false then 13" + + " when false then 14" + + " when false then 15" + + " when false then 16" + + " when a > 5 then (case when false then 17 when true then 18 end)" + + " end)" + + " end)" + + " end" + ); + } + + @Test + void testNestedIf() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NestedCaseWhenCondToLiteral.INSTANCE) + )); + assertRewriteAfterTypeCoercion( + "if(" + + " a > 1," + + " if(" + + " a > 1," + + " if(" + + " a > 2," + + " if(a > 2,a + 2,a + 3)," + + " if(" + + " a > 1," + + " if(a > 2,a + 3,a + 4)," + + " if(a > 2,a + 5,a + 6)" + + " )" + + " )," + + " if(a > 1,a + 1,a + 2)" + + " )," + + " if(" + + " a > 1," + + " a + 5," + + " if(a > 2,a + 6,a + 7)" + + " )" + + ")", + "if(" + + " a > 1," + + " if(" + + " true," + + " if(" + + " a > 2," + + " if(true,a + 2,a + 3)," + + " if(" + + " true," + + " if(false,a + 3,a + 4)," + + " if(false,a + 5,a + 6)" + + " )" + + " )," + + " if(true,a + 1,a + 2)" + + " )," + + " if(" + + " false," + + " a + 5," + + " if(a > 2,a + 6,a + 7)" + + " )" + + ")" + ); + } + + @Test + void testNestedCaseWhenReplacer() { + // case when a > 1 then 101 + // when a > 2 then (case when a > 1 then 102 + // when a > 2 then 103 + // when a > 3 then 104 + // when a > 4 then 105 + // else 106 + // end) + // when a > 3 then 107 + // when a > 4 then 108 + // when a > 4 then 109 + // else 110 + // end + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + Expression c1 = new GreaterThan(a, IntegerLiteral.of(1)); + Expression c2 = new GreaterThan(a, IntegerLiteral.of(2)); + Expression c3 = new GreaterThan(a, IntegerLiteral.of(3)); + Expression c4 = new GreaterThan(a, IntegerLiteral.of(4)); + Expression i101 = IntegerLiteral.of(101); + Expression i102 = IntegerLiteral.of(102); + Expression i103 = IntegerLiteral.of(103); + Expression i104 = IntegerLiteral.of(104); + Expression i105 = IntegerLiteral.of(105); + Expression i106 = IntegerLiteral.of(106); + Expression i107 = IntegerLiteral.of(107); + Expression i108 = IntegerLiteral.of(108); + Expression i109 = IntegerLiteral.of(109); + Expression i110 = IntegerLiteral.of(110); + Expression innerCaseWhen = new CaseWhen( + ImmutableList.of( + new WhenClause(c1, i102), + new WhenClause(c2, i103), + new WhenClause(c3, i104), + new WhenClause(c4, i105)), + i106); + Expression outerCaseWhen = new CaseWhen( + ImmutableList.of( + new WhenClause(c1, i101), + new WhenClause(c2, innerCaseWhen), + new WhenClause(c3, i107), + new WhenClause(c4, i108), + new WhenClause(c4, i109)), + i110); + TestNestedCondReplacer replacer = new TestNestedCondReplacer(); + outerCaseWhen.accept(replacer, null); + replacer.checkExpressionReplaceLiterals(outerCaseWhen, + ImmutableList.of(), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(i101, + ImmutableList.of(c1), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(i102, + ImmutableList.of(c2), + ImmutableList.of(c1)); + replacer.checkExpressionReplaceLiterals(i103, + ImmutableList.of(c2), + ImmutableList.of(c1)); + replacer.checkExpressionReplaceLiterals(i104, + ImmutableList.of(c2, c3), + ImmutableList.of(c1)); + replacer.checkExpressionReplaceLiterals(i105, + ImmutableList.of(c2, c4), + ImmutableList.of(c1, c3)); + replacer.checkExpressionReplaceLiterals(i106, + ImmutableList.of(c2), + ImmutableList.of(c1, c3, c4)); + replacer.checkExpressionReplaceLiterals(i107, + ImmutableList.of(c3), + ImmutableList.of(c1, c2)); + replacer.checkExpressionReplaceLiterals(i108, + ImmutableList.of(c4), + ImmutableList.of(c1, c2, c3)); + replacer.checkExpressionReplaceLiterals(i109, + ImmutableList.of(), + ImmutableList.of(c1, c2, c3, c4)); + replacer.checkExpressionReplaceLiterals(i110, + ImmutableList.of(), + ImmutableList.of(c1, c2, c3, c4)); + + // after rewrite, the condition literals should clear + Assertions.assertEquals(Maps.newHashMap(), replacer.conditionLiterals); + } + + @Test + void testNestedIfReplacer() { + // if(a > 1, + // if(a > 2, + // if(a > 3, 301, 302), + // if(a > 4, 303, 304) + // ), + // if(a > 5, 305, 306) + // ) + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + Expression c1 = new GreaterThan(a, IntegerLiteral.of(1)); + Expression c2 = new GreaterThan(a, IntegerLiteral.of(2)); + Expression c3 = new GreaterThan(a, IntegerLiteral.of(3)); + Expression c4 = new GreaterThan(a, IntegerLiteral.of(4)); + Expression c5 = new GreaterThan(a, IntegerLiteral.of(5)); + Expression i301 = IntegerLiteral.of(301); + Expression i302 = IntegerLiteral.of(302); + Expression i303 = IntegerLiteral.of(303); + Expression i304 = IntegerLiteral.of(304); + Expression i305 = IntegerLiteral.of(305); + Expression i306 = IntegerLiteral.of(306); + Expression innerIf1 = new If(c3, i301, i302); + Expression innerIf2 = new If(c4, i303, i304); + Expression innerIf = new If(c2, innerIf1, innerIf2); + Expression outerIf = new If(c1, innerIf, new If(c5, i305, i306)); + TestNestedCondReplacer replacer = new TestNestedCondReplacer(); + outerIf.accept(replacer, null); + replacer.checkExpressionReplaceLiterals(outerIf, + ImmutableList.of(), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(innerIf, + ImmutableList.of(c1), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(innerIf1, + ImmutableList.of(c1, c2), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(i301, + ImmutableList.of(c1, c2, c3), + ImmutableList.of()); + replacer.checkExpressionReplaceLiterals(i302, + ImmutableList.of(c1, c2), + ImmutableList.of(c3)); + replacer.checkExpressionReplaceLiterals(innerIf2, + ImmutableList.of(c1), + ImmutableList.of(c2)); + + // after rewrite, the condition literals should clear + Assertions.assertEquals(Maps.newHashMap(), replacer.conditionLiterals); + } + + private static class TestNestedCondReplacer extends NestedCaseWhenCondToLiteral.NestedCondReplacer { + private final Map> expressionReplaceMap = Maps.newHashMap(); + + @Override + public Expression visit(Expression expr, Void context) { + recordReplaceLiteral(expr); + return super.visit(expr, context); + } + + @Override + public Expression visitCaseWhen(CaseWhen caseWhen, Void context) { + recordReplaceLiteral(caseWhen); + return super.visitCaseWhen(caseWhen, context); + } + + @Override + public Expression visitIf(If ifExpr, Void context) { + recordReplaceLiteral(ifExpr); + return super.visitIf(ifExpr, context); + } + + private void recordReplaceLiteral(Expression expr) { + expressionReplaceMap.put(expr, Maps.newHashMap(conditionLiterals)); + } + + private void checkExpressionReplaceLiterals(Expression expression, + List trueConditions, List falseConditions) { + Map expectedReplaceMap = Maps.newHashMap(); + for (Expression trueCondition : trueConditions) { + expectedReplaceMap.put(trueCondition, BooleanLiteral.TRUE); + } + for (Expression falseCondition : falseConditions) { + expectedReplaceMap.put(falseCondition, BooleanLiteral.FALSE); + } + Assertions.assertEquals(expectedReplaceMap, expressionReplaceMap.get(expression)); + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java index d4adc821880b60..c2896e50fe6d51 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.expression.rules; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; import org.apache.doris.nereids.trees.expressions.EqualTo; @@ -26,9 +27,13 @@ import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.types.StringType; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import org.junit.jupiter.api.Test; class NullSafeEqualToEqualTest extends ExpressionRewriteTestHelper { @@ -41,7 +46,7 @@ void testNullSafeEqualToIsNull() { )); SlotReference slot = new SlotReference("a", StringType.INSTANCE, true); assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), new IsNull(slot)); - slot = new SlotReference("a", StringType.INSTANCE, false); + slot = new SlotReference("a", StringType.INSTANCE, true); assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), new IsNull(slot)); } @@ -56,7 +61,7 @@ void testNullSafeEqualToFalse() { // "NULL <=> Null" to true @Test - void testNullSafeEqualToTrue() { + void testNullSafeEqualNull() { executor = new ExpressionRuleExecutor(ImmutableList.of( bottomUp(NullSafeEqualToEqual.INSTANCE) )); @@ -87,13 +92,13 @@ void testNullSafeEqualNotChangedRightNullable() { // "A<=>B" not changed @Test - void testNullSafeEqualNotChangedBothNullable() { + void testNullSafeEqualChangedBothNotNullable() { executor = new ExpressionRuleExecutor(ImmutableList.of( bottomUp(NullSafeEqualToEqual.INSTANCE) )); SlotReference a = new SlotReference("a", StringType.INSTANCE, false); SlotReference b = new SlotReference("b", StringType.INSTANCE, false); - assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b)); + assertRewrite(new NullSafeEqual(a, b), new EqualTo(a, b)); } // "1 <=> 0" to "1 = 0" @@ -106,4 +111,70 @@ void testNullSafeEqualToEqual() { IntegerLiteral b = new IntegerLiteral(1); assertRewrite(new NullSafeEqual(a, b), new EqualTo(a, b)); } + + @Test + void testInsideCondition() { + LogicalFilter filter = new LogicalFilter(ImmutableSet.of(), + new LogicalEmptyRelation(new RelationId(1), ImmutableList.of())); + ExpressionRewriteContext oldContext = context; + try { + context = new ExpressionRewriteContext(filter, cascadesContext); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); + + assertRewriteAfterTypeCoercion("a <=> a", "TRUE"); + assertRewriteAfterTypeCoercion("a <=> b", "a <=> b"); + assertRewriteAfterTypeCoercion("a <=> count(b)", "a = count(b)"); + assertRewriteAfterTypeCoercion("a <=> 3", "a = 3"); + assertRewriteAfterTypeCoercion("count(a) <=> count(b)", "count(a) = count(b)"); + assertRewriteAfterTypeCoercion("a <=> null", "a is null"); + assertRewriteAfterTypeCoercion("null <=> 3", "FALSE"); + assertRewriteAfterTypeCoercion("not(a <=> 3)", "not(a <=> 3)"); + assertRewriteAfterTypeCoercion("if(a <=> 3, a <=> 4, a <=> 5)", "if(a = 3, a = 4, a = 5)"); + assertRewriteAfterTypeCoercion("not(if(a <=> 3, a <=> 4, a <=> 5))", "not(if(a = 3, a <=> 4, a <=> 5))"); + } finally { + context = oldContext; + } + } + + @Test + void testNotInsideCondition() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); + + assertRewriteAfterTypeCoercion("a <=> 3", "a <=> 3"); + assertRewriteAfterTypeCoercion("a <=> b", "a <=> b"); + assertRewriteAfterTypeCoercion("a <=> count(b)", "a <=> count(b)"); + assertRewriteAfterTypeCoercion("count(a) <=> count(b)", "count(a) = count(b)"); + assertRewriteAfterTypeCoercion("a <=> null", "a is null"); + assertRewriteAfterTypeCoercion("null <=> 3", "false"); + assertRewriteAfterTypeCoercion("not(a <=> 3)", "not(a <=> 3)"); + assertRewriteAfterTypeCoercion("if(a <=> 3, a <=> 4, a <=> 5)", "if(a = 3, a <=> 4, a <=> 5)"); + assertRewriteAfterTypeCoercion("not(if(a <=> 3, a <=> 4, a <=> 5))", "not(if(a = 3, a <=> 4, a <=> 5))"); + } + + @Test + void testNullSafeEqualToTrue() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); + + assertRewriteAfterTypeCoercion("Ba <=> true", "Ba <=> true"); + assertRewriteAfterTypeCoercion("null <=> true", "false"); + assertRewriteAfterTypeCoercion("a > 1 <=> true", "a > 1 and a is not null"); + assertRewriteAfterTypeCoercion("Xa > 1 <=> true", "Xa > 1 = true"); + assertRewriteAfterTypeCoercion("Xa > null <=> true", "Xa > null <=> true"); + assertRewriteAfterTypeCoercion("a + b > c - d <=> true", "a + b > c - d and a is not null and b is not null and c is not null and d is not null"); + assertRewriteAfterTypeCoercion("(a in (1, 2, c)) <=> true", "(a in (1, 2, c)) <=> true"); + assertRewriteAfterTypeCoercion("(a in (1, 2, 3, null)) <=> true", "a is not null and a in (1, 2, 3)"); + assertRewriteAfterTypeCoercion("(a in (null, null, null)) <=> true", "false"); + assertRewriteAfterTypeCoercion("(a + b in (1, 2, 3, null)) <=> true", "a is not null and b is not null and a + b in (1, 2, 3)"); + assertRewriteAfterTypeCoercion("(a > 1 and b > 1 and (c > d or e > 1)) <=> true", + "a > 1 and a is not null and b > 1 and b is not null and (c > d and c is not null and d is not null or e > 1 and e is not null)"); + assertRewriteAfterTypeCoercion("(a / b > 1) <=> true", "(a / b > 1) <=> true"); + assertRewriteAfterTypeCoercion("(a / b > 1 and c > 1 or (d > 1)) <=> true", + "(a / b > 1) <=> true and (c > 1 and c is not null) or (d > 1 and d is not null)"); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/PushIntoCaseWhenBranchTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/PushIntoCaseWhenBranchTest.java new file mode 100644 index 00000000000000..3cef0f3e02d461 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/PushIntoCaseWhenBranchTest.java @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class PushIntoCaseWhenBranchTest extends ExpressionRewriteTestHelper { + + public PushIntoCaseWhenBranchTest() { + setExpressionOnFilter(); + } + + @Test + void testPushIntoCaseWhen() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(PushIntoCaseWhenBranch.INSTANCE) + )); + + assertRewriteAfterTypeCoercion("cast(case when TA = 1 then 1 else 2 end as bigint)", "case when TA = 1 then 1 else 2 end"); + assertRewriteAfterTypeCoercion("TA > case when TB = 1 then 1 else 3 end", "TA > case when TB = 1 then 1 else 3 end"); + assertRewriteAfterTypeCoercion("2 > case when TB = 1 then 1 else 3 end", "case when TB = 1 then true else false end"); + assertRewriteAfterTypeCoercion("2 > case when TB = 1 then TC else TD end", "2 > case when TB = 1 then TC else TD end"); + assertRewriteAfterTypeCoercion("2 > case when TB = 1 then 1 else TD end", "case when TB = 1 then true else 2 > TD end"); + assertRewriteAfterTypeCoercion("2 > case when TB = 1 then TC end", "case when TB = 1 then 2 > TC else null end"); + } + + @Test + void testPushIntoIf() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(PushIntoCaseWhenBranch.INSTANCE) + )); + + assertRewriteAfterTypeCoercion("cast(if(TA = 1, 1, 2) as bigint)", "if(TA = 1, 1, 2)"); + assertRewriteAfterTypeCoercion("TA > if(TB = 1, 1, 3)", "TA > if(TB = 1, 1, 3)"); + assertRewriteAfterTypeCoercion("2 > if(TB = 1, 1, 3)", "if(TB = 1, true, false)"); + assertRewriteAfterTypeCoercion("10 < if(TA = 1, 1, 100) and 2 > if(TB = 1, 1, 3)", + "if(TA = 1, false, true) and if(TB = 1, true, false)"); + assertRewriteAfterTypeCoercion("2 > if(if(TB = 1, 10, TA) > 15, 1, TC)", + "if(if(TB = 1, false, TA > 15), true, 2 > TC)"); + assertRewriteAfterTypeCoercion("2 > if(TB = 1, TC, TD)", "2 > if(TB = 1, TC, TD)"); + assertRewriteAfterTypeCoercion("2 > if(TB = 1, 1, TD)", "if(TB = 1, true, 2 > TD)"); + assertRewriteAfterTypeCoercion("2 > if(TB = 1, TC, NULL)", "if(TB = 1, 2 > TC, null)"); + } + + @Test + void testPushIntoNvl() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(PushIntoCaseWhenBranch.INSTANCE) + )); + assertRewriteAfterTypeCoercion("cast(nvl(TA, TB) as bigint)", "cast(nvl(TA, TB) as bigint)"); + assertRewriteAfterTypeCoercion("cast(nvl(TA, 1) as bigint)", "if(TA is null, 1, cast(TA as bigint))"); + assertRewriteAfterTypeCoercion("a > nvl(b, c)", "a > nvl(b, c)"); + assertRewriteAfterTypeCoercion("2 > nvl(b, c)", "2 > nvl(b, c)"); + assertRewriteAfterTypeCoercion("2 > nvl(null, c)", "if(null is null, 2 > c, null)"); + assertRewriteAfterTypeCoercion("2 > nvl(b, null)", "if(b is null, null, 2 > b)"); + assertRewriteAfterTypeCoercion("2 > nvl(a + b, null)", "if(a + b is null, null, 2 > a + b)"); + assertRewriteAfterTypeCoercion("2 > nvl(a + b + random(1, 10), null)", "2 > nvl(a + b + random(1, 10), null)"); + assertRewriteAfterTypeCoercion("2 > nvl(null, a + b + random(1, 10))", "if(null is null, 2 > a + b + random(1, 10), null)"); + } + + @Test + void testPushIntoNullIf() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(PushIntoCaseWhenBranch.INSTANCE) + )); + assertRewriteAfterTypeCoercion("cast(nullif(TA, TB) as bigint)", "if(TA = TB, NULL, cast(TA as bigint))"); + assertRewriteAfterTypeCoercion("cast(nullif(TA, 1) as bigint)", "if(TA = 1, null, cast(TA as bigint))"); + assertRewriteAfterTypeCoercion("a > nullif(b, c)", "a > nullif(b, c)"); + assertRewriteAfterTypeCoercion("2 > nullif(b, c)", "if(b = c, null, 2 > b)"); + assertRewriteAfterTypeCoercion("2 > nullif(b + random(1, 10), c)", "2 > nullif(b + random(1, 10), c)"); + assertRewriteAfterTypeCoercion("2 > nullif(b, c + random(1, 10))", "if(b = c + random(1, 10), null, 2 > b)"); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyEqualBooleanLiteralTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyEqualBooleanLiteralTest.java new file mode 100644 index 00000000000000..02db668cfa1f8d --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyEqualBooleanLiteralTest.java @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class SimplifyEqualBooleanLiteralTest extends ExpressionRewriteTestHelper { + + @Test + void testEqualToTrue() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + SimplifyEqualBooleanLiteral.INSTANCE + ) + )); + + assertRewriteAfterTypeCoercion("a > 1 = true", "a > 1"); + assertRewriteAfterTypeCoercion("Ba = true", "Ba = true"); + } + + @Test + void testEqualToFalse() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + SimplifyEqualBooleanLiteral.INSTANCE + ) + )); + + assertRewriteAfterTypeCoercion("(a > 1) = false", "not(a > 1)"); + assertRewriteAfterTypeCoercion("Ba = false", "Ba = false"); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagationTest.java index eac9e0ecc25a18..a1ef8f8d7b414e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagationTest.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.OrderKey; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; @@ -36,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; @@ -46,6 +46,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.types.BigIntType; @@ -70,7 +71,7 @@ class ConstantPropagationTest { private final ConstantPropagation executor = new ConstantPropagation(); private final NereidsParser parser = new NereidsParser(); - private final ExpressionRewriteContext exprRewriteContext; + private final CascadesContext cascadesContext; private final JobContext jobContext; @@ -82,11 +83,11 @@ class ConstantPropagationTest { private final SlotReference scoreSid; private final SlotReference scoreCid; private final SlotReference scoreGrade; + private final LogicalFilter filter; ConstantPropagationTest() { - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext( + cascadesContext = MemoTestUtils.createCascadesContext( new UnboundRelation(new RelationId(1), ImmutableList.of("tbl"))); - exprRewriteContext = new ExpressionRewriteContext(cascadesContext); jobContext = new JobContext(cascadesContext, null, Double.MAX_VALUE); student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("")); @@ -98,6 +99,8 @@ class ConstantPropagationTest { scoreSid = (SlotReference) score.getOutput().get(0); scoreCid = (SlotReference) score.getOutput().get(1); scoreGrade = (SlotReference) score.getOutput().get(2); + + filter = new LogicalFilter<>(ImmutableSet.of(BooleanLiteral.TRUE), student); } @Test @@ -166,17 +169,17 @@ void testExpressionNotReplace() { // for `a is not null`, if this Not isGeneratedIsNotNull, then will not rewrite it SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true); Expression expr1 = ExpressionUtils.and(new EqualTo(a, new IntegerLiteral(1)), new Not(new IsNull(a), false)); - Expression rewrittenExpr1 = executor.replaceConstantsAndRewriteExpr(student, expr1, true, exprRewriteContext); + Expression rewrittenExpr1 = rewriteExpression(expr1, true); Expression expectExpr1 = new EqualTo(a, new IntegerLiteral(1)); Assertions.assertEquals(expectExpr1, rewrittenExpr1); Expression expr2 = ExpressionUtils.and(new EqualTo(a, new IntegerLiteral(1)), new Not(new IsNull(a), true)); - Expression rewrittenExpr2 = executor.replaceConstantsAndRewriteExpr(student, expr2, true, exprRewriteContext); + Expression rewrittenExpr2 = rewriteExpression(expr2, true); Assertions.assertEquals(expr2, rewrittenExpr2); // for `a match_any xx`, don't replace it, because the match require left child is column, not literal SlotReference b = new SlotReference("b", StringType.INSTANCE, true); Expression expr3 = ExpressionUtils.and(new EqualTo(b, new StringLiteral("hello")), new MatchAny(b, new StringLiteral("%ll%"))); - Expression rewrittenExpr3 = executor.replaceConstantsAndRewriteExpr(student, expr3, true, exprRewriteContext); + Expression rewrittenExpr3 = rewriteExpression(expr3, true); Assertions.assertEquals(expr3, rewrittenExpr3); } @@ -439,10 +442,14 @@ private void assertRewrite(String expression, String expected, boolean useInnerI Expression rewriteExpression = parser.parseExpression(expression); rewriteExpression = ExpressionRewriteTestHelper.typeCoercion( ExpressionRewriteTestHelper.replaceUnboundSlot(rewriteExpression, Maps.newHashMap())); - rewriteExpression = executor.replaceConstantsAndRewriteExpr(student, rewriteExpression, - useInnerInferConstants, exprRewriteContext); + rewriteExpression = rewriteExpression(rewriteExpression, useInnerInferConstants); Expression expectedExpression = parser.parseExpression(expected); Assertions.assertEquals(expectedExpression.toSql(), rewriteExpression.toSql()); } + private Expression rewriteExpression(Expression expression, boolean useInnerInferConstants) { + LogicalPlan plan = useInnerInferConstants ? filter : student; + return executor.replaceConstantsAndRewriteExpr(plan, expression, cascadesContext); + } + } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateFilterTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateFilterTest.java index 692f6532541015..a295b42d6ca2d4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateFilterTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateFilterTest.java @@ -24,13 +24,11 @@ import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Or; -import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoPatternMatchSupported; @@ -38,7 +36,6 @@ import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.Arrays; @@ -186,25 +183,4 @@ void testEliminateOneFilterFalse() { logicalFilter(logicalOlapScan()).when(f -> f.getPredicate() instanceof GreaterThan) ); } - - @Test - void testEliminateNullLiteral() { - Expression a = new SlotReference("a", IntegerType.INSTANCE); - Expression b = new SlotReference("b", IntegerType.INSTANCE); - Expression one = Literal.of(1); - Expression two = Literal.of(2); - Expression expression = new And(Arrays.asList( - new And(new GreaterThan(a, one), new NullLiteral(IntegerType.INSTANCE)), - new Or(Arrays.asList(new GreaterThan(b, two), new NullLiteral(IntegerType.INSTANCE), - new EqualTo(a, new NullLiteral(IntegerType.INSTANCE)))), - new Not(new And(new GreaterThan(a, one), new NullLiteral(IntegerType.INSTANCE))) - )); - Expression expectExpression = new And(Arrays.asList( - new And(new GreaterThan(a, one), BooleanLiteral.FALSE), - new Or(Arrays.asList(new GreaterThan(b, two), BooleanLiteral.FALSE, - new EqualTo(a, new NullLiteral(IntegerType.INSTANCE)))), - new Not(new And(new GreaterThan(a, one), new NullLiteral(IntegerType.INSTANCE))) - )); - Assertions.assertEquals(expectExpression, new EliminateFilter().eliminateNullLiteral(expression)); - } } diff --git a/regression-test/data/nereids_rules_p0/adjust_nullable/test_subquery_nullable.out b/regression-test/data/nereids_rules_p0/adjust_nullable/test_subquery_nullable.out index 348814d39f1149..cdf32d50f56b4f 100644 --- a/regression-test/data/nereids_rules_p0/adjust_nullable/test_subquery_nullable.out +++ b/regression-test/data/nereids_rules_p0/adjust_nullable/test_subquery_nullable.out @@ -127,7 +127,7 @@ PhysicalResultSink -- !correlate_notop_notnullable_agg_scalar_subquery_shape -- PhysicalResultSink --PhysicalDistribute[DistributionSpecGather] -----PhysicalProject[AND[any_value(count(x)) IS NULL,NULL] AS `x > 10 and x < 1`, a AS `a`, assert_true(OR[count(*) IS NULL,(count(*) <= 1)], 'correlate scalar subquery must return only 1 row') AS `assert_true(OR[count(*) IS NULL,(count(*) <= 1)], 'correlate scalar subquery must return only 1 row')`, assert_true(OR[count(*) IS NULL,(count(*) <= 1)], 'correlate scalar subquery must return only 1 row') AS `assert_true(OR[count(*) IS NULL,(count(*) <= 1)], 'correlate scalar subquery must return only 1 row')`] +----PhysicalProject[AND[any_value(count(x)) IS NULL,NULL] AS `x > 10 and x < 1`, a AS `a`, assert_true(OR[(count(*) <= 1),count(*) IS NULL], 'correlate scalar subquery must return only 1 row') AS `assert_true(OR[count(*) IS NULL,(count(*) <= 1)], 'correlate scalar subquery must return only 1 row')`, assert_true(OR[(count(*) <= 1),count(*) IS NULL], 'correlate scalar subquery must return only 1 row') AS `assert_true(OR[count(*) IS NULL,(count(*) <= 1)], 'correlate scalar subquery must return only 1 row')`] ------hashJoin[LEFT_OUTER_JOIN broadcast] hashCondition=((expr_(cast(a as BIGINT) + cast(b as BIGINT)) = cast(x as BIGINT))) otherCondition=() --------PhysicalProject[(cast(a as BIGINT) + cast(b as BIGINT)) AS `expr_(cast(a as BIGINT) + cast(b as BIGINT))`, test_subquery_nullable_t1.a] ----------PhysicalOlapScan[test_subquery_nullable_t1] diff --git a/regression-test/data/nereids_rules_p0/case_when_rules/test_case_when_rules.out b/regression-test/data/nereids_rules_p0/case_when_rules/test_case_when_rules.out new file mode 100644 index 00000000000000..686bcfd06f13bd --- /dev/null +++ b/regression-test/data/nereids_rules_p0/case_when_rules/test_case_when_rules.out @@ -0,0 +1,56 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql_1_shape -- +PhysicalResultSink +--PhysicalProject +----filter(a IN (1, 3)) +------PhysicalOlapScan[tbl_test_case_when_rules] + +-- !sql_1_result -- +1 10 101 +3 30 103 + +-- !sql_2_shape -- +PhysicalResultSink +--PhysicalProject +----filter((tbl_test_case_when_rules.a = 4)) +------PhysicalOlapScan[tbl_test_case_when_rules] + +-- !sql_2_result -- +4 40 104 + +-- !sql_3_shape -- +PhysicalResultSink +--PhysicalProject +----filter(OR[( not a IN (1, 2, 3)),a IS NULL]) +------PhysicalOlapScan[tbl_test_case_when_rules] + +-- !sql_3_result -- +\N 0 107 +4 40 104 +5 50 107 +6 60 107 + +-- !sql_4_shape -- +PhysicalResultSink +--PhysicalProject +----filter(OR[( not (a = 3)),a IS NULL]) +------PhysicalOlapScan[tbl_test_case_when_rules] + +-- !sql_4_result -- +\N 0 107 +1 10 101 +2 20 102 +4 40 104 +5 50 107 +6 60 107 + +-- !sql_5_shape -- +PhysicalResultSink +--PhysicalProject +----filter(a IN (1, 2)) +------PhysicalOlapScan[tbl_test_case_when_rules] + +-- !sql_5_result -- +1 10 101 +2 20 102 + diff --git a/regression-test/data/nereids_rules_p0/expression/test_simplify_range.out b/regression-test/data/nereids_rules_p0/expression/test_simplify_range.out new file mode 100644 index 00000000000000..cda4a2fb326469 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/expression/test_simplify_range.out @@ -0,0 +1,31 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql_1_shape -- +PhysicalResultSink +--PhysicalProject +----PhysicalOlapScan[test_simplify_range_tbl_1] + +-- !sql_1_result -- +\N +false + +-- !sql_2_shape -- +PhysicalResultSink +--PhysicalProject +----PhysicalOlapScan[test_simplify_range_tbl_1] + +-- !sql_2_result -- +\N \N +10 20 + +-- !sql_3_shape -- +PhysicalResultSink +--PhysicalProject +----hashAgg[GLOBAL] +------hashAgg[LOCAL] +--------PhysicalProject +----------filter(( not (cast(b as BIGINT) * 0) IS NULL)) +------------PhysicalOlapScan[test_simplify_range_tbl_1] + +-- !sql_3_result -- +0 20 + diff --git a/regression-test/data/nereids_rules_p0/filter_push_down/push_down_filter_other_condition.out b/regression-test/data/nereids_rules_p0/filter_push_down/push_down_filter_other_condition.out index 370abe414b5f62..51c751b763006c 100644 --- a/regression-test/data/nereids_rules_p0/filter_push_down/push_down_filter_other_condition.out +++ b/regression-test/data/nereids_rules_p0/filter_push_down/push_down_filter_other_condition.out @@ -229,7 +229,7 @@ PhysicalResultSink -- !pushdown_left_outer_join_subquery_outer -- PhysicalResultSink ---NestedLoopJoin[INNER_JOIN]OR[(t1.id = t2.id),AND[id IS NULL,(id > 1)]] +--hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=() ----PhysicalOlapScan[t1] ----PhysicalAssertNumRows ------PhysicalOlapScan[t2] diff --git a/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out b/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out index 3a1170d265c34b..42a6c1029948b1 100644 --- a/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out +++ b/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out @@ -136,24 +136,25 @@ PhysicalResultSink -- !test_if_predicate -- PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() -----PhysicalOlapScan[extend_infer_t1] -----filter(if(( not d_int IN (10, 20)), TRUE, FALSE)) +----filter(( not d_int IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] +----filter(( not d_int IN (10, 20))) ------PhysicalOlapScan[extend_infer_t1] -- !test_if_and_in_predicate -- PhysicalResultSink ---hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() -----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +--NestedLoopJoin[INNER_JOIN] +----filter((t1.d_int = 5)) ------PhysicalOlapScan[extend_infer_t1] -----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +----filter((t2.d_int = 5)) ------PhysicalOlapScan[extend_infer_t1] -- !test_if_and_in_predicate_not -- PhysicalResultSink ---hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() -----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +--NestedLoopJoin[INNER_JOIN] +----filter((t1.d_int = 5)) ------PhysicalOlapScan[extend_infer_t1] -----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +----filter((t2.d_int = 5)) ------PhysicalOlapScan[extend_infer_t1] -- !test_multi_slot_in_predicate1 -- @@ -173,9 +174,10 @@ PhysicalResultSink -- !test_case_when_predicate -- PhysicalResultSink ---hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() -----PhysicalOlapScan[extend_infer_t1] -----filter(CASE WHEN (d_int = 1) THEN TRUE WHEN (d_int = 2) THEN FALSE ELSE FALSE END) +--NestedLoopJoin[INNER_JOIN] +----filter((t1.d_int = 1)) +------PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_int = 1)) ------PhysicalOlapScan[extend_infer_t1] -- !test_datetimev2_predicate -- diff --git a/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_literal.out b/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_literal.out index 54d9564bab1fda..adc6bb8087c0bf 100644 --- a/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_literal.out +++ b/regression-test/data/nereids_rules_p0/infer_predicate/pull_up_predicate_literal.out @@ -897,7 +897,7 @@ PhysicalResultSink ----------PhysicalProject ------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] ------PhysicalProject ---------filter((cast(d_char as BOOLEAN) = TRUE)) +--------filter(cast(d_char as BOOLEAN)) ----------PhysicalOlapScan[test_types] -- !const_value_and_join_column_type92 -- @@ -909,7 +909,7 @@ PhysicalResultSink ----------PhysicalProject ------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] ------PhysicalProject ---------filter((cast(d_varchar as BOOLEAN) = TRUE)) +--------filter(cast(d_varchar as BOOLEAN)) ----------PhysicalOlapScan[test_types] -- !const_value_and_join_column_type93 -- @@ -921,7 +921,7 @@ PhysicalResultSink ----------PhysicalProject ------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal] ------PhysicalProject ---------filter((cast(d_string as BOOLEAN) = TRUE)) +--------filter(cast(d_string as BOOLEAN)) ----------PhysicalOlapScan[test_types] -- !const_value_and_join_column_type96 -- diff --git a/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out new file mode 100644 index 00000000000000..269a1f29a95b76 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.out @@ -0,0 +1,293 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !case_when_one_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]( not (CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) END = 95)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------filter(OR[( not (a = 95)),( not ((a + b) = 95))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[(x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_one_side_2 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]( not (CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) END = ((a - b) - 5))) +------PhysicalProject[((a - b) - 5) AS `((a - b) - 5)`, (a + b) AS `(a + b)`, t1.a] +--------filter(OR[( not (a = ((a - b) - 5))),( not ((a + b) = ((a - b) - 5)))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[(x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_one_side_3 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]( not (CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = 95)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +--------filter(OR[( not (a = 95)),( not ((a + b) = 95)),( not (b = 95))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[(x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_one_side_4 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]( not (CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = ((a - b) - 5))) +------PhysicalProject[((a - b) - 5) AS `((a - b) - 5)`, (a + b) AS `(a + b)`, t1.a, t1.b] +--------filter(OR[( not (a = ((a - b) - 5))),( not ((a + b) = ((a - b) - 5))),( not (b = ((a - b) - 5)))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[(x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_one_side_5 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END + random(1, 100)) = (a - b)) +------PhysicalProject[(a + b) AS `(a + b)`, (a - b) AS `(a - b)`, t1.a, t1.b] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[(x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_one_side_6 -- +PhysicalResultSink +--NestedLoopJoin[CROSS_JOIN] +----PhysicalProject[t1.a] +------filter(( not (CASE WHEN (a > 10) THEN a WHEN (b > 10) THEN (a + b) END = 95))) +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +----PhysicalProject[t2.x] +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_one_side_7 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]((CASE WHEN (x > 0) THEN a WHEN (x < 10) THEN (a + 1) END + CASE WHEN (x > 1) THEN (a + 1) WHEN (x < 10) THEN (a + 10) END) > (a + b)) +------PhysicalProject[(a + 1) AS `(a + 1)`, (a + 10) AS `(a + 10)`, (a + b) AS `(a + b)`, t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[(x < 10) AS `(x < 10)`, (x > 0) AS `(x > 0)`, (x > 1) AS `(x > 1)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_one_side_8 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]((CASE WHEN (x > 0) THEN a WHEN (x < 10) THEN (a + 1) END + CASE WHEN (a > 1) THEN (a + 1) WHEN (a < 10) THEN (a + 10) END) > (a + b)) +------PhysicalProject[(a + 1) AS `(a + 1)`, (a + b) AS `(a + b)`, CASE WHEN (a > 1) THEN (a + 1) WHEN (a < 10) THEN (a + 10) END AS `CASE WHEN (a > 1) THEN (a + 1) WHEN (a < 10) THEN (a + 10) END`, t1.a] +--------filter(OR[((t1.a + CASE WHEN (a > 1) THEN (a + 1) WHEN (a < 10) THEN (a + 10) END) > (t1.a + t1.b)),((t1.a + CASE WHEN (a > 1) THEN (a + 1) WHEN (a < 10) THEN (a + 10) END) > ((t1.a + t1.b) - 1))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[(x < 10) AS `(x < 10)`, (x > 0) AS `(x > 0)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_one_side_9 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END > (a + b)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------filter(((t1.a + t1.b) < 10005)) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END AS `CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_one_side_10 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[LEFT_OUTER_JOIN](CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END > (a + b)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END AS `CASE WHEN (x > 0) THEN 105 WHEN (x < 10) THEN 10005 ELSE NULL END`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_one_side_11 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[RIGHT_OUTER_JOIN](CASE WHEN (a > 0) THEN 105 WHEN (b < 10) THEN 10005 ELSE NULL END > (x + y)) +------PhysicalProject[CASE WHEN (a > 0) THEN 105 WHEN (b < 10) THEN 10005 ELSE NULL END AS `CASE WHEN (a > 0) THEN 105 WHEN (b < 10) THEN 10005 ELSE NULL END`, t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[(x + y) AS `(x + y)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !hash_cond_for_one_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----hashJoin[INNER_JOIN] hashCondition=((expr_if((a > 1), 1, 100) = expr_if((x > 2), 1, 200))) otherCondition=() +------PhysicalProject[if((a > 1), 1, 100) AS `expr_if((a > 1), 1, 100)`, t1.a] +--------filter((t1.a > 1)) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[if((x > 2), 1, 200) AS `expr_if((x > 2), 1, 200)`, t2.x] +--------filter((t2.x > 2)) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_two_side_1 -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +--PhysicalCteAnchor ( cteId=CTEId#1 ) +----PhysicalCteProducer ( cteId=CTEId#1 ) +------PhysicalProject[((x + y) + -4) AS `((x + y) + -4)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] +----PhysicalResultSink +------PhysicalUnion +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.a = .((x + y) + -4))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = .((x + y) + -4))) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.(a + b) = .((x + y) + -4))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = .((x + y) + -4)) and OR[( not (a = ((x + y) + -4))),(a = ((x + y) + -4)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.b = .((x + y) + -4))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = .((x + y) + -4)) and OR[( not ((a + b) = ((x + y) + -4))),((a + b) = ((x + y) + -4)) IS NULL] and OR[( not (a = ((x + y) + -4))),(a = ((x + y) + -4)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) + +-- !case_when_two_side_2 -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalProject[(a + b) AS `(a + b)`, t1.a] +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +--PhysicalCteAnchor ( cteId=CTEId#1 ) +----PhysicalCteProducer ( cteId=CTEId#1 ) +------PhysicalProject[((x + y) + -4) AS `((x + y) + -4)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] +----PhysicalResultSink +------PhysicalUnion +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.a = .((x + y) + -4))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) END = .((x + y) + -4))) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.(a + b) = .((x + y) + -4))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) END = .((x + y) + -4)) and OR[( not (a = ((x + y) + -4))),(a = ((x + y) + -4)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) + +-- !case_when_two_side_3 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN]((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END + random(1, 10)) = ((x + y) + 1)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[((x + y) + 1) AS `((x + y) + 1)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_two_side_4 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE (x + b) END = ((x + y) + 1)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[((x + y) + 1) AS `((x + y) + 1)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !case_when_two_side_5 -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +--PhysicalCteAnchor ( cteId=CTEId#1 ) +----PhysicalCteProducer ( cteId=CTEId#1 ) +------PhysicalProject[((x + y) + -4) AS `((x + y) + -4)`, (x - y) AS `(x - y)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] +----PhysicalResultSink +------PhysicalUnion +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.a = .(x - y))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = .((x + y) + -4)) and (if((a > x), a, b) = .(x - y))) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.b = .(x - y))) otherCondition=((CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE b END = .((x + y) + -4)) and (if((a > x), a, b) = .(x - y)) and OR[( not (a = (x - y))),(a = (x - y)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) + +-- !case_when_two_side_6 -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalProject[a IS NULL AS `a IS NULL`, coalesce(a, b) AS `coalesce(a, b)`, t1.a] +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +--PhysicalCteAnchor ( cteId=CTEId#1 ) +----PhysicalCteProducer ( cteId=CTEId#1 ) +------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] +----PhysicalResultSink +------PhysicalUnion +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.x = .coalesce(a, b))) otherCondition=((if(a IS NULL, x, y) = .coalesce(a, b))) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.y = .coalesce(a, b))) otherCondition=((if(a IS NULL, x, y) = .coalesce(a, b)) and OR[( not (x = coalesce(a, b))),(x = coalesce(a, b)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) + +-- !case_when_two_side_7 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](CASE WHEN (x > 10) THEN a WHEN (y > 10) THEN (a + b) ELSE 100 END = ((x + y) + -4)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[((x + y) + -4) AS `((x + y) + -4)`, (x > 10) AS `(x > 10)`, (y > 10) AS `(y > 10)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !if_one_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](if((a > x), a, b) = (a + b)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a, t1.b] +--------filter(OR[(t1.a = (t1.a + t1.b)),(t1.b = (t1.a + t1.b))]) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !if_two_side_1 -- +PhysicalCteAnchor ( cteId=CTEId#0 ) +--PhysicalCteProducer ( cteId=CTEId#0 ) +----PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +--PhysicalCteAnchor ( cteId=CTEId#1 ) +----PhysicalCteProducer ( cteId=CTEId#1 ) +------PhysicalProject[(x + y) AS `(x + y)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] +----PhysicalResultSink +------PhysicalUnion +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.a = .(x + y))) otherCondition=((if((a > x), a, b) = .(x + y))) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) +--------PhysicalProject[.a, .x] +----------hashJoin[INNER_JOIN] hashCondition=((.b = .(x + y))) otherCondition=((if((a > x), a, b) = .(x + y)) and OR[( not (a = (x + y))),(a = (x + y)) IS NULL]) +------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------PhysicalCteConsumer ( cteId=CTEId#1 ) + +-- !ifnull_one_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](ifnull(a, x) = (a + b)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !ifnull_two_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](ifnull(a, x) = (x + y)) +------PhysicalProject[t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[(x + y) AS `(x + y)`, t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !nullif_one_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----NestedLoopJoin[INNER_JOIN](nullif(a, x) = (a + b)) +------PhysicalProject[(a + b) AS `(a + b)`, t1.a] +--------filter((t1.a = (t1.a + t1.b))) +----------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[t2.x] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + +-- !nullif_two_side_1 -- +PhysicalResultSink +--PhysicalProject[t1.a, t2.x] +----hashJoin[INNER_JOIN] hashCondition=((t1.a = expr_(x + y))) otherCondition=((nullif(a, x) = (t2.x + t2.y))) +------PhysicalProject[t1.a] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_1] +------PhysicalProject[(x + y) AS `expr_(x + y)`, tbl_join_extract_or_from_case_when_2.x, tbl_join_extract_or_from_case_when_2.y] +--------PhysicalOlapScan[tbl_join_extract_or_from_case_when_2] + diff --git a/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out b/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out index e0ff329e8e46d5..826c453b7f4801 100644 --- a/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out +++ b/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out @@ -195,7 +195,7 @@ PhysicalResultSink -- !infer_predicate_join_with_case_when -- PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=() -----filter((if((score > 170), 'high', 'low') = 'high')) +----filter((t1.score > 170)) ------PhysicalOlapScan[t] ----PhysicalOlapScan[t] @@ -354,14 +354,14 @@ PhysicalResultSink PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t12.id = t34.id)) otherCondition=() ----hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=() -------filter(( not (id = 3)) and ( not (id = 4)) and (t1.id < 9) and (t1.id > 1)) +------filter(( not id IN (3, 4)) and (t12.id < 9) and (t12.id > 1)) --------PhysicalOlapScan[t1] -------filter(( not (id = 3)) and ( not (id = 4)) and (t2.id < 9) and (t2.id > 1)) +------filter(( not id IN (3, 4)) and (t2.id < 9) and (t2.id > 1)) --------PhysicalOlapScan[t2] ----hashJoin[INNER_JOIN] hashCondition=((t3.id = t4.id)) otherCondition=() -------filter(( not (id = 3)) and ( not (id = 4)) and (t34.id < 9) and (t34.id > 1)) +------filter(( not id IN (3, 4)) and (t3.id < 9) and (t3.id > 1)) --------PhysicalOlapScan[t3] -------filter(( not (id = 3)) and ( not (id = 4)) and (t4.id < 9) and (t4.id > 1)) +------filter(( not id IN (3, 4)) and (t4.id < 9) and (t4.id > 1)) --------PhysicalOlapScan[t4] -- !infer8 -- diff --git a/regression-test/data/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.out b/regression-test/data/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.out index 88f5fc54ed3f0c..d2096c4635d858 100644 --- a/regression-test/data/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.out +++ b/regression-test/data/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.out @@ -4,14 +4,14 @@ PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.c_int = t2.c_int)) otherCondition=() ----PhysicalOlapScan[tbl_adjust_virtual_slot_nullable_1] ----PhysicalProject -------filter(OR[( not dayofmonth(c_date) IN (1, 3)),( not dayofmonth(c_date) IN (2, 3))]) +------filter(OR[( not dayofmonth(c_date) IN (1, 3)),( not (cast(dayofmonth(c_date) as INT) = c_int))]) --------PhysicalOlapScan[tbl_adjust_virtual_slot_nullable_2] -- !left_join_result -- -1 2020-01-01 1 2022-02-01 1 2020-01-01 1 2022-02-02 -1 2020-01-02 1 2022-02-01 +1 2020-01-01 1 2022-02-03 1 2020-01-02 1 2022-02-02 -1 2020-01-03 1 2022-02-01 +1 2020-01-02 1 2022-02-03 1 2020-01-03 1 2022-02-02 +1 2020-01-03 1 2022-02-03 diff --git a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query21.out b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query21.out index db506f0acaa0e9..7d5ea073befc01 100644 --- a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query21.out +++ b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query21.out @@ -4,7 +4,7 @@ PhysicalResultSink --PhysicalTopN[MERGE_SORT] ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] ---------filter((if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE))) +--------filter(((cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE)) and (x.inv_before > 0)) ----------hashAgg[GLOBAL] ------------PhysicalDistribute[DistributionSpecHash] --------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query34.out b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query34.out index a12df581f106ae..3c40cf00bc54bf 100644 --- a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query34.out +++ b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query34.out @@ -25,7 +25,7 @@ PhysicalResultSink ------------------------------filter(s_county IN ('Barrow County', 'Daviess County', 'Franklin Parish', 'Luce County', 'Richland County', 'Walker County', 'Williamson County', 'Ziebach County')) --------------------------------PhysicalOlapScan[store] ------------------------PhysicalProject ---------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.2) and hd_buy_potential IN ('0-500', '1001-5000')) +--------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.2) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('0-500', '1001-5000')) ----------------------------PhysicalOlapScan[household_demographics] ------------PhysicalProject --------------PhysicalOlapScan[customer] diff --git a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query39.out b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query39.out index d906073878075f..7f3fd870ff4eb8 100644 --- a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query39.out +++ b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query39.out @@ -3,7 +3,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) ----PhysicalProject -------filter((if((mean = 0.0), 0.0, (stdev / mean)) > 1.0)) +------filter(( not (mean = 0.0)) and ((foo.stdev / foo.mean) > 1.0)) --------hashAgg[GLOBAL] ----------PhysicalDistribute[DistributionSpecHash] ------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query47.out b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query47.out index c04dc536e6bbef..4eab6ffa832a9d 100644 --- a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query47.out +++ b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query47.out @@ -39,6 +39,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1)) and (v1.s_company_name = v1_lag.s_company_name) and (v1.s_store_name = v1_lag.s_store_name)) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 s_store_name->[s_store_name];RF6 s_company_name->[s_company_name];RF7 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 RF7 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query53.out b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query53.out index 6cc3c447d13449..c9b3389d6f9f04 100644 --- a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query53.out +++ b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query53.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_quarterly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_quarterly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_quarterly_sales) > 0.100000) and (tmp1.avg_quarterly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query57.out b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query57.out index 12bdb0c331739c..c09d3f986a5a10 100644 --- a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query57.out +++ b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query57.out @@ -39,6 +39,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.cc_name = v1_lag.cc_name) and (v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1))) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 cc_name->[cc_name];RF6 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query63.out b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query63.out index 94e35cb7458980..40eac890870e04 100644 --- a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query63.out +++ b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query63.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000) and (tmp1.avg_monthly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query73.out b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query73.out index ef94345e10145a..0b2d1c094ce101 100644 --- a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query73.out +++ b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query73.out @@ -25,7 +25,7 @@ PhysicalResultSink ------------------------------filter(s_county IN ('Barrow County', 'Daviess County', 'Fairfield County', 'Walker County')) --------------------------------PhysicalOlapScan[store] ------------------------PhysicalProject ---------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.0) and hd_buy_potential IN ('501-1000', 'Unknown')) +--------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.0) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('501-1000', 'Unknown')) ----------------------------PhysicalOlapScan[household_demographics] ------------PhysicalProject --------------PhysicalOlapScan[customer] diff --git a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query89.out b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query89.out index 8b0e89ca99b508..6f43c2b24cdc99 100644 --- a/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query89.out +++ b/regression-test/data/shape_check/tpcds_sf100/noStatsRfPrune/query89.out @@ -6,7 +6,7 @@ PhysicalResultSink ------PhysicalDistribute[DistributionSpecGather] --------PhysicalTopN[LOCAL_SORT] ----------PhysicalProject -------------filter((if(( not (avg_monthly_sales = 0.0000)), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +------------filter(( not (avg_monthly_sales = 0.0000)) and ((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000)) --------------PhysicalWindow ----------------PhysicalQuickSort[LOCAL_SORT] ------------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query21.out b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query21.out index 6a3b7ecf26ca2f..fd13184f321b14 100644 --- a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query21.out +++ b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query21.out @@ -4,7 +4,7 @@ PhysicalResultSink --PhysicalTopN[MERGE_SORT] ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] ---------filter((if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE))) +--------filter(((cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE)) and (x.inv_before > 0)) ----------hashAgg[GLOBAL] ------------PhysicalDistribute[DistributionSpecHash] --------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query34.out b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query34.out index ed1f30d2c00dbe..b1a87ab9083269 100644 --- a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query34.out +++ b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query34.out @@ -25,7 +25,7 @@ PhysicalResultSink ------------------------------filter(s_county IN ('Barrow County', 'Daviess County', 'Franklin Parish', 'Luce County', 'Richland County', 'Walker County', 'Williamson County', 'Ziebach County')) --------------------------------PhysicalOlapScan[store] ------------------------PhysicalProject ---------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.2) and hd_buy_potential IN ('0-500', '1001-5000')) +--------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.2) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('0-500', '1001-5000')) ----------------------------PhysicalOlapScan[household_demographics] ------------PhysicalProject --------------PhysicalOlapScan[customer] diff --git a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query39.out b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query39.out index 90c507f9c536f5..7e590db4e6d5d7 100644 --- a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query39.out +++ b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query39.out @@ -3,7 +3,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) ----PhysicalProject -------filter((if((mean = 0.0), 0.0, (stdev / mean)) > 1.0)) +------filter(( not (mean = 0.0)) and ((foo.stdev / foo.mean) > 1.0)) --------hashAgg[GLOBAL] ----------PhysicalDistribute[DistributionSpecHash] ------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query47.out b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query47.out index 2f76e9211ef4df..0aa0c8ac22ce9a 100644 --- a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query47.out +++ b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query47.out @@ -39,6 +39,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1)) and (v1.s_company_name = v1_lag.s_company_name) and (v1.s_store_name = v1_lag.s_store_name)) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 s_store_name->[s_store_name];RF6 s_company_name->[s_company_name];RF7 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 RF7 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query53.out b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query53.out index 00c45d333ec013..018fd75b77d2b3 100644 --- a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query53.out +++ b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query53.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_quarterly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_quarterly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_quarterly_sales) > 0.100000) and (tmp1.avg_quarterly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query57.out b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query57.out index 697bd284f5701e..9a6c48bb489f4b 100644 --- a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query57.out +++ b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query57.out @@ -39,6 +39,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.cc_name = v1_lag.cc_name) and (v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1))) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 cc_name->[cc_name];RF6 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query63.out b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query63.out index 210a97832febf1..27938e87c4a439 100644 --- a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query63.out +++ b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query63.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000) and (tmp1.avg_monthly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query73.out b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query73.out index 969743c9f9f833..d9585feee5c7a4 100644 --- a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query73.out +++ b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query73.out @@ -25,7 +25,7 @@ PhysicalResultSink ------------------------------filter(s_county IN ('Barrow County', 'Daviess County', 'Fairfield County', 'Walker County')) --------------------------------PhysicalOlapScan[store] ------------------------PhysicalProject ---------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.0) and hd_buy_potential IN ('501-1000', 'Unknown')) +--------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.0) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('501-1000', 'Unknown')) ----------------------------PhysicalOlapScan[household_demographics] ------------PhysicalProject --------------PhysicalOlapScan[customer] diff --git a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query89.out b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query89.out index 552fbbd3aaf5ae..e3b2b344232b68 100644 --- a/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query89.out +++ b/regression-test/data/shape_check/tpcds_sf100/no_stats_shape/query89.out @@ -6,7 +6,7 @@ PhysicalResultSink ------PhysicalDistribute[DistributionSpecGather] --------PhysicalTopN[LOCAL_SORT] ----------PhysicalProject -------------filter((if(( not (avg_monthly_sales = 0.0000)), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +------------filter(( not (avg_monthly_sales = 0.0000)) and ((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000)) --------------PhysicalWindow ----------------PhysicalQuickSort[LOCAL_SORT] ------------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query21.out b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query21.out index 991b448adf9f0c..6c42c4a588a411 100644 --- a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query21.out +++ b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query21.out @@ -4,7 +4,7 @@ PhysicalResultSink --PhysicalTopN[MERGE_SORT] ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] ---------filter((if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE))) +--------filter(((cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE)) and (x.inv_before > 0)) ----------hashAgg[GLOBAL] ------------PhysicalDistribute[DistributionSpecHash] --------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query34.out b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query34.out index b1fdb6d566299a..e15f034b260bfe 100644 --- a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query34.out +++ b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query34.out @@ -5,9 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalQuickSort[LOCAL_SORT] --------PhysicalProject -----------hashJoin[INNER_JOIN shuffleBucket] hashCondition=((dn.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 ss_customer_sk->[c_customer_sk] -------------PhysicalProject ---------------PhysicalOlapScan[customer] apply RFs: RF3 +----------hashJoin[INNER_JOIN bucketShuffle] hashCondition=((dn.ss_customer_sk = customer.c_customer_sk)) otherCondition=() ------------filter((dn.cnt <= 20) and (dn.cnt >= 15)) --------------hashAgg[GLOBAL] ----------------PhysicalDistribute[DistributionSpecHash] @@ -15,18 +13,20 @@ PhysicalResultSink --------------------PhysicalProject ----------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_store_sk = store.s_store_sk)) otherCondition=() build RFs:RF2 s_store_sk->[ss_store_sk] ------------------------PhysicalProject ---------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[ss_sold_date_sk] +--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=() build RFs:RF1 hd_demo_sk->[ss_hdemo_sk] ----------------------------PhysicalProject -------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=() build RFs:RF0 hd_demo_sk->[ss_hdemo_sk] +------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ss_sold_date_sk] --------------------------------PhysicalProject ----------------------------------PhysicalOlapScan[store_sales] apply RFs: RF0 RF1 RF2 --------------------------------PhysicalProject -----------------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.2) and hd_buy_potential IN ('0-500', '1001-5000')) -------------------------------------PhysicalOlapScan[household_demographics] +----------------------------------filter((date_dim.d_dom <= 28) and (date_dim.d_dom >= 1) and OR[(date_dim.d_dom <= 3),(date_dim.d_dom >= 25)] and d_year IN (1998, 1999, 2000)) +------------------------------------PhysicalOlapScan[date_dim] ----------------------------PhysicalProject -------------------------------filter((date_dim.d_dom <= 28) and (date_dim.d_dom >= 1) and OR[(date_dim.d_dom <= 3),(date_dim.d_dom >= 25)] and d_year IN (1998, 1999, 2000)) ---------------------------------PhysicalOlapScan[date_dim] +------------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.2) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('0-500', '1001-5000')) +--------------------------------PhysicalOlapScan[household_demographics] ------------------------PhysicalProject --------------------------filter(s_county IN ('Barrow County', 'Daviess County', 'Franklin Parish', 'Luce County', 'Richland County', 'Walker County', 'Williamson County', 'Ziebach County')) ----------------------------PhysicalOlapScan[store] +------------PhysicalProject +--------------PhysicalOlapScan[customer] diff --git a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query39.out b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query39.out index 899b1a5e0bdd99..157c57a9772d08 100644 --- a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query39.out +++ b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query39.out @@ -3,7 +3,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) ----PhysicalProject -------filter((if((mean = 0.0), 0.0, (stdev / mean)) > 1.0)) +------filter(( not (mean = 0.0)) and ((foo.stdev / foo.mean) > 1.0)) --------hashAgg[GLOBAL] ----------PhysicalDistribute[DistributionSpecHash] ------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query47.out b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query47.out index 0a70cbcf51c3a2..28b894140cde27 100644 --- a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query47.out +++ b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query47.out @@ -37,7 +37,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1)) and (v1.s_company_name = v1_lag.s_company_name) and (v1.s_store_name = v1_lag.s_store_name)) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 s_store_name->[s_store_name];RF6 s_company_name->[s_company_name];RF7 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 RF7 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) ----------------PhysicalProject ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query53.out b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query53.out index 89dc632eb527c4..2a4f8114325899 100644 --- a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query53.out +++ b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query53.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_quarterly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_quarterly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_quarterly_sales) > 0.100000) and (tmp1.avg_quarterly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query57.out b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query57.out index ca1f63bfb07616..04ecacce598a56 100644 --- a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query57.out +++ b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query57.out @@ -37,7 +37,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.cc_name = v1_lag.cc_name) and (v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1))) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 cc_name->[cc_name];RF6 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) ----------------PhysicalProject ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query63.out b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query63.out index 9653f6c52199aa..0f7bde97f68765 100644 --- a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query63.out +++ b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query63.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000) and (tmp1.avg_monthly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query73.out b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query73.out index bfc42f79bbc570..a0f08ce747aabb 100644 --- a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query73.out +++ b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query73.out @@ -24,7 +24,7 @@ PhysicalResultSink ----------------------------------filter((date_dim.d_dom <= 2) and (date_dim.d_dom >= 1) and d_year IN (2000, 2001, 2002)) ------------------------------------PhysicalOlapScan[date_dim] ----------------------------PhysicalProject -------------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.0) and hd_buy_potential IN ('501-1000', 'Unknown')) +------------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.0) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('501-1000', 'Unknown')) --------------------------------PhysicalOlapScan[household_demographics] ------------------------PhysicalProject --------------------------filter(s_county IN ('Barrow County', 'Daviess County', 'Fairfield County', 'Walker County')) diff --git a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query89.out b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query89.out index 8b0e89ca99b508..6f43c2b24cdc99 100644 --- a/regression-test/data/shape_check/tpcds_sf100/rf_prune/query89.out +++ b/regression-test/data/shape_check/tpcds_sf100/rf_prune/query89.out @@ -6,7 +6,7 @@ PhysicalResultSink ------PhysicalDistribute[DistributionSpecGather] --------PhysicalTopN[LOCAL_SORT] ----------PhysicalProject -------------filter((if(( not (avg_monthly_sales = 0.0000)), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +------------filter(( not (avg_monthly_sales = 0.0000)) and ((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000)) --------------PhysicalWindow ----------------PhysicalQuickSort[LOCAL_SORT] ------------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf100/shape/query21.out b/regression-test/data/shape_check/tpcds_sf100/shape/query21.out index e80000c6353128..187958e5af2ee4 100644 --- a/regression-test/data/shape_check/tpcds_sf100/shape/query21.out +++ b/regression-test/data/shape_check/tpcds_sf100/shape/query21.out @@ -4,7 +4,7 @@ PhysicalResultSink --PhysicalTopN[MERGE_SORT] ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] ---------filter((if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE))) +--------filter(((cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE)) and (x.inv_before > 0)) ----------hashAgg[GLOBAL] ------------PhysicalDistribute[DistributionSpecHash] --------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf100/shape/query34.out b/regression-test/data/shape_check/tpcds_sf100/shape/query34.out index b1fdb6d566299a..96b175bf9c86c8 100644 --- a/regression-test/data/shape_check/tpcds_sf100/shape/query34.out +++ b/regression-test/data/shape_check/tpcds_sf100/shape/query34.out @@ -5,9 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalQuickSort[LOCAL_SORT] --------PhysicalProject -----------hashJoin[INNER_JOIN shuffleBucket] hashCondition=((dn.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 ss_customer_sk->[c_customer_sk] -------------PhysicalProject ---------------PhysicalOlapScan[customer] apply RFs: RF3 +----------hashJoin[INNER_JOIN bucketShuffle] hashCondition=((dn.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[ss_customer_sk] ------------filter((dn.cnt <= 20) and (dn.cnt >= 15)) --------------hashAgg[GLOBAL] ----------------PhysicalDistribute[DistributionSpecHash] @@ -15,18 +13,20 @@ PhysicalResultSink --------------------PhysicalProject ----------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_store_sk = store.s_store_sk)) otherCondition=() build RFs:RF2 s_store_sk->[ss_store_sk] ------------------------PhysicalProject ---------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[ss_sold_date_sk] +--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=() build RFs:RF1 hd_demo_sk->[ss_hdemo_sk] ----------------------------PhysicalProject -------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=() build RFs:RF0 hd_demo_sk->[ss_hdemo_sk] +------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ss_sold_date_sk] --------------------------------PhysicalProject -----------------------------------PhysicalOlapScan[store_sales] apply RFs: RF0 RF1 RF2 +----------------------------------PhysicalOlapScan[store_sales] apply RFs: RF0 RF1 RF2 RF3 --------------------------------PhysicalProject -----------------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.2) and hd_buy_potential IN ('0-500', '1001-5000')) -------------------------------------PhysicalOlapScan[household_demographics] +----------------------------------filter((date_dim.d_dom <= 28) and (date_dim.d_dom >= 1) and OR[(date_dim.d_dom <= 3),(date_dim.d_dom >= 25)] and d_year IN (1998, 1999, 2000)) +------------------------------------PhysicalOlapScan[date_dim] ----------------------------PhysicalProject -------------------------------filter((date_dim.d_dom <= 28) and (date_dim.d_dom >= 1) and OR[(date_dim.d_dom <= 3),(date_dim.d_dom >= 25)] and d_year IN (1998, 1999, 2000)) ---------------------------------PhysicalOlapScan[date_dim] +------------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.2) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('0-500', '1001-5000')) +--------------------------------PhysicalOlapScan[household_demographics] ------------------------PhysicalProject --------------------------filter(s_county IN ('Barrow County', 'Daviess County', 'Franklin Parish', 'Luce County', 'Richland County', 'Walker County', 'Williamson County', 'Ziebach County')) ----------------------------PhysicalOlapScan[store] +------------PhysicalProject +--------------PhysicalOlapScan[customer] diff --git a/regression-test/data/shape_check/tpcds_sf100/shape/query39.out b/regression-test/data/shape_check/tpcds_sf100/shape/query39.out index b7ca740e55c672..f43f8407d82f31 100644 --- a/regression-test/data/shape_check/tpcds_sf100/shape/query39.out +++ b/regression-test/data/shape_check/tpcds_sf100/shape/query39.out @@ -3,7 +3,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) ----PhysicalProject -------filter((if((mean = 0.0), 0.0, (stdev / mean)) > 1.0)) +------filter(( not (mean = 0.0)) and ((foo.stdev / foo.mean) > 1.0)) --------hashAgg[GLOBAL] ----------PhysicalDistribute[DistributionSpecHash] ------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf100/shape/query47.out b/regression-test/data/shape_check/tpcds_sf100/shape/query47.out index fba9743e0f2353..d1e34966d3dedc 100644 --- a/regression-test/data/shape_check/tpcds_sf100/shape/query47.out +++ b/regression-test/data/shape_check/tpcds_sf100/shape/query47.out @@ -37,7 +37,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1)) and (v1.s_company_name = v1_lag.s_company_name) and (v1.s_store_name = v1_lag.s_store_name)) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 s_store_name->[s_store_name];RF6 s_company_name->[s_company_name];RF7 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 RF7 RF8 RF9 RF10 RF11 RF12 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF8 RF9 RF10 RF11 RF12 ----------------PhysicalProject ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf100/shape/query53.out b/regression-test/data/shape_check/tpcds_sf100/shape/query53.out index 04920e65ac6894..5a591d64cc5d6d 100644 --- a/regression-test/data/shape_check/tpcds_sf100/shape/query53.out +++ b/regression-test/data/shape_check/tpcds_sf100/shape/query53.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_quarterly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_quarterly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_quarterly_sales) > 0.100000) and (tmp1.avg_quarterly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/shape_check/tpcds_sf100/shape/query57.out b/regression-test/data/shape_check/tpcds_sf100/shape/query57.out index 152c2c884779f7..1686241e4651f1 100644 --- a/regression-test/data/shape_check/tpcds_sf100/shape/query57.out +++ b/regression-test/data/shape_check/tpcds_sf100/shape/query57.out @@ -37,7 +37,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.cc_name = v1_lag.cc_name) and (v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1))) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 cc_name->[cc_name];RF6 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 RF7 RF8 RF9 RF10 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF7 RF8 RF9 RF10 ----------------PhysicalProject ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf100/shape/query63.out b/regression-test/data/shape_check/tpcds_sf100/shape/query63.out index d4fb4990da98b8..5bff62d9cae36d 100644 --- a/regression-test/data/shape_check/tpcds_sf100/shape/query63.out +++ b/regression-test/data/shape_check/tpcds_sf100/shape/query63.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000) and (tmp1.avg_monthly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/shape_check/tpcds_sf100/shape/query73.out b/regression-test/data/shape_check/tpcds_sf100/shape/query73.out index bfc42f79bbc570..a0f08ce747aabb 100644 --- a/regression-test/data/shape_check/tpcds_sf100/shape/query73.out +++ b/regression-test/data/shape_check/tpcds_sf100/shape/query73.out @@ -24,7 +24,7 @@ PhysicalResultSink ----------------------------------filter((date_dim.d_dom <= 2) and (date_dim.d_dom >= 1) and d_year IN (2000, 2001, 2002)) ------------------------------------PhysicalOlapScan[date_dim] ----------------------------PhysicalProject -------------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.0) and hd_buy_potential IN ('501-1000', 'Unknown')) +------------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.0) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('501-1000', 'Unknown')) --------------------------------PhysicalOlapScan[household_demographics] ------------------------PhysicalProject --------------------------filter(s_county IN ('Barrow County', 'Daviess County', 'Fairfield County', 'Walker County')) diff --git a/regression-test/data/shape_check/tpcds_sf100/shape/query89.out b/regression-test/data/shape_check/tpcds_sf100/shape/query89.out index 552fbbd3aaf5ae..e3b2b344232b68 100644 --- a/regression-test/data/shape_check/tpcds_sf100/shape/query89.out +++ b/regression-test/data/shape_check/tpcds_sf100/shape/query89.out @@ -6,7 +6,7 @@ PhysicalResultSink ------PhysicalDistribute[DistributionSpecGather] --------PhysicalTopN[LOCAL_SORT] ----------PhysicalProject -------------filter((if(( not (avg_monthly_sales = 0.0000)), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +------------filter(( not (avg_monthly_sales = 0.0000)) and ((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000)) --------------PhysicalWindow ----------------PhysicalQuickSort[LOCAL_SORT] ------------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf1000/hint/query21.out b/regression-test/data/shape_check/tpcds_sf1000/hint/query21.out index b3f351af394272..cce40a238b383a 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/hint/query21.out +++ b/regression-test/data/shape_check/tpcds_sf1000/hint/query21.out @@ -4,7 +4,7 @@ PhysicalResultSink --PhysicalTopN[MERGE_SORT] ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] ---------filter((if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE))) +--------filter(((cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE)) and (x.inv_before > 0)) ----------hashAgg[GLOBAL] ------------PhysicalDistribute[DistributionSpecHash] --------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf1000/hint/query34.out b/regression-test/data/shape_check/tpcds_sf1000/hint/query34.out index 846a8160171a4d..5cc9aba0108680 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/hint/query34.out +++ b/regression-test/data/shape_check/tpcds_sf1000/hint/query34.out @@ -28,7 +28,7 @@ PhysicalResultSink --------------------------------filter((date_dim.d_dom <= 28) and (date_dim.d_dom >= 1) and OR[(date_dim.d_dom <= 3),(date_dim.d_dom >= 25)] and d_year IN (2000, 2001, 2002)) ----------------------------------PhysicalOlapScan[date_dim] --------------------------PhysicalProject -----------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.2) and hd_buy_potential IN ('0-500', '1001-5000')) +----------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.2) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('0-500', '1001-5000')) ------------------------------PhysicalOlapScan[household_demographics] Hint log: diff --git a/regression-test/data/shape_check/tpcds_sf1000/hint/query39.out b/regression-test/data/shape_check/tpcds_sf1000/hint/query39.out index 064bbe33338a8d..8ad91b8ce2e54d 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/hint/query39.out +++ b/regression-test/data/shape_check/tpcds_sf1000/hint/query39.out @@ -3,7 +3,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) ----PhysicalProject -------filter((if((mean = 0.0), 0.0, (stdev / mean)) > 1.0)) +------filter(( not (mean = 0.0)) and ((foo.stdev / foo.mean) > 1.0)) --------hashAgg[GLOBAL] ----------PhysicalDistribute[DistributionSpecHash] ------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf1000/hint/query47.out b/regression-test/data/shape_check/tpcds_sf1000/hint/query47.out index 9a2258d852cebf..8dabd618a637a5 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/hint/query47.out +++ b/regression-test/data/shape_check/tpcds_sf1000/hint/query47.out @@ -37,7 +37,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1)) and (v1.s_company_name = v1_lag.s_company_name) and (v1.s_store_name = v1_lag.s_store_name)) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 s_store_name->[s_store_name];RF6 s_company_name->[s_company_name];RF7 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 RF7 RF8 RF9 RF10 RF11 RF12 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2000)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2000)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF8 RF9 RF10 RF11 RF12 ----------------PhysicalProject ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf1000/hint/query53.out b/regression-test/data/shape_check/tpcds_sf1000/hint/query53.out index 1a8edda84fe3f4..db100ec66db0d0 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/hint/query53.out +++ b/regression-test/data/shape_check/tpcds_sf1000/hint/query53.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_quarterly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_quarterly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_quarterly_sales) > 0.100000) and (tmp1.avg_quarterly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/shape_check/tpcds_sf1000/hint/query57.out b/regression-test/data/shape_check/tpcds_sf1000/hint/query57.out index a2a31216a3d458..6347cd590b5947 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/hint/query57.out +++ b/regression-test/data/shape_check/tpcds_sf1000/hint/query57.out @@ -37,7 +37,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.cc_name = v1_lag.cc_name) and (v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1))) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 cc_name->[cc_name];RF6 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 RF7 RF8 RF9 RF10 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF7 RF8 RF9 RF10 ----------------PhysicalProject ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf1000/hint/query63.out b/regression-test/data/shape_check/tpcds_sf1000/hint/query63.out index 1ca1f9d9304401..f459c41f9bd4df 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/hint/query63.out +++ b/regression-test/data/shape_check/tpcds_sf1000/hint/query63.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000) and (tmp1.avg_monthly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/shape_check/tpcds_sf1000/hint/query73.out b/regression-test/data/shape_check/tpcds_sf1000/hint/query73.out index 4deeef486cf33e..1165b3341a5602 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/hint/query73.out +++ b/regression-test/data/shape_check/tpcds_sf1000/hint/query73.out @@ -28,7 +28,7 @@ PhysicalResultSink --------------------------------filter((store.s_county = 'Williamson County')) ----------------------------------PhysicalOlapScan[store] --------------------------PhysicalProject -----------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.0) and hd_buy_potential IN ('1001-5000', '5001-10000')) +----------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.0) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('1001-5000', '5001-10000')) ------------------------------PhysicalOlapScan[household_demographics] Hint log: diff --git a/regression-test/data/shape_check/tpcds_sf1000/hint/query89.out b/regression-test/data/shape_check/tpcds_sf1000/hint/query89.out index a1e8e771fe13f5..f090e2f10febb9 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/hint/query89.out +++ b/regression-test/data/shape_check/tpcds_sf1000/hint/query89.out @@ -6,7 +6,7 @@ PhysicalResultSink ------PhysicalDistribute[DistributionSpecGather] --------PhysicalTopN[LOCAL_SORT] ----------PhysicalProject -------------filter((if(( not (avg_monthly_sales = 0.0000)), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +------------filter(( not (avg_monthly_sales = 0.0000)) and ((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000)) --------------PhysicalWindow ----------------PhysicalQuickSort[LOCAL_SORT] ------------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf1000/shape/query21.out b/regression-test/data/shape_check/tpcds_sf1000/shape/query21.out index f68b978b0b2ba6..7aaa027dd961ca 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/shape/query21.out +++ b/regression-test/data/shape_check/tpcds_sf1000/shape/query21.out @@ -4,7 +4,7 @@ PhysicalResultSink --PhysicalTopN[MERGE_SORT] ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] ---------filter((if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE))) +--------filter(((cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE)) and (x.inv_before > 0)) ----------hashAgg[GLOBAL] ------------PhysicalDistribute[DistributionSpecHash] --------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf1000/shape/query34.out b/regression-test/data/shape_check/tpcds_sf1000/shape/query34.out index c75d4fc3e18155..b9f187a3aaabee 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/shape/query34.out +++ b/regression-test/data/shape_check/tpcds_sf1000/shape/query34.out @@ -13,9 +13,9 @@ PhysicalResultSink ----------------PhysicalDistribute[DistributionSpecHash] ------------------hashAgg[LOCAL] --------------------PhysicalProject -----------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[ss_sold_date_sk] +----------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=() build RFs:RF2 hd_demo_sk->[ss_hdemo_sk] ------------------------PhysicalProject ---------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)) otherCondition=() build RFs:RF1 hd_demo_sk->[ss_hdemo_sk] +--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[ss_sold_date_sk] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_store_sk = store.s_store_sk)) otherCondition=() build RFs:RF0 s_store_sk->[ss_store_sk] --------------------------------PhysicalProject @@ -24,9 +24,9 @@ PhysicalResultSink ----------------------------------filter((store.s_county = 'Williamson County')) ------------------------------------PhysicalOlapScan[store] ----------------------------PhysicalProject -------------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.2) and hd_buy_potential IN ('0-500', '1001-5000')) ---------------------------------PhysicalOlapScan[household_demographics] +------------------------------filter((date_dim.d_dom <= 28) and (date_dim.d_dom >= 1) and OR[(date_dim.d_dom <= 3),(date_dim.d_dom >= 25)] and d_year IN (2000, 2001, 2002)) +--------------------------------PhysicalOlapScan[date_dim] ------------------------PhysicalProject ---------------------------filter((date_dim.d_dom <= 28) and (date_dim.d_dom >= 1) and OR[(date_dim.d_dom <= 3),(date_dim.d_dom >= 25)] and d_year IN (2000, 2001, 2002)) -----------------------------PhysicalOlapScan[date_dim] +--------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.2) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('0-500', '1001-5000')) +----------------------------PhysicalOlapScan[household_demographics] diff --git a/regression-test/data/shape_check/tpcds_sf1000/shape/query39.out b/regression-test/data/shape_check/tpcds_sf1000/shape/query39.out index 7b00628d966265..94a93dbec9a898 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/shape/query39.out +++ b/regression-test/data/shape_check/tpcds_sf1000/shape/query39.out @@ -3,7 +3,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) ----PhysicalProject -------filter((if((mean = 0.0), 0.0, (stdev / mean)) > 1.0)) +------filter(( not (mean = 0.0)) and ((foo.stdev / foo.mean) > 1.0)) --------hashAgg[GLOBAL] ----------PhysicalDistribute[DistributionSpecHash] ------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf1000/shape/query47.out b/regression-test/data/shape_check/tpcds_sf1000/shape/query47.out index 0428d0e8670918..411afc9057d077 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/shape/query47.out +++ b/regression-test/data/shape_check/tpcds_sf1000/shape/query47.out @@ -37,7 +37,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1)) and (v1.s_company_name = v1_lag.s_company_name) and (v1.s_store_name = v1_lag.s_store_name)) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 s_store_name->[s_store_name];RF6 s_company_name->[s_company_name];RF7 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 RF7 RF8 RF9 RF10 RF11 RF12 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2000)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2000)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF8 RF9 RF10 RF11 RF12 ----------------PhysicalProject ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf1000/shape/query53.out b/regression-test/data/shape_check/tpcds_sf1000/shape/query53.out index d2467a65e93e09..f5db2a514a2d8f 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/shape/query53.out +++ b/regression-test/data/shape_check/tpcds_sf1000/shape/query53.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_quarterly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_quarterly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_quarterly_sales) > 0.100000) and (tmp1.avg_quarterly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/shape_check/tpcds_sf1000/shape/query57.out b/regression-test/data/shape_check/tpcds_sf1000/shape/query57.out index 00c01451579574..593429bacdcaab 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/shape/query57.out +++ b/regression-test/data/shape_check/tpcds_sf1000/shape/query57.out @@ -37,7 +37,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.cc_name = v1_lag.cc_name) and (v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1))) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 cc_name->[cc_name];RF6 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 RF7 RF8 RF9 RF10 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF7 RF8 RF9 RF10 ----------------PhysicalProject ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf1000/shape/query63.out b/regression-test/data/shape_check/tpcds_sf1000/shape/query63.out index bbbb80bc4b68e0..7807b765fc36d4 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/shape/query63.out +++ b/regression-test/data/shape_check/tpcds_sf1000/shape/query63.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000) and (tmp1.avg_monthly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalDistribute[DistributionSpecHash] diff --git a/regression-test/data/shape_check/tpcds_sf1000/shape/query73.out b/regression-test/data/shape_check/tpcds_sf1000/shape/query73.out index 52c88ab966b1c9..6b13afc6aa87cd 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/shape/query73.out +++ b/regression-test/data/shape_check/tpcds_sf1000/shape/query73.out @@ -27,6 +27,6 @@ PhysicalResultSink ------------------------------filter((store.s_county = 'Williamson County')) --------------------------------PhysicalOlapScan[store] ------------------------PhysicalProject ---------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.0) and hd_buy_potential IN ('1001-5000', '5001-10000')) +--------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.0) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('1001-5000', '5001-10000')) ----------------------------PhysicalOlapScan[household_demographics] diff --git a/regression-test/data/shape_check/tpcds_sf1000/shape/query89.out b/regression-test/data/shape_check/tpcds_sf1000/shape/query89.out index e18adf036aa80e..6ddd6f367f06b1 100644 --- a/regression-test/data/shape_check/tpcds_sf1000/shape/query89.out +++ b/regression-test/data/shape_check/tpcds_sf1000/shape/query89.out @@ -6,7 +6,7 @@ PhysicalResultSink ------PhysicalDistribute[DistributionSpecGather] --------PhysicalTopN[LOCAL_SORT] ----------PhysicalProject -------------filter((if(( not (avg_monthly_sales = 0.0000)), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +------------filter(( not (avg_monthly_sales = 0.0000)) and ((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000)) --------------PhysicalWindow ----------------PhysicalQuickSort[LOCAL_SORT] ------------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query21.out b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query21.out index 31448491385145..91dd8ce1e25c76 100644 --- a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query21.out +++ b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query21.out @@ -4,7 +4,7 @@ PhysicalResultSink --PhysicalTopN[MERGE_SORT] ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] ---------filter((if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE))) +--------filter(((cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)) <= 1.5) and (if((inv_before > 0), (cast(inv_after as DOUBLE) / cast(inv_before as DOUBLE)), NULL) >= cast((2.000000 / 3.0) as DOUBLE)) and (x.inv_before > 0)) ----------hashAgg[GLOBAL] ------------PhysicalDistribute[DistributionSpecHash] --------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query34.out b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query34.out index f5c833d21c1256..a8d9933686ff75 100644 --- a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query34.out +++ b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query34.out @@ -25,7 +25,7 @@ PhysicalResultSink ------------------------------filter(s_county IN ('Arthur County', 'Halifax County', 'Lunenburg County', 'Oglethorpe County', 'Perry County', 'Salem County', 'Sumner County', 'Terrell County')) --------------------------------PhysicalOlapScan[store] ------------------------PhysicalProject ---------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.2) and hd_buy_potential IN ('>10000', 'Unknown')) +--------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.2) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('>10000', 'Unknown')) ----------------------------PhysicalOlapScan[household_demographics] ------------PhysicalProject --------------PhysicalOlapScan[customer] diff --git a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query39.out b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query39.out index 81c822f42d871d..a997e4f9259337 100644 --- a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query39.out +++ b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query39.out @@ -3,7 +3,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) --PhysicalCteProducer ( cteId=CTEId#0 ) ----PhysicalProject -------filter((if((mean = 0.0), 0.0, (stdev / mean)) > 1.0)) +------filter(( not (mean = 0.0)) and ((foo.stdev / foo.mean) > 1.0)) --------hashAgg[GLOBAL] ----------PhysicalDistribute[DistributionSpecHash] ------------hashAgg[LOCAL] diff --git a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query47.out b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query47.out index d1c72cdf6f5766..08961666074d35 100644 --- a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query47.out +++ b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query47.out @@ -39,6 +39,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1)) and (v1.s_company_name = v1_lag.s_company_name) and (v1.s_store_name = v1_lag.s_store_name)) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 s_store_name->[s_store_name];RF6 s_company_name->[s_company_name];RF7 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 RF7 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query53.out b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query53.out index 08f3ba6e090871..20dce00d33c7c9 100644 --- a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query53.out +++ b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query53.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_quarterly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_quarterly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_quarterly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_quarterly_sales) > 0.100000) and (tmp1.avg_quarterly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query57.out b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query57.out index 697bd284f5701e..9a6c48bb489f4b 100644 --- a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query57.out +++ b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query57.out @@ -39,6 +39,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------hashJoin[INNER_JOIN shuffle] hashCondition=((v1.cc_name = v1_lag.cc_name) and (v1.i_brand = v1_lag.i_brand) and (v1.i_category = v1_lag.i_category) and (v1.rn = expr_(rn + 1))) otherCondition=() build RFs:RF3 i_category->[i_category];RF4 i_brand->[i_brand];RF5 cc_name->[cc_name];RF6 rn->[(rn + 1)] --------------------PhysicalProject ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF3 RF4 RF5 RF6 ---------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) +--------------------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / v2.avg_monthly_sales) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999)) ----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query63.out b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query63.out index 1ad6a72b1612f8..22b94a0bb8a827 100644 --- a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query63.out +++ b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query63.out @@ -5,7 +5,7 @@ PhysicalResultSink ----PhysicalDistribute[DistributionSpecGather] ------PhysicalTopN[LOCAL_SORT] --------PhysicalProject -----------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +----------filter(((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000) and (tmp1.avg_monthly_sales > 0.0000)) ------------PhysicalWindow --------------PhysicalQuickSort[LOCAL_SORT] ----------------PhysicalProject diff --git a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query73.out b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query73.out index 70d93fd9695e48..09e8cd80f2ce53 100644 --- a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query73.out +++ b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query73.out @@ -25,7 +25,7 @@ PhysicalResultSink ------------------------------filter(s_county IN ('Bronx County', 'Furnas County', 'Lea County', 'Pennington County')) --------------------------------PhysicalOlapScan[store] ------------------------PhysicalProject ---------------------------filter((household_demographics.hd_vehicle_count > 0) and (if((hd_vehicle_count > 0), (cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)), NULL) > 1.0) and hd_buy_potential IN ('5001-10000', '>10000')) +--------------------------filter(((cast(hd_dep_count as DOUBLE) / cast(hd_vehicle_count as DOUBLE)) > 1.0) and (household_demographics.hd_vehicle_count > 0) and hd_buy_potential IN ('5001-10000', '>10000')) ----------------------------PhysicalOlapScan[household_demographics] ------------PhysicalProject --------------PhysicalOlapScan[customer] diff --git a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query89.out b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query89.out index db4f8eeac80499..b17c9a89da3b1d 100644 --- a/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query89.out +++ b/regression-test/data/shape_check/tpcds_sf10t_orc/shape/query89.out @@ -6,7 +6,7 @@ PhysicalResultSink ------PhysicalDistribute[DistributionSpecGather] --------PhysicalTopN[LOCAL_SORT] ----------PhysicalProject -------------filter((if(( not (avg_monthly_sales = 0.0000)), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000)) +------------filter(( not (avg_monthly_sales = 0.0000)) and ((cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / tmp1.avg_monthly_sales) > 0.100000)) --------------PhysicalWindow ----------------PhysicalQuickSort[LOCAL_SORT] ------------------PhysicalProject diff --git a/regression-test/suites/external_table_p0/jdbc/test_mysql_jdbc_catalog.groovy b/regression-test/suites/external_table_p0/jdbc/test_mysql_jdbc_catalog.groovy index 2235c4b016a418..5f918658b73110 100644 --- a/regression-test/suites/external_table_p0/jdbc/test_mysql_jdbc_catalog.groovy +++ b/regression-test/suites/external_table_p0/jdbc/test_mysql_jdbc_catalog.groovy @@ -394,18 +394,18 @@ suite("test_mysql_jdbc_catalog", "p0,external,mysql,external_docker,external_doc contains "QUERY: SELECT `timestamp0` FROM `doris_test`.`dt` WHERE (`timestamp0` > '2022-01-01 00:00:00')" } explain { - sql ("select k6, k8 from test1 where nvl(k6, null) = 1;") + sql ("select k6, k8 from test1 where nvl(k6, 1) = k6;") - contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(`k6`, NULL) = 1))" + contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(`k6`, 1) = `k6`))" } explain { - sql ("select k6, k8 from test1 where nvl(nvl(k6, null),null) = 1;") + sql ("select k6, k8 from test1 where nvl(k6, nvl(k6, 1)) = k6;") - contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(ifnull(`k6`, NULL), NULL) = 1))" + contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((ifnull(`k6`, ifnull(`k6`, 1)) = `k6`))" } sql """ set enable_ext_func_pred_pushdown = "false"; """ explain { - sql ("select k6, k8 from test1 where nvl(k6, null) = 1 and k8 = 1;") + sql ("select k6, k8 from test1 where nvl(k6, 1) = k6 and k8 = 1;") contains "QUERY: SELECT `k6`, `k8` FROM `doris_test`.`test1` WHERE ((`k8` = 1))" } diff --git a/regression-test/suites/external_table_p0/jdbc/test_oracle_jdbc_catalog.groovy b/regression-test/suites/external_table_p0/jdbc/test_oracle_jdbc_catalog.groovy index 27a9249d10e8c8..75bd21bc186c50 100644 --- a/regression-test/suites/external_table_p0/jdbc/test_oracle_jdbc_catalog.groovy +++ b/regression-test/suites/external_table_p0/jdbc/test_oracle_jdbc_catalog.groovy @@ -202,8 +202,8 @@ suite("test_oracle_jdbc_catalog", "p0,external,oracle,external_docker,external_d // test nvl explain { - sql("SELECT * FROM STUDENT WHERE nvl(score, 0) < 95;") - contains """SELECT "ID", "NAME", "AGE", "SCORE" FROM "DORIS_TEST"."STUDENT" WHERE ((nvl("SCORE", 0.0) < 95.0))""" + sql("SELECT * FROM STUDENT WHERE nvl(score, 0) < score;") + contains """SELECT "ID", "NAME", "AGE", "SCORE" FROM "DORIS_TEST"."STUDENT" WHERE ((nvl("SCORE", 0.0) < "SCORE"))""" } order_qt_raw """ select * from TEST_RAW order by ID; """ @@ -417,38 +417,38 @@ suite("test_oracle_jdbc_catalog", "p0,external,oracle,external_docker,external_d sql "use oracle_function_rules.DORIS_TEST" explain { - sql """select id from STUDENT where abs(id) > 0 and ifnull(id, 3) = 3;""" - contains """QUERY: SELECT "ID" FROM "DORIS_TEST"."STUDENT" WHERE ((abs("ID") > 0)) AND ((nvl("ID", 3) = 3))""" - contains """PREDICATES: ((abs(ID[#0]) > 0) AND (ifnull(ID[#0], 3) = 3))""" + sql """select id from STUDENT where abs(id) > 0 and nvl(id, 3) = id;""" + contains """QUERY: SELECT "ID" FROM "DORIS_TEST"."STUDENT" WHERE ((abs("ID") > 0)) AND ((nvl("ID", 3) = "ID"))""" + contains """PREDICATES: ((abs(ID[#0]) > 0) AND (ifnull(ID[#0], 3) = ID[#0]))""" } sql """alter catalog oracle_function_rules set properties("function_rules" = '');""" explain { - sql """select id from STUDENT where abs(id) > 0 and ifnull(id, 3) = 3;""" - contains """QUERY: SELECT "ID" FROM "DORIS_TEST"."STUDENT" WHERE ((nvl("ID", 3) = 3))""" - contains """PREDICATES: ((abs(ID[#0]) > 0) AND (ifnull(ID[#0], 3) = 3))""" + sql """select id from STUDENT where abs(id) > 0 and ifnull(id, 3) = id;""" + contains """QUERY: SELECT "ID" FROM "DORIS_TEST"."STUDENT" WHERE ((nvl("ID", 3) = "ID"))""" + contains """PREDICATES: ((abs(ID[#0]) > 0) AND (ifnull(ID[#0], 3) = ID[#0]))""" } sql """alter catalog oracle_function_rules set properties("function_rules" = '{"pushdown" : {"supported": ["abs"], "unsupported" : []}}')""" explain { - sql """select id from STUDENT where abs(id) > 0 and ifnull(id, 3) = 3;""" - contains """QUERY: SELECT "ID" FROM "DORIS_TEST"."STUDENT" WHERE ((abs("ID") > 0)) AND ((nvl("ID", 3) = 3))""" - contains """PREDICATES: ((abs(ID[#0]) > 0) AND (ifnull(ID[#0], 3) = 3))""" + sql """select id from STUDENT where abs(id) > 0 and ifnull(id, 3) = id;""" + contains """QUERY: SELECT "ID" FROM "DORIS_TEST"."STUDENT" WHERE ((abs("ID") > 0)) AND ((nvl("ID", 3) = "ID"))""" + contains """PREDICATES: ((abs(ID[#0]) > 0) AND (ifnull(ID[#0], 3) = ID[#0]))""" } // test rewrite sql """alter catalog oracle_function_rules set properties("function_rules" = '{"pushdown" : {"supported": ["abs"]}, "rewrite" : {"abs" : "abs2"}}');""" explain { - sql """select id from STUDENT where abs(id) > 0 and ifnull(id, 3) = 3;""" - contains """QUERY: SELECT "ID" FROM "DORIS_TEST"."STUDENT" WHERE ((abs2("ID") > 0)) AND ((nvl("ID", 3) = 3))""" - contains """PREDICATES: ((abs(ID[#0]) > 0) AND (ifnull(ID[#0], 3) = 3))""" + sql """select id from STUDENT where abs(id) > 0 and ifnull(id, 3) = id;""" + contains """QUERY: SELECT "ID" FROM "DORIS_TEST"."STUDENT" WHERE ((abs2("ID") > 0)) AND ((nvl("ID", 3) = "ID"))""" + contains """PREDICATES: ((abs(ID[#0]) > 0) AND (ifnull(ID[#0], 3) = ID[#0]))""" } // reset function rules sql """alter catalog oracle_function_rules set properties("function_rules" = '');""" explain { - sql """select id from STUDENT where abs(id) > 0 and ifnull(id, 3) = 3;""" - contains """QUERY: SELECT "ID" FROM "DORIS_TEST"."STUDENT" WHERE ((nvl("ID", 3) = 3))""" - contains """PREDICATES: ((abs(ID[#0]) > 0) AND (ifnull(ID[#0], 3) = 3))""" + sql """select id from STUDENT where abs(id) > 0 and ifnull(id, 3) = id;""" + contains """QUERY: SELECT "ID" FROM "DORIS_TEST"."STUDENT" WHERE ((nvl("ID", 3) = "ID"))""" + contains """PREDICATES: ((abs(ID[#0]) > 0) AND (ifnull(ID[#0], 3) = ID[#0]))""" } // test invalid config diff --git a/regression-test/suites/mv_p0/where/k123_nereids/k123_nereids.groovy b/regression-test/suites/mv_p0/where/k123_nereids/k123_nereids.groovy index 281e18d68a2f2e..6be708922324c1 100644 --- a/regression-test/suites/mv_p0/where/k123_nereids/k123_nereids.groovy +++ b/regression-test/suites/mv_p0/where/k123_nereids/k123_nereids.groovy @@ -83,7 +83,7 @@ suite ("k123p_nereids") { qt_select_mv_constant """select bitmap_empty() from d_table where true;""" - mv_rewrite_success_without_check_chosen("select k2 from d_table where k1=1 and (k1>2 or k1 < 0) order by k2;", "kwh1") + mv_rewrite_success_without_check_chosen("select k2 from d_table where k1=1 and (k1>2 or k1 * k1 > 10) order by k2;", "kwh1") qt_select_mv "select k2 from d_table where k1=1 and (k1>2 or k1 < 0) order by k2;" diff --git a/regression-test/suites/nereids_p0/sql_functions/conditional_functions/test_nullif.groovy b/regression-test/suites/nereids_p0/sql_functions/conditional_functions/test_nullif.groovy index c4d469543fdeec..77f08fb5e8ca0b 100644 --- a/regression-test/suites/nereids_p0/sql_functions/conditional_functions/test_nullif.groovy +++ b/regression-test/suites/nereids_p0/sql_functions/conditional_functions/test_nullif.groovy @@ -90,9 +90,9 @@ suite("test_nullif") { sql "use nereids_test_query_db" def tableName1 = "test" qt_if_nullif1 """select if(null, -1, 10) a, if(null, "hello", "worlk") b""" - qt_if_nullif2 """select if(k1 > 5, true, false) a from baseall order by k1""" + qt_if_nullif2 """select /*+ SET_VAR(fe_debug=false) */ if(k1 > 5, true, false) a from baseall order by k1""" qt_if_nullif3 """select if(k1, 10, -1) a from baseall order by k1""" - qt_if_nullif4 """select if(length(k6) >= 5, true, false) a from baseall order by k1""" + qt_if_nullif4 """select /*+ SET_VAR(fe_debug=false) */ if(length(k6) >= 5, true, false) a from baseall order by k1""" qt_if_nullif5 """select if(k6 like "fa%", -1, 10) a from baseall order by k6""" qt_if_nullif6 """select if(k6 like "%e", "hello", "world") a from baseall order by k6""" qt_if_nullif7 """select if(k6, -1, 0) a from baseall order by k6""" diff --git a/regression-test/suites/nereids_rules_p0/case_when_rules/test_case_when_rules.groovy b/regression-test/suites/nereids_rules_p0/case_when_rules/test_case_when_rules.groovy new file mode 100644 index 00000000000000..41ec4302eec090 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/case_when_rules/test_case_when_rules.groovy @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_case_when_rules") { + + // this test will invoke many rules, including: + // PUSH_INTO_CASE_WHEN_BRNACH + // CASE_WHEN_TO_COMPOUND_PREDICATE + // NULL_SAFE_EQUAL_TO_EQUAL + // SIMPLIFY_RANGE + + sql """ + SET ignore_shape_nodes='PhysicalDistribute'; + SET runtime_filter_mode=OFF; + SET disable_join_reorder=true; + drop table if exists tbl_test_case_when_rules force; + create table tbl_test_case_when_rules(a bigint, b bigint) properties('replication_num' = '1'); + insert into tbl_test_case_when_rules values(null, 0), (1, 10), (2, 20), (3, 30), (4, 40), (5, 50), (6, 60) + """ + + explainAndOrderResult 'sql_1', ''' + select * + from ( + select a, b, case when a = 1 then 101 when a = 2 then 102 when a = 3 then 103 when a = 4 then 104 end as k + from tbl_test_case_when_rules + ) s + where k = 101 or k = 103 or k = 105 + ''' + + explainAndOrderResult 'sql_2', ''' + select * + from ( + select a, b, case when a = 1 then 101 when a = 2 then 102 when a = 3 then 103 when a = 4 then 104 end as k + from tbl_test_case_when_rules + ) s + where k > 103 + ''' + + explainAndOrderResult 'sql_3', ''' + select * + from ( + select a, b, case when a = 1 then 101 when a = 2 then 102 when a = 3 then 103 when a = 4 then 104 else 107 end as k + from tbl_test_case_when_rules + ) s + where k > 103 + ''' + + explainAndOrderResult 'sql_4', ''' + select * + from ( + select a, b, case when a = 1 then 101 when a = 2 then 102 when a = 3 then 103 when a = 4 then 104 else 107 end as k + from tbl_test_case_when_rules + ) s + where k != 103 + ''' + + explainAndOrderResult 'sql_5', ''' + select * + from ( + select a, b, case when a = 1 then 101 when a = 2 then 102 when a = 3 then 103 when a = 4 then 104 else 107 end as k + from tbl_test_case_when_rules + ) s + where k < 103 + ''' +} diff --git a/regression-test/suites/nereids_rules_p0/expression/test_simplify_range.groovy b/regression-test/suites/nereids_rules_p0/expression/test_simplify_range.groovy index 89cf3581b40d03..e5c5aec3c28eed 100644 --- a/regression-test/suites/nereids_rules_p0/expression/test_simplify_range.groovy +++ b/regression-test/suites/nereids_rules_p0/expression/test_simplify_range.groovy @@ -17,15 +17,35 @@ suite('test_simplify_range') { def tbl_1 = 'test_simplify_range_tbl_1' - sql "set disable_nereids_rules='PRUNE_EMPTY_PARTITION'" + sql ''' + SET ignore_shape_nodes='PhysicalDistribute'; + SET disable_nereids_rules='PRUNE_EMPTY_PARTITION'; + SET runtime_filter_mode=OFF; + SET disable_join_reorder=true; + ''' sql "DROP TABLE IF EXISTS ${tbl_1} FORCE" - sql "CREATE TABLE ${tbl_1}(a DECIMAL(16,8), b INT) PROPERTIES ('replication_num' = '1')" - sql "INSERT INTO ${tbl_1} VALUES(null, 10)" - test { - sql "SELECT a BETWEEN 100.02 and 40.123 OR a IN (54.0402) AND b < 10 FROM ${tbl_1}" - result([[null]]) - } + sql "CREATE TABLE ${tbl_1}(a DECIMAL(16,8), b INT, c bigint) PROPERTIES ('replication_num' = '1')" + sql "INSERT INTO ${tbl_1} VALUES(null, 10, 20), (1, null, null)" + + explainAndOrderResult 'sql_1', """ + SELECT a BETWEEN 100.02 and 40.123 OR a IN (54.0402) AND b < 10 + FROM ${tbl_1} + """ + + explainAndOrderResult 'sql_2', """ + SELECT b, c + FROM ${tbl_1} + WHERE a < 10 or ((a != 1 or a is null) and (a != 2 or a is null)) + """ + + explainAndOrderResult 'sql_3', """ + SELECT b * 0 AS b, SUM(c) + FROM ${tbl_1} + GROUP BY b + HAVING NOT b * 0 > b * 0 * 6; + """ + sql "DROP TABLE IF EXISTS ${tbl_1} FORCE" sql """ diff --git a/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy new file mode 100644 index 00000000000000..708c7fa6eea8c7 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/join_extract_or_from_case_when/join_extract_or_from_case_when.groovy @@ -0,0 +1,201 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite('join_extract_or_from_case_when') { + sql """ + SET disable_nereids_rules='PRUNE_EMPTY_PARTITION'; + SET detail_shape_nodes='PhysicalProject,PhysicalHashAggregate'; + SET ignore_shape_nodes='PhysicalDistribute'; + SET runtime_filter_mode=OFF; + SET disable_join_reorder=true; + DROP TABLE IF EXISTS tbl_join_extract_or_from_case_when_1 FORCE; + DROP TABLE IF EXISTS tbl_join_extract_or_from_case_when_2 FORCE; + CREATE TABLE tbl_join_extract_or_from_case_when_1 (a bigint, b bigint) properties('replication_num' = '1'); + CREATE TABLE tbl_join_extract_or_from_case_when_2 (x bigint, y bigint) properties('replication_num' = '1'); + """ + + qt_case_when_one_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on not((case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b end) + 5 = 100); + """ + + qt_case_when_one_side_2 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on not((case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b end) + 5 = t1.a - t1.b); + """ + + qt_case_when_one_side_3 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on not((case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end) + 5 = 100); + """ + + qt_case_when_one_side_4 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on not((case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end) + 5 = t1.a - t1.b); + """ + + // random will not extract + qt_case_when_one_side_5 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end + random(1, 100) = t1.a - t1.b; + """ + + // the origin expression not contains two sides will not extract + qt_case_when_one_side_6 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on not((case when t1.a > 10 then t1.a when t1.b > 10 then t1.a + t1.b end) + 5 = 100); + """ + + // two case when branch contains both side slots will not rewrite + qt_case_when_one_side_7 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 0 then t1.a when t2.x < 10 then t1.a + 1 end + + case when t2.x > 1 then t1.a + 1 when t2.x < 10 then t1.a + 10 end > t1.a + t1.b + """ + + qt_case_when_one_side_8 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 0 then t1.a when t2.x < 10 then t1.a + 1 end + + case when t1.a > 1 then t1.a + 1 when t1.a < 10 then t1.a + 10 end > t1.a + t1.b + """ + + // case when's result are all literal + qt_case_when_one_side_9 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 0 then 100 when t2.x < 10 then 10000 end + 5 > t1.a + t1.b + """ + + // not push down because not meet push down other requirement + qt_case_when_one_side_10 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 left join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 0 then 100 when t2.x < 10 then 10000 end + 5 > t1.a + t1.b + """ + + // not push down because not meet push down other requirement + qt_case_when_one_side_11 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 right join tbl_join_extract_or_from_case_when_2 t2 + on case when t1.a > 0 then 100 when t1.b < 10 then 10000 end + 5 > t2.x + t2.y + """ + + qt_hash_cond_for_one_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on if(t1.a > 1, 1, 100) = if(t2.x > 2, 1, 200) + """ + + qt_case_when_two_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on (case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end) + 5 = t2.x + t2.y + 1; + """ + + qt_case_when_two_side_2 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on (case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b end) + 5 = t2.x + t2.y + 1; + """ + + // random will not extract + qt_case_when_two_side_3 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end + random(1, 10) = t2.x + t2.y + 1; + """ + + // any case when branch contains slot from both sides will not extract + qt_case_when_two_side_4 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t2.x + t1.b end = t2.x + t2.y + 1; + """ + + // extract the least OR EXPANSION hash condition + qt_case_when_two_side_5 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on if(t1.a > t2.x, t1.a, t1.b) = t2.x - t2.y and (case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else t1.b end) + 5 = t2.x + t2.y + 1; + """ + + qt_case_when_two_side_6 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on case when t1.a is null then t2.x else t2.y end = COALESCE(t1.a, t1.b); + """ + + // any case when branch no contains slot from other side will not extract + qt_case_when_two_side_7 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on (case when t2.x > 10 then t1.a when t2.y > 10 then t1.a + t1.b else 100 end) + 5 = t2.x + t2.y + 1; + """ + + qt_if_one_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on if(t1.a > t2.x, t1.a, t1.b) = t1.a + t1.b; + """ + + qt_if_two_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on if(t1.a > t2.x, t1.a, t1.b) = t2.x + t2.y; + """ + + // in fact, IFNULL will nerver rewrite becase the rule require the + // case when expression contains both side slots and all the branch results contains only one side slots. + qt_ifnull_one_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on ifnull(t1.a, t2.x) = t1.a + t1.b; + """ + + // in fact, IFNULL will nerver rewrite becase the rule require the + // case when expression contains both side slots and all the branch results contains only one side slots. + qt_ifnull_two_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on ifnull(t1.a, t2.x) = t2.x + t2.y; + """ + + qt_nullif_one_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on nullif(t1.a, t2.x) = t1.a + t1.b; + """ + + qt_nullif_two_side_1 """explain shape plan + select t1.a, t2.x + from tbl_join_extract_or_from_case_when_1 t1 join tbl_join_extract_or_from_case_when_2 t2 + on nullif(t1.a, t2.x) = t2.x + t2.y; + """ + + sql """ + DROP TABLE IF EXISTS tbl_join_extract_or_from_case_when_1 FORCE; + DROP TABLE IF EXISTS tbl_join_extract_or_from_case_when_2 FORCE; + """ +} diff --git a/regression-test/suites/nereids_rules_p0/partition_prune/one_col_list_partition.groovy b/regression-test/suites/nereids_rules_p0/partition_prune/one_col_list_partition.groovy index 4de512376d1378..a174d275705759 100644 --- a/regression-test/suites/nereids_rules_p0/partition_prune/one_col_list_partition.groovy +++ b/regression-test/suites/nereids_rules_p0/partition_prune/one_col_list_partition.groovy @@ -224,7 +224,7 @@ suite("one_col_list_partition") { contains("VEMPTYSET") } explain { - sql "SELECT * FROM one_col_list_partition_date WHERE if(a>1, dt<'2001-1-01 00:00:00', dt<'2001-1-01 00:00:00')" + sql "SELECT * FROM one_col_list_partition_date WHERE if(a>1, dt<'2001-1-01 00:00:00', dt<'2001-1-03 00:00:00')" contains("partitions=8/9 (p1,p2,p3,p4,p5,p6,p7,p8)") } explain { @@ -246,4 +246,4 @@ suite("one_col_list_partition") { else '2023-01-01 00:00:00' end <'2021-01-06 00:00:00' ;""" contains("partitions=3/9 (p1,p2,p3)") } -} \ No newline at end of file +} diff --git a/regression-test/suites/query_p0/sql_functions/conditional_functions/test_nullif.groovy b/regression-test/suites/query_p0/sql_functions/conditional_functions/test_nullif.groovy index c3110d69f7aa9e..3a4a2d4d0a95a9 100644 --- a/regression-test/suites/query_p0/sql_functions/conditional_functions/test_nullif.groovy +++ b/regression-test/suites/query_p0/sql_functions/conditional_functions/test_nullif.groovy @@ -88,9 +88,9 @@ suite("test_nullif") { sql "use test_query_db" def tableName1 = "test" qt_if_nullif1 """select if(null, -1, 10) a, if(null, "hello", "worlk") b""" - qt_if_nullif2 """select if(k1 > 5, true, false) a from baseall order by k1""" + qt_if_nullif2 """select /*+ SET_VAR(fe_debug=false) */ if(k1 > 5, true, false) a from baseall order by k1""" qt_if_nullif3 """select if(k1, 10, -1) a from baseall order by k1""" - qt_if_nullif4 """select if(length(k6) >= 5, true, false) a from baseall order by k1""" + qt_if_nullif4 """select /*+ SET_VAR(fe_debug=false) */ if(length(k6) >= 5, true, false) a from baseall order by k1""" qt_if_nullif5 """select if(k6 like "fa%", -1, 10) a from baseall order by k6""" qt_if_nullif6 """select if(k6 like "%e", "hello", "world") a from baseall order by k6""" qt_if_nullif7 """select if(k6, -1, 0) a from baseall order by k6""" diff --git a/regression-test/suites/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.groovy b/regression-test/suites/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.groovy index df2a3087906aff..04273e48403a17 100644 --- a/regression-test/suites/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.groovy +++ b/regression-test/suites/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.groovy @@ -51,7 +51,7 @@ suite("adjust_virtual_slot_nullable") { NOT ( day(t2.c_date) IN (1, 3) AND - day(t2.c_date) IN (2, 3, 3) + day(t2.c_date) = t2.c_int ); """