Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ public String getExplainString(String prefix) {
.append(", nullable=").append(isNullable)
.append(", isAutoIncrement=").append(isAutoInc)
.append(", subColPath=").append(subColPath)
.append(", virtualColumn=").append(virtualColumn)
.append(", virtualColumn=").append(virtualColumn == null ? null : virtualColumn.toSql())
.append("}")
.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ private Expression replaceConstants(Expression expression, boolean useInnerInfer
return replaceOrConstants((Or) expression, useInnerInfer, context, parentEqualSet, parentConstants);
} else if (!parentConstants.isEmpty()
&& expression.anyMatch(e -> e instanceof Slot && parentConstants.containsKey(e))) {
Expression newExpr = ExpressionUtils.replaceIf(expression, parentConstants, this::canReplaceExpression);
Expression newExpr = ExpressionUtils.replaceIf(
expression, parentConstants, this::canReplaceExpression, true);
if (!newExpr.equals(expression)) {
newExpr = FoldConstantRule.evaluate(newExpr, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand All @@ -29,11 +30,9 @@
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsBigInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsLargeInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsSmallInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeString;
import org.apache.doris.nereids.trees.expressions.functions.scalar.IsIpAddressInRange;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MultiMatch;
Expand Down Expand Up @@ -198,14 +197,15 @@ public List<Rule> buildRules() {
private Plan pushDown(LogicalFilter<LogicalOlapScan> filter, LogicalOlapScan logicalOlapScan,
Optional<LogicalProject<?>> optionalProject) {
// 1. extract repeated sub-expressions from filter conjuncts
// 2. generate virtual columns and add them to scan
// 2. generate virtual columns
// 3. replace filter and project
// 4. add useful virtual columns to scan

Map<Expression, Expression> replaceMap = Maps.newHashMap();
ImmutableList.Builder<NamedExpression> virtualColumnsBuilder = ImmutableList.builder();
Map<Expression, NamedExpression> virtualColumnsMap = Maps.newHashMap();

// Extract repeated sub-expressions
extractRepeatedSubExpressions(filter, optionalProject, replaceMap, virtualColumnsBuilder);
extractRepeatedSubExpressions(filter, optionalProject, replaceMap, virtualColumnsMap);

if (replaceMap.isEmpty()) {
return null;
Expand All @@ -216,17 +216,41 @@ private Plan pushDown(LogicalFilter<LogicalOlapScan> filter, LogicalOlapScan log
replaceMap.size(), replaceMap.keySet());
}

// Create new scan with virtual columns
logicalOlapScan = logicalOlapScan.withVirtualColumns(virtualColumnsBuilder.build());

// Replace expressions in filter and project
Set<Expression> conjuncts = ExpressionUtils.replace(filter.getConjuncts(), replaceMap);
Plan plan = filter.withConjunctsAndChild(conjuncts, logicalOlapScan);
Map<Expression, Integer> counterMap = Maps.newHashMap();
Set<Expression> conjuncts = ExpressionUtils.replaceWithCounter(filter.getConjuncts(), replaceMap, counterMap);
List<NamedExpression> projections = null;
if (optionalProject.isPresent()) {
LogicalProject<?> project = optionalProject.get();
projections = ExpressionUtils.replaceWithCounter(
(List) project.getProjects(), replaceMap, counterMap);
}

// generate a map that only contains the expression really used in conjuncts and projections
Map<Expression, Expression> realReplacedMap = Maps.newHashMap();
for (Map.Entry<Expression, Integer> entry : counterMap.entrySet()) {
realReplacedMap.put(entry.getKey(), replaceMap.get(entry.getKey()));
}
// use replace map to replace virtual column expression
for (Map.Entry<Expression, NamedExpression> entry : virtualColumnsMap.entrySet()) {
Expression value = entry.getValue();
NamedExpression afterReplacement = (NamedExpression) ExpressionUtils.replaceIf(
value, replaceMap, e -> !e.equals(value.child(0)), false);
if (afterReplacement != value) {
virtualColumnsMap.put(entry.getKey(), afterReplacement);
}
}

// replace virtual columns with other virtual columns
ImmutableList.Builder<NamedExpression> virtualColumnsBuilder = ImmutableList.builder();
for (Map.Entry<Expression, Expression> entry : replaceMap.entrySet()) {
virtualColumnsBuilder.add(virtualColumnsMap.get(entry.getKey()));
}

logicalOlapScan = logicalOlapScan.withVirtualColumns(virtualColumnsBuilder.build());
Plan plan = filter.withConjunctsAndChild(conjuncts, logicalOlapScan);
if (optionalProject.isPresent()) {
LogicalProject<?> project = optionalProject.get();
List<NamedExpression> projections = ExpressionUtils.replace(
(List) project.getProjects(), replaceMap);
plan = project.withProjectsAndChild(projections, plan);
} else {
plan = new LogicalProject<>((List) filter.getOutput(), plan);
Expand All @@ -240,7 +264,7 @@ private Plan pushDown(LogicalFilter<LogicalOlapScan> filter, LogicalOlapScan log
private void extractRepeatedSubExpressions(LogicalFilter<LogicalOlapScan> filter,
Optional<LogicalProject<?>> optionalProject,
Map<Expression, Expression> replaceMap,
ImmutableList.Builder<NamedExpression> virtualColumnsBuilder) {
Map<Expression, NamedExpression> virtualColumnsMap) {

// Collect all expressions from filter and project
Set<Expression> allExpressions = new HashSet<>();
Expand Down Expand Up @@ -278,7 +302,7 @@ private void extractRepeatedSubExpressions(LogicalFilter<LogicalOlapScan> filter
Expression expr = entry.getKey();
Alias alias = new Alias(expr);
replaceMap.put(expr, alias.toSlot());
virtualColumnsBuilder.add(alias);
virtualColumnsMap.put(expr, alias);

if (LOG.isDebugEnabled()) {
LOG.debug("Created virtual column for expression: {} with type: {}",
Expand All @@ -288,7 +312,7 @@ private void extractRepeatedSubExpressions(LogicalFilter<LogicalOlapScan> filter

// Logging for debugging
if (LOG.isDebugEnabled()) {
logger.debug("Extracted virtual columns: {}", virtualColumnsBuilder.build());
logger.debug("Extracted virtual columns: {}", virtualColumnsMap.values());
}
}

Expand Down Expand Up @@ -316,24 +340,17 @@ private void collectSubExpressions(Expression expr, Map<Expression, Integer> exp
return;
}

if (skipResult.shouldSkipCounting() || skipResult.isNotBeneficial()) {
// Examples for SKIP_COUNTING: CAST(x AS VARCHAR)
// Examples for SKIP_NOT_BENEFICIAL:
// - encode_as_bigint(x), decode_as_varchar(x)
// - x > 10, x IN (1,2,3), x IS NULL (ColumnPredicate convertible)
// - is_ip_address_in_range(ip, '192.168.1.0/24'), multi_match(text, 'query') (index pushdown)
// - expressions containing lambda functions
// These expressions are not counted but we continue processing their children
for (Expression child : expr.children()) {
collectSubExpressions(child, expressionCounts, insideLambda);
// CONTINUE case: Examples like x + y, func(a, b), (x + y) * z
// Only count expressions that meet minimum complexity requirements
if (!(skipResult.shouldSkipCounting() || skipResult.isNotBeneficial())) {
if (expr.getDepth() >= MIN_EXPRESSION_DEPTH && expr.children().size() > 0) {
expressionCounts.put(expr, expressionCounts.getOrDefault(expr, 0) + 1);
}
return;
}

// CONTINUE case: Examples like x + y, func(a, b), (x + y) * z
// Only count expressions that meet minimum complexity requirements
if (expr.getDepth() >= MIN_EXPRESSION_DEPTH && expr.children().size() > 0) {
expressionCounts.put(expr, expressionCounts.getOrDefault(expr, 0) + 1);
// if the Expression has been collected, we do not collect it's children again
if (expressionCounts.getOrDefault(expr, 0) > 1) {
return;
}

// Recursively process children
Expand All @@ -352,30 +369,31 @@ private void collectSubExpressions(Expression expr, Map<Expression, Integer> exp
* @return SkipResult indicating how to handle this expression
*/
private SkipResult shouldSkipExpression(Expression expr, boolean insideLambda) {
// Skip simple slots and literals as they don't benefit from being pushed down
if (expr instanceof Slot || expr.isConstant()) {
return SkipResult.TERMINATE;
}

// Skip expressions inside lambda functions - they shouldn't be optimized
if (insideLambda) {
if (expr.containsType(ArrayItemReference.class)) {
return SkipResult.SKIP_NOT_BENEFICIAL;
}
}

// Skip simple slots and literals as they don't benefit from being pushed down
if (expr instanceof Slot || expr.isConstant()) {
return SkipResult.TERMINATE;
}

// Skip CAST expressions - they shouldn't be optimized as common sub-expressions
// Skip CAST and WhenClause expressions - they shouldn't be optimized as common sub-expressions
// but we still need to process their children
if (expr instanceof Cast) {
if (expr instanceof Cast || expr instanceof WhenClause) {
return SkipResult.SKIP_COUNTING;
}

// Skip expressions with decode_as_varchar or encode_as_bigint as root
if (expr instanceof DecodeAsVarchar || expr instanceof EncodeAsBigInt || expr instanceof EncodeAsInt
|| expr instanceof EncodeAsLargeInt || expr instanceof EncodeAsSmallInt) {
if (expr instanceof DecodeAsVarchar || expr instanceof EncodeString) {
return SkipResult.SKIP_NOT_BENEFICIAL;
}

// Skip expressions that contain lambda functions anywhere in the tree
if (containsLambdaFunction(expr)) {
if (expr instanceof Lambda) {
return SkipResult.SKIP_NOT_BENEFICIAL;
}

Expand All @@ -389,23 +407,6 @@ private SkipResult shouldSkipExpression(Expression expr, boolean insideLambda) {
return SkipResult.CONTINUE;
}

/**
* Check if an expression contains lambda functions
*/
private boolean containsLambdaFunction(Expression expr) {
if (expr instanceof Lambda) {
return true;
}

for (Expression child : expr.children()) {
if (containsLambdaFunction(child)) {
return true;
}
}

return false;
}

/**
* Result type for expression skip decisions
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,20 @@ default NODE_TYPE rewriteDownShortCircuit(Function<NODE_TYPE, NODE_TYPE> rewrite
* border predicate are rewritten, and if a node not match predicate, then its descendant will not rewrite.
*/
default NODE_TYPE rewriteDownShortCircuitDown(Function<NODE_TYPE, NODE_TYPE> rewriteFunction,
Predicate predicate) {
Predicate predicate, boolean stopWhenNotMatched) {
NODE_TYPE currentNode = (NODE_TYPE) this;
if (!predicate.test(this)) {
boolean matched = predicate.test(this);
if (stopWhenNotMatched && !matched) {
return currentNode;
}
currentNode = rewriteFunction.apply(currentNode);
if (matched) {
currentNode = rewriteFunction.apply(currentNode);
}
if (currentNode == this) {
Builder<NODE_TYPE> newChildren = ImmutableList.builderWithExpectedSize(arity());
boolean changed = false;
for (NODE_TYPE child : children()) {
NODE_TYPE newChild = child.rewriteDownShortCircuitDown(rewriteFunction, predicate);
NODE_TYPE newChild = child.rewriteDownShortCircuitDown(rewriteFunction, predicate, stopWhenNotMatched);
if (child != newChild) {
changed = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,59 @@ public static NamedExpression replaceNameExpression(NamedExpression expr,
* Replace expression node with predicate in the expression tree by `replaceMap` in top-down manner.
*/
public static Expression replaceIf(Expression expr, Map<? extends Expression, ? extends Expression> replaceMap,
Predicate<Expression> predicate) {
Predicate<Expression> predicate, boolean stopWhenNotMatched) {
return expr.rewriteDownShortCircuitDown(e -> {
Expression replacedExpr = replaceMap.get(e);
return replacedExpr == null ? e : replacedExpr;
}, predicate);
}, predicate, stopWhenNotMatched);
}

public static Set<Expression> replaceWithCounter(Set<Expression> exprs,
Map<? extends Expression, ? extends Expression> replaceMap, Map<Expression, Integer> counterMap) {
ImmutableSet.Builder<Expression> result = ImmutableSet.builderWithExpectedSize(exprs.size());
for (Expression expr : exprs) {
result.add(replaceWithCounter(expr, replaceMap, counterMap));
}
return result.build();
}

public static List<Expression> replaceWithCounter(List<Expression> exprs,
Map<? extends Expression, ? extends Expression> replaceMap,
Map<Expression, Integer> counterMap) {
ImmutableList.Builder<Expression> result = ImmutableList.builderWithExpectedSize(exprs.size());
for (Expression expr : exprs) {
result.add(replaceWithCounter(expr, replaceMap, counterMap));
}
return result.build();
}

/**
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
* This function gives counter map to record replace count.
* For example.
* <pre>
* input expression: a > 1
* replaceMap: a -> b + c
*
* output:
* b + c > 1
* </pre>
*/
public static Expression replaceWithCounter(Expression expr,
Map<? extends Expression, ? extends Expression> replaceMap,
Map<Expression, Integer> counterMap) {
return expr.rewriteDownShortCircuit(e -> {
Expression replacedExpr = replaceMap.get(e);
if (replacedExpr != null) {
if (!counterMap.containsKey(e)) {
counterMap.put(e, 1);
} else {
counterMap.put(e, counterMap.get(e) + 1);
}
return replacedExpr;
}
return e;
});
}

/**
Expand Down
Loading