Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct;
import org.apache.doris.nereids.rules.rewrite.InitJoinOrder;
import org.apache.doris.nereids.rules.rewrite.InlineLogicalView;
import org.apache.doris.nereids.rules.rewrite.JoinExtractOrFromCaseWhen;
import org.apache.doris.nereids.rules.rewrite.LimitAggToTopNAgg;
import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN;
import org.apache.doris.nereids.rules.rewrite.LogicalResultSinkToShortCircuitPointQuery;
Expand Down Expand Up @@ -323,7 +324,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
new PushFilterInsideJoin(),
new FindHashConditionForJoin(),
new ConvertInnerOrCrossJoin(),
new EliminateNullAwareLeftAntiJoin()
new EliminateNullAwareLeftAntiJoin(),
new JoinExtractOrFromCaseWhen()
),
// push down SEMI Join
bottomUp(
Expand Down Expand Up @@ -553,7 +555,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
new PushFilterInsideJoin(),
new FindHashConditionForJoin(),
new ConvertInnerOrCrossJoin(),
new EliminateNullAwareLeftAntiJoin()
new EliminateNullAwareLeftAntiJoin(),
new JoinExtractOrFromCaseWhen()
),
// push down SEMI Join
bottomUp(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ public enum RuleType {
REWRITE_PARTITION_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_QUALIFY_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
JOIN_EXTRACT_OR_FROM_CASE_WHEN(RuleTypeClass.REWRITE),
EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE),
REORDER_JOIN(RuleTypeClass.REWRITE),
INIT_JOIN_ORDER(RuleTypeClass.REWRITE),
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,10 @@ public Plan visitLogicalCTEAnchor(
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, OrExpandsionContext ctx) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) this.visit(join, ctx);
if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) {
return join;
}
if (!supportJoinType.contains(join.getJoinType())) {
if (!needRewriteJoin(join)) {
return join;
}

Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(),
"Only Expansion nest loop join without hashCond");

Expand Down Expand Up @@ -207,6 +205,16 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, O
return null;
}

/**
* check whether it need to rewrite the join
*/
public boolean needRewriteJoin(LogicalJoin<? extends Plan, ? extends Plan> join) {
if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) {
return false;
}
return supportJoinType.contains(join.getJoinType());
}

private Map<Slot, Slot> constructReplaceMap(LogicalCTEConsumer leftConsumer, Map<Slot, Slot> leftCloneToLeft,
LogicalCTEConsumer rightConsumer, Map<Slot, Slot> rightCloneToRight) {
Map<Slot, Slot> replaced = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
* Push the other join conditions in LogicalJoin to children.
*/
public class PushDownJoinOtherCondition extends OneRewriteRuleFactory {
private static final ImmutableList<JoinType> PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of(
/**
* left push support type
*/
public static final ImmutableList<JoinType> PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.LEFT_SEMI_JOIN,
JoinType.RIGHT_OUTER_JOIN,
Expand All @@ -46,7 +49,10 @@ public class PushDownJoinOtherCondition extends OneRewriteRuleFactory {
JoinType.CROSS_JOIN
);

private static final ImmutableList<JoinType> PUSH_DOWN_RIGHT_VALID_TYPE = ImmutableList.of(
/**
* right push support type
*/
public static final ImmutableList<JoinType> PUSH_DOWN_RIGHT_VALID_TYPE = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.LEFT_OUTER_JOIN,
JoinType.LEFT_ANTI_JOIN,
Expand All @@ -60,7 +66,8 @@ public class PushDownJoinOtherCondition extends OneRewriteRuleFactory {
public Rule build() {
return logicalJoin()
// TODO: we may need another rule to handle on true or on false condition
.when(join -> !join.getOtherJoinConjuncts().isEmpty() && !join.isMarkJoin())
.when(join -> !join.getOtherJoinConjuncts().isEmpty())
.when(PushDownJoinOtherCondition::needRewrite)
.then(join -> {
List<Expression> otherJoinConjuncts = join.getOtherJoinConjuncts();
List<Expression> remainingOther = Lists.newArrayList();
Expand Down Expand Up @@ -93,7 +100,14 @@ && allCoveredBy(otherConjunct, join.right().getOutputSet())) {
}).toRule(RuleType.PUSH_DOWN_JOIN_OTHER_CONDITION);
}

private boolean allCoveredBy(Expression predicate, Set<Slot> inputSlotSet) {
/**
* check need rewrite
*/
public static boolean needRewrite(LogicalJoin<Plan, Plan> join) {
return !join.isMarkJoin();
}

private static boolean allCoveredBy(Expression predicate, Set<Slot> inputSlotSet) {
return inputSlotSet.containsAll(predicate.getInputSlots());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
Expand All @@ -51,11 +52,15 @@
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl;
import org.apache.doris.nereids.trees.expressions.functions.scalar.UniqueFunction;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
Expand Down Expand Up @@ -1220,6 +1225,40 @@ public static String slotListShapeInfo(List<Slot> materializedSlots) {
return shapeBuilder.toString();
}

/**
* check whether the expression contains CaseWhen like type
*/
public static boolean containsCaseWhenLikeType(Expression expression) {
return expression.containsType(CaseWhen.class, If.class, NullIf.class, Nvl.class);
}

/**
* get the results of each branch in CaseWhen like expression
*/
public static Optional<List<Expression>> getCaseWhenLikeBranchResults(Expression expression) {
if (expression instanceof CaseWhen) {
CaseWhen caseWhen = (CaseWhen) expression;
ImmutableList.Builder<Expression> builder
= ImmutableList.builderWithExpectedSize(caseWhen.getWhenClauses().size() + 1);
for (WhenClause whenClause : caseWhen.getWhenClauses()) {
builder.add(whenClause.getResult());
}
builder.add(caseWhen.getDefaultValue().orElse(new NullLiteral(caseWhen.getDataType())));
return Optional.of(builder.build());
} else if (expression instanceof If) {
If ifExpr = (If) expression;
return Optional.of(ImmutableList.of(ifExpr.getTrueValue(), ifExpr.getFalseValue()));
} else if (expression instanceof NullIf) {
NullIf nullIf = (NullIf) expression;
return Optional.of(ImmutableList.of(new NullLiteral(nullIf.getDataType()), nullIf.left()));
} else if (expression instanceof Nvl) {
Nvl nvl = (Nvl) expression;
return Optional.of(ImmutableList.of(nvl.left(), nvl.right()));
} else {
return Optional.empty();
}
}

/**
* has aggregate function, exclude the window function
*/
Expand Down
Loading
Loading