diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotDescriptor.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotDescriptor.java index 5d0ec929ad6fc3..93bf38b820e6dc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotDescriptor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotDescriptor.java @@ -363,7 +363,7 @@ public String getExplainString(String prefix) { .append(", nullable=").append(isNullable) .append(", isAutoIncrement=").append(isAutoInc) .append(", subColPath=").append(subColPath) - .append(", virtualColumn=").append(virtualColumn) + .append(", virtualColumn=").append(virtualColumn == null ? null : virtualColumn.toSql()) .append("}") .toString(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java index c8c839af995a03..72283f46a49e45 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java @@ -374,7 +374,8 @@ private Expression replaceConstants(Expression expression, boolean useInnerInfer return replaceOrConstants((Or) expression, useInnerInfer, context, parentEqualSet, parentConstants); } else if (!parentConstants.isEmpty() && expression.anyMatch(e -> e instanceof Slot && parentConstants.containsKey(e))) { - Expression newExpr = ExpressionUtils.replaceIf(expression, parentConstants, this::canReplaceExpression); + Expression newExpr = ExpressionUtils.replaceIf( + expression, parentConstants, this::canReplaceExpression, true); if (!newExpr.equals(expression)) { newExpr = FoldConstantRule.evaluate(newExpr, context); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java index e7188dbffaefc8..f2490aa71f61fc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; @@ -29,11 +30,9 @@ import org.apache.doris.nereids.trees.expressions.Match; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar; -import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsBigInt; -import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsInt; -import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsLargeInt; -import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsSmallInt; +import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeString; import org.apache.doris.nereids.trees.expressions.functions.scalar.IsIpAddressInRange; import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; import org.apache.doris.nereids.trees.expressions.functions.scalar.MultiMatch; @@ -198,14 +197,15 @@ public List buildRules() { private Plan pushDown(LogicalFilter filter, LogicalOlapScan logicalOlapScan, Optional> optionalProject) { // 1. extract repeated sub-expressions from filter conjuncts - // 2. generate virtual columns and add them to scan + // 2. generate virtual columns // 3. replace filter and project + // 4. add useful virtual columns to scan Map replaceMap = Maps.newHashMap(); - ImmutableList.Builder virtualColumnsBuilder = ImmutableList.builder(); + Map virtualColumnsMap = Maps.newHashMap(); // Extract repeated sub-expressions - extractRepeatedSubExpressions(filter, optionalProject, replaceMap, virtualColumnsBuilder); + extractRepeatedSubExpressions(filter, optionalProject, replaceMap, virtualColumnsMap); if (replaceMap.isEmpty()) { return null; @@ -216,17 +216,41 @@ private Plan pushDown(LogicalFilter filter, LogicalOlapScan log replaceMap.size(), replaceMap.keySet()); } - // Create new scan with virtual columns - logicalOlapScan = logicalOlapScan.withVirtualColumns(virtualColumnsBuilder.build()); - // Replace expressions in filter and project - Set conjuncts = ExpressionUtils.replace(filter.getConjuncts(), replaceMap); - Plan plan = filter.withConjunctsAndChild(conjuncts, logicalOlapScan); + Map counterMap = Maps.newHashMap(); + Set conjuncts = ExpressionUtils.replaceWithCounter(filter.getConjuncts(), replaceMap, counterMap); + List projections = null; + if (optionalProject.isPresent()) { + LogicalProject project = optionalProject.get(); + projections = ExpressionUtils.replaceWithCounter( + (List) project.getProjects(), replaceMap, counterMap); + } + + // generate a map that only contains the expression really used in conjuncts and projections + Map realReplacedMap = Maps.newHashMap(); + for (Map.Entry entry : counterMap.entrySet()) { + realReplacedMap.put(entry.getKey(), replaceMap.get(entry.getKey())); + } + // use replace map to replace virtual column expression + for (Map.Entry entry : virtualColumnsMap.entrySet()) { + Expression value = entry.getValue(); + NamedExpression afterReplacement = (NamedExpression) ExpressionUtils.replaceIf( + value, replaceMap, e -> !e.equals(value.child(0)), false); + if (afterReplacement != value) { + virtualColumnsMap.put(entry.getKey(), afterReplacement); + } + } + // replace virtual columns with other virtual columns + ImmutableList.Builder virtualColumnsBuilder = ImmutableList.builder(); + for (Map.Entry entry : replaceMap.entrySet()) { + virtualColumnsBuilder.add(virtualColumnsMap.get(entry.getKey())); + } + + logicalOlapScan = logicalOlapScan.withVirtualColumns(virtualColumnsBuilder.build()); + Plan plan = filter.withConjunctsAndChild(conjuncts, logicalOlapScan); if (optionalProject.isPresent()) { LogicalProject project = optionalProject.get(); - List projections = ExpressionUtils.replace( - (List) project.getProjects(), replaceMap); plan = project.withProjectsAndChild(projections, plan); } else { plan = new LogicalProject<>((List) filter.getOutput(), plan); @@ -240,7 +264,7 @@ private Plan pushDown(LogicalFilter filter, LogicalOlapScan log private void extractRepeatedSubExpressions(LogicalFilter filter, Optional> optionalProject, Map replaceMap, - ImmutableList.Builder virtualColumnsBuilder) { + Map virtualColumnsMap) { // Collect all expressions from filter and project Set allExpressions = new HashSet<>(); @@ -278,7 +302,7 @@ private void extractRepeatedSubExpressions(LogicalFilter filter Expression expr = entry.getKey(); Alias alias = new Alias(expr); replaceMap.put(expr, alias.toSlot()); - virtualColumnsBuilder.add(alias); + virtualColumnsMap.put(expr, alias); if (LOG.isDebugEnabled()) { LOG.debug("Created virtual column for expression: {} with type: {}", @@ -288,7 +312,7 @@ private void extractRepeatedSubExpressions(LogicalFilter filter // Logging for debugging if (LOG.isDebugEnabled()) { - logger.debug("Extracted virtual columns: {}", virtualColumnsBuilder.build()); + logger.debug("Extracted virtual columns: {}", virtualColumnsMap.values()); } } @@ -316,24 +340,17 @@ private void collectSubExpressions(Expression expr, Map exp return; } - if (skipResult.shouldSkipCounting() || skipResult.isNotBeneficial()) { - // Examples for SKIP_COUNTING: CAST(x AS VARCHAR) - // Examples for SKIP_NOT_BENEFICIAL: - // - encode_as_bigint(x), decode_as_varchar(x) - // - x > 10, x IN (1,2,3), x IS NULL (ColumnPredicate convertible) - // - is_ip_address_in_range(ip, '192.168.1.0/24'), multi_match(text, 'query') (index pushdown) - // - expressions containing lambda functions - // These expressions are not counted but we continue processing their children - for (Expression child : expr.children()) { - collectSubExpressions(child, expressionCounts, insideLambda); + // CONTINUE case: Examples like x + y, func(a, b), (x + y) * z + // Only count expressions that meet minimum complexity requirements + if (!(skipResult.shouldSkipCounting() || skipResult.isNotBeneficial())) { + if (expr.getDepth() >= MIN_EXPRESSION_DEPTH && expr.children().size() > 0) { + expressionCounts.put(expr, expressionCounts.getOrDefault(expr, 0) + 1); } - return; } - // CONTINUE case: Examples like x + y, func(a, b), (x + y) * z - // Only count expressions that meet minimum complexity requirements - if (expr.getDepth() >= MIN_EXPRESSION_DEPTH && expr.children().size() > 0) { - expressionCounts.put(expr, expressionCounts.getOrDefault(expr, 0) + 1); + // if the Expression has been collected, we do not collect it's children again + if (expressionCounts.getOrDefault(expr, 0) > 1) { + return; } // Recursively process children @@ -352,30 +369,31 @@ private void collectSubExpressions(Expression expr, Map exp * @return SkipResult indicating how to handle this expression */ private SkipResult shouldSkipExpression(Expression expr, boolean insideLambda) { - // Skip simple slots and literals as they don't benefit from being pushed down - if (expr instanceof Slot || expr.isConstant()) { - return SkipResult.TERMINATE; - } - // Skip expressions inside lambda functions - they shouldn't be optimized if (insideLambda) { + if (expr.containsType(ArrayItemReference.class)) { + return SkipResult.SKIP_NOT_BENEFICIAL; + } + } + + // Skip simple slots and literals as they don't benefit from being pushed down + if (expr instanceof Slot || expr.isConstant()) { return SkipResult.TERMINATE; } - // Skip CAST expressions - they shouldn't be optimized as common sub-expressions + // Skip CAST and WhenClause expressions - they shouldn't be optimized as common sub-expressions // but we still need to process their children - if (expr instanceof Cast) { + if (expr instanceof Cast || expr instanceof WhenClause) { return SkipResult.SKIP_COUNTING; } // Skip expressions with decode_as_varchar or encode_as_bigint as root - if (expr instanceof DecodeAsVarchar || expr instanceof EncodeAsBigInt || expr instanceof EncodeAsInt - || expr instanceof EncodeAsLargeInt || expr instanceof EncodeAsSmallInt) { + if (expr instanceof DecodeAsVarchar || expr instanceof EncodeString) { return SkipResult.SKIP_NOT_BENEFICIAL; } // Skip expressions that contain lambda functions anywhere in the tree - if (containsLambdaFunction(expr)) { + if (expr instanceof Lambda) { return SkipResult.SKIP_NOT_BENEFICIAL; } @@ -389,23 +407,6 @@ private SkipResult shouldSkipExpression(Expression expr, boolean insideLambda) { return SkipResult.CONTINUE; } - /** - * Check if an expression contains lambda functions - */ - private boolean containsLambdaFunction(Expression expr) { - if (expr instanceof Lambda) { - return true; - } - - for (Expression child : expr.children()) { - if (containsLambdaFunction(child)) { - return true; - } - } - - return false; - } - /** * Result type for expression skip decisions */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java index c7a6ad3fc8e1d6..74215be9c4ea6b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java @@ -153,17 +153,20 @@ default NODE_TYPE rewriteDownShortCircuit(Function rewrite * border predicate are rewritten, and if a node not match predicate, then its descendant will not rewrite. */ default NODE_TYPE rewriteDownShortCircuitDown(Function rewriteFunction, - Predicate predicate) { + Predicate predicate, boolean stopWhenNotMatched) { NODE_TYPE currentNode = (NODE_TYPE) this; - if (!predicate.test(this)) { + boolean matched = predicate.test(this); + if (stopWhenNotMatched && !matched) { return currentNode; } - currentNode = rewriteFunction.apply(currentNode); + if (matched) { + currentNode = rewriteFunction.apply(currentNode); + } if (currentNode == this) { Builder newChildren = ImmutableList.builderWithExpectedSize(arity()); boolean changed = false; for (NODE_TYPE child : children()) { - NODE_TYPE newChild = child.rewriteDownShortCircuitDown(rewriteFunction, predicate); + NODE_TYPE newChild = child.rewriteDownShortCircuitDown(rewriteFunction, predicate, stopWhenNotMatched); if (child != newChild) { changed = true; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index f7d4e3488f6090..15cb6753c84fe6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -423,11 +423,59 @@ public static NamedExpression replaceNameExpression(NamedExpression expr, * Replace expression node with predicate in the expression tree by `replaceMap` in top-down manner. */ public static Expression replaceIf(Expression expr, Map replaceMap, - Predicate predicate) { + Predicate predicate, boolean stopWhenNotMatched) { return expr.rewriteDownShortCircuitDown(e -> { Expression replacedExpr = replaceMap.get(e); return replacedExpr == null ? e : replacedExpr; - }, predicate); + }, predicate, stopWhenNotMatched); + } + + public static Set replaceWithCounter(Set exprs, + Map replaceMap, Map counterMap) { + ImmutableSet.Builder result = ImmutableSet.builderWithExpectedSize(exprs.size()); + for (Expression expr : exprs) { + result.add(replaceWithCounter(expr, replaceMap, counterMap)); + } + return result.build(); + } + + public static List replaceWithCounter(List exprs, + Map replaceMap, + Map counterMap) { + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(exprs.size()); + for (Expression expr : exprs) { + result.add(replaceWithCounter(expr, replaceMap, counterMap)); + } + return result.build(); + } + + /** + * Replace expression node in the expression tree by `replaceMap` in top-down manner. + * This function gives counter map to record replace count. + * For example. + *
+     * input expression: a > 1
+     * replaceMap: a -> b + c
+     *
+     * output:
+     * b + c > 1
+     * 
+ */ + public static Expression replaceWithCounter(Expression expr, + Map replaceMap, + Map counterMap) { + return expr.rewriteDownShortCircuit(e -> { + Expression replacedExpr = replaceMap.get(e); + if (replacedExpr != null) { + if (!counterMap.containsKey(e)) { + counterMap.put(e, 1); + } else { + counterMap.put(e, counterMap.get(e) + 1); + } + return replacedExpr; + } + return e; + }); } /** diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScanTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScanTest.java index 9a1e385261426b..00447a3a2aac37 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScanTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScanTest.java @@ -20,8 +20,11 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference; +import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.IsNull; @@ -30,6 +33,10 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Array; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap; import org.apache.doris.nereids.trees.expressions.functions.scalar.Concat; import org.apache.doris.nereids.trees.expressions.functions.scalar.IsIpAddressInRange; import org.apache.doris.nereids.trees.expressions.functions.scalar.L2Distance; @@ -44,18 +51,27 @@ import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.VarcharType; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.lang.reflect.Method; import java.util.List; +import java.util.Map; +import java.util.Optional; /** * Test for PushDownVirtualColumnsIntoOlapScan rule. */ -public class PushDownVirtualColumnsIntoOlapScanTest { +public class PushDownVirtualColumnsIntoOlapScanTest implements MemoPatternMatchSupported { @Test public void testExtractRepeatedSubExpressions() { @@ -98,7 +114,7 @@ public void testExtractRepeatedSubExpressions() { List rules = rule.buildRules(); // Test that rules are created - assert rules.size() == 2; + Assertions.assertEquals(2, rules.size()); // Test rule application on the actual plan structures boolean projectFilterScanRuleMatches = false; @@ -113,8 +129,8 @@ public void testExtractRepeatedSubExpressions() { } } - assert projectFilterScanRuleMatches : "Should have rule for Project->Filter->Scan pattern"; - assert filterScanRuleMatches : "Should have rule for Filter->Scan pattern"; + Assertions.assertTrue(projectFilterScanRuleMatches, "Should have rule for Project->Filter->Scan pattern"); + Assertions.assertTrue(filterScanRuleMatches, "Should have rule for Filter->Scan pattern"); } @Test @@ -140,17 +156,18 @@ public void testExtractDistanceFunctions() { List rules = rule.buildRules(); // Should create appropriate rules - assert rules.size() == 2; + Assertions.assertEquals(2, rules.size()); // Verify the filter contains the distance function - assert filter.getConjuncts().contains(distanceFilter) : "Filter should contain distance function"; - assert filter.child() == scan : "Filter should have scan as child"; + Assertions.assertTrue(filter.getConjuncts().contains(distanceFilter), + "Filter should contain distance function"); + Assertions.assertEquals(scan, filter.child(), "Filter should have scan as child"); // Verify distance function structure - assert distanceFilter.left() instanceof L2Distance : "Should have L2Distance function"; + Assertions.assertInstanceOf(L2Distance.class, distanceFilter.left(), "Should have L2Distance function"); L2Distance distFunc = (L2Distance) distanceFilter.left(); - assert distFunc.child(0) == vector1 : "First argument should be vector1"; - assert distFunc.child(1) == vector2 : "Second argument should be vector2"; + Assertions.assertEquals(vector1, distFunc.child(0), "First argument should be vector1"); + Assertions.assertEquals(vector2, distFunc.child(1), "First argument should be vector2"); } @Test @@ -186,13 +203,13 @@ public void testComplexRepeatedExpressions() { List rules = rule.buildRules(); // Should create appropriate rules for complex expressions - assert rules.size() == 2; + Assertions.assertEquals(2, rules.size()); // Verify the filter structure - assert filter.getConjuncts().size() == 2 : "Filter should have 2 conjuncts"; - assert filter.getConjuncts().contains(gt) : "Filter should contain greater than condition"; - assert filter.getConjuncts().contains(lt) : "Filter should contain less than condition"; - assert filter.child() == scan : "Filter should have scan as child"; + Assertions.assertEquals(2, filter.getConjuncts().size(), "Filter should have 2 conjuncts"); + Assertions.assertTrue(filter.getConjuncts().contains(gt), "Filter should contain greater than condition"); + Assertions.assertTrue(filter.getConjuncts().contains(lt), "Filter should contain less than condition"); + Assertions.assertEquals(scan, filter.child(), "Filter should have scan as child"); // Verify complex expressions are structurally equivalent (though different objects) // Both should be Add expressions with Multiply as left child @@ -202,26 +219,62 @@ public void testComplexRepeatedExpressions() { assert complexExpr2.left() instanceof Multiply : "Left side should be Multiply"; } + @Test + public void testSkipWhenClause() { + // Test that WhenClause expressions are not optimized as common sub-expressions + // SELECT * FROM table WHERE CASE WHEN x = 1 THEN 'abc' ELSE WHEN x = 1 THEN 'abc' END != 'def' + + LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + SlotReference x = (SlotReference) scan.getOutput().get(0); + + // Create repeated CAST expressions + WhenClause whenClause = new WhenClause(x, new StringLiteral("abc")); + CaseWhen caseWhen = new CaseWhen(ImmutableList.of(whenClause, whenClause)); + + // Create OLAP scan + + // Create filter with repeated CAST expressions + LogicalFilter filter = new LogicalFilter<>( + ImmutableSet.of(caseWhen), scan); + + // Apply the rule + PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); + List rules = rule.buildRules(); + + // Test that the rule can match the filter pattern (without executing transformation) + boolean hasMatchingRule = false; + for (Rule r : rules) { + if (r.getPattern().matchPlanTree(filter)) { + hasMatchingRule = true; + break; + } + } + + // WhenClause expressions should NOT be optimized, but the rule should still match the pattern + Assertions.assertTrue(hasMatchingRule, "Rule should match the filter pattern"); + + PlanChecker.from(MemoTestUtils.createConnectContext(), filter) + .applyTopDown(rules) + .matches(logicalOlapScan().when(o -> o.getVirtualColumns().isEmpty())); + } + @Test public void testSkipCastExpressions() { // Test that CAST expressions are not optimized as common sub-expressions // SELECT * FROM table WHERE CAST(x AS VARCHAR) = 'abc' AND CAST(x AS VARCHAR) != 'def' - DataType intType = IntegerType.INSTANCE; - DataType varcharType = VarcharType.SYSTEM_DEFAULT; - SlotReference x = new SlotReference("x", intType); + // Create OLAP scan + LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + SlotReference x = (SlotReference) scan.getOutput().get(0); // Create repeated CAST expressions - Cast cast1 = new Cast(x, varcharType); - Cast cast2 = new Cast(x, varcharType); + Cast cast1 = new Cast(x, VarcharType.SYSTEM_DEFAULT); + Cast cast2 = new Cast(x, VarcharType.SYSTEM_DEFAULT); // Create filter conditions using the repeated CAST expression EqualTo eq = new EqualTo(cast1, new StringLiteral("abc")); Not neq = new Not(new EqualTo(cast2, new StringLiteral("def"))); - // Create OLAP scan - LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); - // Create filter with repeated CAST expressions LogicalFilter filter = new LogicalFilter<>( ImmutableSet.of(eq, neq), scan); @@ -231,7 +284,7 @@ public void testSkipCastExpressions() { List rules = rule.buildRules(); // Test rule creation - assert rules.size() == 2; + Assertions.assertEquals(2, rules.size()); // Test that the rule can match the filter pattern (without executing transformation) boolean hasMatchingRule = false; @@ -243,7 +296,11 @@ public void testSkipCastExpressions() { } // CAST expressions should NOT be optimized, but the rule should still match the pattern - assert hasMatchingRule : "Rule should match the filter pattern"; + Assertions.assertTrue(hasMatchingRule, "Rule should match the filter pattern"); + + PlanChecker.from(MemoTestUtils.createConnectContext(), filter) + .applyTopDown(rules) + .matches(logicalOlapScan().when(o -> o.getVirtualColumns().isEmpty())); } @Test @@ -251,31 +308,34 @@ public void testSkipLambdaExpressions() { // Test that expressions inside lambda functions are not optimized // This is a simplified test since creating actual lambda expressions is complex - DataType intType = IntegerType.INSTANCE; - SlotReference x = new SlotReference("x", intType); - SlotReference y = new SlotReference("y", intType); - - // Create a repeated expression that would normally be optimized - Add xyAdd1 = new Add(x, y); - Add xyAdd2 = new Add(x, y); - - // Create filter conditions - one normal, one that would be inside a lambda context - GreaterThan gt1 = new GreaterThan(xyAdd1, new IntegerLiteral(10)); - GreaterThan gt2 = new GreaterThan(xyAdd2, new IntegerLiteral(20)); + ConnectContext connectContext = MemoTestUtils.createConnectContext(); // Create OLAP scan LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + SlotReference x = (SlotReference) scan.getOutput().get(0); + SlotReference y = (SlotReference) scan.getOutput().get(1); + Add xyAdd = new Add(x, y); + + // Create a lambda expression + Array arr = new Array(y); + ArrayItemReference refA = new ArrayItemReference("a", arr); + Add lambdaAdd = new Add(refA.toSlot(), xyAdd); + Lambda lambda = new Lambda(ImmutableList.of("a"), lambdaAdd, ImmutableList.of(refA)); + + // Create two expression contain lambda + ArrayFilter arrayFilter = new ArrayFilter(lambda); + ArrayMap arrayMap = new ArrayMap(lambda); // Create filter LogicalFilter filter = new LogicalFilter<>( - ImmutableSet.of(gt1, gt2), scan); + ImmutableSet.of(new EqualTo(arrayFilter, arrayMap)), scan); // Apply the rule PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); List rules = rule.buildRules(); // Test rule creation - assert rules.size() == 2; + Assertions.assertEquals(2, rules.size()); // This test verifies the rule structure but actual lambda testing would require // more complex expression trees with lambda functions @@ -286,7 +346,22 @@ public void testSkipLambdaExpressions() { break; } } - assert hasFilterScanRule : "Should have rule that matches filter->scan pattern"; + Assertions.assertTrue(hasFilterScanRule, "Should have rule that matches filter->scan pattern"); + + PlanChecker.from(connectContext, filter) + .applyTopDown(rules) + .applyCustom(new ColumnPruning()) + .matches(logicalOlapScan() + .when(o -> o.getVirtualColumns().size() == 2) + .when(o -> { + for (NamedExpression virtualColumn : o.getVirtualColumns()) { + Expression c = virtualColumn.child(0); + if (!(c instanceof ArrayMap) && !c.equals(arr)) { + return false; + } + } + return true; + })); } @Test @@ -328,16 +403,15 @@ public void testMixedComplexExpressions() { // Apply the rule PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); List rules = rule.buildRules(); - // Test rule creation - assert rules.size() == 2; + Assertions.assertEquals(2, rules.size()); // Verify filter structure - assert filter.getConjuncts().size() == 4 : "Filter should have 4 conjuncts"; - assert filter.getConjuncts().contains(gt) : "Filter should contain greater than condition"; - assert filter.getConjuncts().contains(lt) : "Filter should contain less than condition"; - assert filter.getConjuncts().contains(eq) : "Filter should contain equality condition"; - assert filter.getConjuncts().contains(neq) : "Filter should contain not equal condition"; + Assertions.assertEquals(4, filter.getConjuncts().size(), "Filter should have 4 conjuncts"); + Assertions.assertTrue(filter.getConjuncts().contains(gt), "Filter should contain greater than condition"); + Assertions.assertTrue(filter.getConjuncts().contains(lt), "Filter should contain less than condition"); + Assertions.assertTrue(filter.getConjuncts().contains(eq), "Filter should contain equality condition"); + Assertions.assertTrue(filter.getConjuncts().contains(neq), "Filter should contain not equal condition"); // Test that rules can match the pattern boolean hasMatchingRule = false; @@ -347,7 +421,7 @@ public void testMixedComplexExpressions() { break; } } - assert hasMatchingRule : "Should have rule that matches the filter pattern"; + Assertions.assertTrue(hasMatchingRule, "Should have rule that matches the filter pattern"); } @Test @@ -377,7 +451,7 @@ public void testNoOptimizationWhenNoRepeatedExpressions() { List rules = rule.buildRules(); // Test rule creation - assert rules.size() == 2; + Assertions.assertEquals(2, rules.size()); // Test that the rule can match the filter pattern (without executing transformation) boolean hasMatchingRule = false; @@ -389,7 +463,7 @@ public void testNoOptimizationWhenNoRepeatedExpressions() { } // No optimization should occur since there are no repeated expressions, but rule should match - assert hasMatchingRule : "Rule should match the filter pattern"; + Assertions.assertTrue(hasMatchingRule, "Should have rule that matches the filter pattern"); } @Test @@ -420,7 +494,7 @@ public void testRulePatternMatching() { PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); List rules = rule.buildRules(); - assert rules.size() == 2 : "Should create exactly 2 rules"; + Assertions.assertEquals(2, rules.size()); // Test pattern matching int projectFilterScanMatches = 0; @@ -435,12 +509,12 @@ public void testRulePatternMatching() { } } - assert projectFilterScanMatches == 1 : "Should have exactly 1 rule for Project->Filter->Scan"; - assert filterScanMatches == 1 : "Should have exactly 1 rule for Filter->Scan"; + Assertions.assertEquals(1, projectFilterScanMatches, "Should have exactly 1 rule for Project->Filter->Scan"); + Assertions.assertEquals(1, filterScanMatches, "Should have exactly 1 rule for Filter->Scan"); } @Test - public void testCanConvertToColumnPredicate_ComparisonPredicates() { + public void testCanConvertToColumnPredicate_ComparisonPredicates() throws Exception { // Test that comparison predicates can be converted to ColumnPredicate DataType intType = IntegerType.INSTANCE; @@ -454,27 +528,20 @@ public void testCanConvertToColumnPredicate_ComparisonPredicates() { PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); - try { - java.lang.reflect.Method method = rule.getClass().getDeclaredMethod("canConvertToColumnPredicate", - org.apache.doris.nereids.trees.expressions.Expression.class); - method.setAccessible(true); - - boolean result1 = (boolean) method.invoke(rule, eq); - assert result1 : "EqualTo should be convertible to ColumnPredicate"; - - boolean result2 = (boolean) method.invoke(rule, gt); - assert result2 : "GreaterThan should be convertible to ColumnPredicate"; + Method method = rule.getClass().getDeclaredMethod("canConvertToColumnPredicate", Expression.class); + method.setAccessible(true); + boolean result1 = (boolean) method.invoke(rule, eq); + Assertions.assertTrue(result1, "EqualTo should be convertible to ColumnPredicate"); - boolean result3 = (boolean) method.invoke(rule, lt); - assert result3 : "LessThan should be convertible to ColumnPredicate"; + boolean result2 = (boolean) method.invoke(rule, gt); + Assertions.assertTrue(result2, "GreaterThan should be convertible to ColumnPredicate"); - } catch (Exception e) { - throw new RuntimeException("Failed to test canConvertToColumnPredicate", e); - } + boolean result3 = (boolean) method.invoke(rule, lt); + Assertions.assertTrue(result3, "LessThan should be convertible to ColumnPredicate"); } @Test - public void testCanConvertToColumnPredicate_InAndNullPredicates() { + public void testCanConvertToColumnPredicateInAndNullPredicates() throws Exception { // Test IN and IS NULL predicates DataType intType = IntegerType.INSTANCE; @@ -489,24 +556,18 @@ public void testCanConvertToColumnPredicate_InAndNullPredicates() { PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); - try { - java.lang.reflect.Method method = rule.getClass().getDeclaredMethod("canConvertToColumnPredicate", - org.apache.doris.nereids.trees.expressions.Expression.class); - method.setAccessible(true); - - boolean result1 = (boolean) method.invoke(rule, inPred); - assert result1 : "IN predicate should be convertible to ColumnPredicate"; + Method method = rule.getClass().getDeclaredMethod("canConvertToColumnPredicate", Expression.class); + method.setAccessible(true); - boolean result2 = (boolean) method.invoke(rule, isNull); - assert result2 : "IS NULL should be convertible to ColumnPredicate"; + boolean result1 = (boolean) method.invoke(rule, inPred); + Assertions.assertTrue(result1, "IN predicate should be convertible to ColumnPredicate"); - } catch (Exception e) { - throw new RuntimeException("Failed to test canConvertToColumnPredicate with IN/NULL", e); - } + boolean result2 = (boolean) method.invoke(rule, isNull); + Assertions.assertTrue(result2, "IS NULL should be convertible to ColumnPredicate"); } @Test - public void testIsIndexPushdownFunction_IpAddressInRange() { + public void testIsIndexPushdownFunctionIpAddressInRange() throws Exception { // Test IP address range function detection DataType varcharType = VarcharType.SYSTEM_DEFAULT; @@ -517,21 +578,16 @@ public void testIsIndexPushdownFunction_IpAddressInRange() { PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); - try { - java.lang.reflect.Method method = rule.getClass().getDeclaredMethod("isIndexPushdownFunction", - org.apache.doris.nereids.trees.expressions.Expression.class); - method.setAccessible(true); + Method method = rule.getClass().getDeclaredMethod("isIndexPushdownFunction", Expression.class); + method.setAccessible(true); - boolean result = (boolean) method.invoke(rule, ipRangeFunc); - assert result : "IsIpAddressInRange should be detected as index pushdown function"; + boolean result = (boolean) method.invoke(rule, ipRangeFunc); + Assertions.assertTrue(result, "IsIpAddressInRange should be detected as index pushdown function"); - } catch (Exception e) { - throw new RuntimeException("Failed to test isIndexPushdownFunction with IP range", e); - } } @Test - public void testIsIndexPushdownFunction_MultiMatch() { + public void testIsIndexPushdownFunctionMultiMatch() throws Exception { // Test multi-match function detection DataType varcharType = VarcharType.SYSTEM_DEFAULT; @@ -543,24 +599,17 @@ public void testIsIndexPushdownFunction_MultiMatch() { PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); - try { - java.lang.reflect.Method method = rule.getClass().getDeclaredMethod("isIndexPushdownFunction", - org.apache.doris.nereids.trees.expressions.Expression.class); - method.setAccessible(true); - - boolean result1 = (boolean) method.invoke(rule, multiMatchFunc); - assert result1 : "MultiMatch should be detected as index pushdown function"; + Method method = rule.getClass().getDeclaredMethod("isIndexPushdownFunction", Expression.class); + method.setAccessible(true); + boolean result1 = (boolean) method.invoke(rule, multiMatchFunc); + Assertions.assertTrue(result1, "MultiMatch should be detected as index pushdown function"); - boolean result2 = (boolean) method.invoke(rule, multiMatchAnyFunc); - assert result2 : "MultiMatchAny should be detected as index pushdown function"; - - } catch (Exception e) { - throw new RuntimeException("Failed to test isIndexPushdownFunction with MultiMatch", e); - } + boolean result2 = (boolean) method.invoke(rule, multiMatchAnyFunc); + Assertions.assertTrue(result2, "MultiMatchAny should be detected as index pushdown function"); } @Test - public void testContainsIndexPushdownFunction_NestedExpression() { + public void testContainsIndexPushdownFunctionNestedExpression() throws Exception { // Test detection of index pushdown functions in nested expressions DataType varcharType = VarcharType.SYSTEM_DEFAULT; @@ -576,68 +625,55 @@ public void testContainsIndexPushdownFunction_NestedExpression() { PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); - try { - java.lang.reflect.Method method = rule.getClass().getDeclaredMethod("containsIndexPushdownFunction", - org.apache.doris.nereids.trees.expressions.Expression.class); - method.setAccessible(true); - - // Test expression containing index pushdown function - boolean result1 = (boolean) method.invoke(rule, ipRangeFunc); - assert result1 : "Expression containing IsIpAddressInRange should be detected"; - - // Test expression not containing index pushdown function - boolean result2 = (boolean) method.invoke(rule, countCondition); - assert !result2 : "Regular comparison should not be detected as containing index pushdown function"; + Method method = rule.getClass().getDeclaredMethod("containsIndexPushdownFunction", Expression.class); + method.setAccessible(true); + // Test expression containing index pushdown function + boolean result1 = (boolean) method.invoke(rule, ipRangeFunc); + Assertions.assertTrue(result1, "Expression containing IsIpAddressInRange should be detected"); - } catch (Exception e) { - throw new RuntimeException("Failed to test containsIndexPushdownFunction", e); - } + // Test expression not containing index pushdown function + boolean result2 = (boolean) method.invoke(rule, countCondition); + Assertions.assertFalse(result2, + "Regular comparison should not be detected as containing index pushdown function"); } @Test - public void testIsSupportedVirtualColumnType() { - try { - PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); - - // Use reflection to access the private method - java.lang.reflect.Method method = rule.getClass() - .getDeclaredMethod("isSupportedVirtualColumnType", - org.apache.doris.nereids.trees.expressions.Expression.class); - method.setAccessible(true); - - DataType intType = IntegerType.INSTANCE; - DataType varcharType = VarcharType.createVarcharType(100); + public void testIsSupportedVirtualColumnType() throws Exception { + PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); - // Test supported types - SlotReference intSlot = new SlotReference("int_col", intType); - SlotReference varcharSlot = new SlotReference("varchar_col", varcharType); - - // Test basic arithmetic expression with supported types (should return integer) - Add intAddition = new Add(intSlot, new IntegerLiteral(1)); - boolean intSupported = (boolean) method.invoke(rule, intAddition); - assert intSupported : "Integer arithmetic expression should be supported for virtual columns"; - - // Test string concatenation (should return varchar) - Concat stringConcat = new Concat(varcharSlot, new StringLiteral("_suffix")); - boolean stringSupported = (boolean) method.invoke(rule, stringConcat); - assert stringSupported : "String expression should be supported for virtual columns"; - - // Test a complex expression with multiple operations - Multiply complexMath = new Multiply( - new Add(intSlot, new IntegerLiteral(5)), - new IntegerLiteral(2) - ); - boolean complexSupported = (boolean) method.invoke(rule, complexMath); - assert complexSupported : "Complex arithmetic expression should be supported for virtual columns"; - - // Test a CAST expression to string (should be supported) - Cast castToString = new Cast(intSlot, VarcharType.createVarcharType(50)); - boolean castSupported = (boolean) method.invoke(rule, castToString); - assert castSupported : "CAST to supported type should be supported for virtual columns"; + // Use reflection to access the private method + Method method = rule.getClass().getDeclaredMethod("isSupportedVirtualColumnType", Expression.class); + method.setAccessible(true); - } catch (Exception e) { - throw new RuntimeException("Failed to test isSupportedVirtualColumnType", e); - } + DataType intType = IntegerType.INSTANCE; + DataType varcharType = VarcharType.createVarcharType(100); + + // Test supported types + SlotReference intSlot = new SlotReference("int_col", intType); + SlotReference varcharSlot = new SlotReference("varchar_col", varcharType); + // Test basic arithmetic expression with supported types (should return integer) + Add intAddition = new Add(intSlot, new IntegerLiteral(1)); + boolean intSupported = (boolean) method.invoke(rule, intAddition); + Assertions.assertTrue(intSupported, "Integer arithmetic expression should be supported for virtual columns"); + + // Test string concatenation (should return varchar) + Concat stringConcat = new Concat(varcharSlot, new StringLiteral("_suffix")); + boolean stringSupported = (boolean) method.invoke(rule, stringConcat); + Assertions.assertTrue(stringSupported, "String expression should be supported for virtual columns"); + + // Test a complex expression with multiple operations + Multiply complexMath = new Multiply( + new Add(intSlot, new IntegerLiteral(5)), + new IntegerLiteral(2) + ); + boolean complexSupported = (boolean) method.invoke(rule, complexMath); + Assertions.assertTrue(complexSupported, + "Complex arithmetic expression should be supported for virtual columns"); + + // Test a CAST expression to string (should be supported) + Cast castToString = new Cast(intSlot, VarcharType.createVarcharType(50)); + boolean castSupported = (boolean) method.invoke(rule, castToString); + Assertions.assertTrue(castSupported, "CAST to supported type should be supported for virtual columns"); } @Test @@ -646,9 +682,7 @@ public void testUnsupportedVirtualColumnType() { PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); // Use reflection to access the private method - java.lang.reflect.Method method = rule.getClass() - .getDeclaredMethod("isSupportedVirtualColumnType", - org.apache.doris.nereids.trees.expressions.Expression.class); + Method method = rule.getClass().getDeclaredMethod("isSupportedVirtualColumnType", Expression.class); method.setAccessible(true); // Test expression that might return an unsupported type @@ -661,7 +695,7 @@ public void testUnsupportedVirtualColumnType() { Lambda lambdaExpr = new Lambda(ImmutableList.of("int_col"), new Add(intSlot, new IntegerLiteral(1))); boolean lambdaSupported = (boolean) method.invoke(rule, lambdaExpr); - assert !lambdaSupported : "Lambda expressions should not be supported for virtual columns"; + Assertions.assertFalse(lambdaSupported, "Lambda expressions should not be supported for virtual columns"); } catch (Exception e) { // Expected for some unsupported expressions @@ -670,79 +704,65 @@ public void testUnsupportedVirtualColumnType() { } @Test - public void testVirtualColumnTypeFilteringInExtraction() { + public void testVirtualColumnTypeFilteringInExtraction() throws Exception { // Test that the extractRepeatedSubExpressions method properly filters out // expressions with unsupported types during virtual column creation - try { - PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); + PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); - // Use reflection to access the private extractRepeatedSubExpressions method - java.lang.reflect.Method extractMethod = rule.getClass() - .getDeclaredMethod("extractRepeatedSubExpressions", - org.apache.doris.nereids.trees.plans.logical.LogicalFilter.class, - java.util.Optional.class, - java.util.Map.class, - com.google.common.collect.ImmutableList.Builder.class); - extractMethod.setAccessible(true); + // Use reflection to access the private extractRepeatedSubExpressions method + Method extractMethod = rule.getClass() + .getDeclaredMethod("extractRepeatedSubExpressions", + LogicalFilter.class, + Optional.class, + Map.class, + Map.class); + extractMethod.setAccessible(true); - DataType intType = IntegerType.INSTANCE; - SlotReference x = new SlotReference("x", intType); - SlotReference y = new SlotReference("y", intType); + DataType intType = IntegerType.INSTANCE; + SlotReference x = new SlotReference("x", intType); + SlotReference y = new SlotReference("y", intType); - // Create expressions that should be supported (arithmetic operations return int) - Add supportedAdd1 = new Add(x, y); - Add supportedAdd2 = new Add(x, y); - Add supportedAdd3 = new Add(x, y); + // Create expressions that should be supported (arithmetic operations return int) + Add supportedAdd1 = new Add(x, y); + Add supportedAdd2 = new Add(x, y); + Add supportedAdd3 = new Add(x, y); - // Create filter conditions using the repeated supported expression - GreaterThan gt1 = new GreaterThan(supportedAdd1, new IntegerLiteral(10)); - GreaterThan gt2 = new GreaterThan(supportedAdd2, new IntegerLiteral(20)); - GreaterThan gt3 = new GreaterThan(supportedAdd3, new IntegerLiteral(30)); + // Create filter conditions using the repeated supported expression + GreaterThan gt1 = new GreaterThan(supportedAdd1, new IntegerLiteral(10)); + GreaterThan gt2 = new GreaterThan(supportedAdd2, new IntegerLiteral(20)); + GreaterThan gt3 = new GreaterThan(supportedAdd3, new IntegerLiteral(30)); - // Create OLAP scan and filter - LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); - LogicalFilter filter = new LogicalFilter<>( - ImmutableSet.of(gt1, gt2, gt3), scan); - - // Test the extraction method - java.util.Map replaceMap = - new java.util.HashMap<>(); - com.google.common.collect.ImmutableList.Builder - virtualColumnsBuilder = com.google.common.collect.ImmutableList.builder(); - - // Call the extraction method - extractMethod.invoke(rule, filter, java.util.Optional.empty(), replaceMap, virtualColumnsBuilder); - - // Verify that virtual columns were created for supported expressions - java.util.List virtualColumns = - virtualColumnsBuilder.build(); - - // Since Add(x, y) appears 3 times and returns int (supported type), - // it should be included in virtual columns - assert !virtualColumns.isEmpty() : "Should create virtual columns for repeated supported expressions"; - assert replaceMap.size() > 0 : "Should have replacements for supported expressions"; - - // Test that the virtual column expression has a supported type - if (!virtualColumns.isEmpty()) { - org.apache.doris.nereids.trees.expressions.NamedExpression virtualCol = virtualColumns.get(0); - if (virtualCol instanceof org.apache.doris.nereids.trees.expressions.Alias) { - org.apache.doris.nereids.trees.expressions.Alias alias = - (org.apache.doris.nereids.trees.expressions.Alias) virtualCol; - org.apache.doris.nereids.trees.expressions.Expression expr = alias.child(); - - // The expression should be supported by isSupportedVirtualColumnType - java.lang.reflect.Method typeCheckMethod = rule.getClass() - .getDeclaredMethod("isSupportedVirtualColumnType", - org.apache.doris.nereids.trees.expressions.Expression.class); - typeCheckMethod.setAccessible(true); - boolean isSupported = (boolean) typeCheckMethod.invoke(rule, expr); - assert isSupported : "Virtual column expression should have supported type"; - } + // Create OLAP scan and filter + LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + LogicalFilter filter = new LogicalFilter<>(ImmutableSet.of(gt1, gt2, gt3), scan); + + // Test the extraction method + Map replaceMap = Maps.newHashMap(); + Map virtualColumnsMap = Maps.newHashMap(); + + // Call the extraction method + extractMethod.invoke(rule, filter, Optional.empty(), replaceMap, virtualColumnsMap); + + // Verify that virtual columns were created for supported expressions + // Since Add(x, y) appears 3 times and returns int (supported type), + // it should be included in virtual columns + Assertions.assertFalse(virtualColumnsMap.isEmpty(), + "Should create virtual columns for repeated supported expressions"); + Assertions.assertTrue(replaceMap.size() > 0, "Should have replacements for supported expressions"); + + // Test that the virtual column expression has a supported type + for (NamedExpression virtualCol : virtualColumnsMap.values()) { + if (virtualCol instanceof Alias) { + Alias alias = (Alias) virtualCol; + Expression expr = alias.child(); + + // The expression should be supported by isSupportedVirtualColumnType + Method typeCheckMethod = rule.getClass() + .getDeclaredMethod("isSupportedVirtualColumnType", Expression.class); + typeCheckMethod.setAccessible(true); + boolean isSupported = (boolean) typeCheckMethod.invoke(rule, expr); + Assertions.assertTrue(isSupported, "Virtual column expression should have supported type"); } - - } catch (Exception e) { - throw new RuntimeException("Failed to test virtual column type filtering", e); } } @@ -753,17 +773,15 @@ public void testTypeFilteringWithMixedExpressions() { PushDownVirtualColumnsIntoOlapScan rule = new PushDownVirtualColumnsIntoOlapScan(); // Use reflection to access private methods - java.lang.reflect.Method extractMethod = rule.getClass() - .getDeclaredMethod("extractRepeatedSubExpressions", - org.apache.doris.nereids.trees.plans.logical.LogicalFilter.class, - java.util.Optional.class, - java.util.Map.class, - com.google.common.collect.ImmutableList.Builder.class); + Method extractMethod = rule.getClass().getDeclaredMethod("extractRepeatedSubExpressions", + LogicalFilter.class, + Optional.class, + Map.class, + Map.class); extractMethod.setAccessible(true); - java.lang.reflect.Method typeCheckMethod = rule.getClass() - .getDeclaredMethod("isSupportedVirtualColumnType", - org.apache.doris.nereids.trees.expressions.Expression.class); + Method typeCheckMethod = rule.getClass() + .getDeclaredMethod("isSupportedVirtualColumnType", Expression.class); typeCheckMethod.setAccessible(true); DataType intType = IntegerType.INSTANCE; @@ -782,15 +800,13 @@ public void testTypeFilteringWithMixedExpressions() { boolean supportedIsSupported = (boolean) typeCheckMethod.invoke(rule, supportedExpr1); boolean unsupportedIsSupported1 = (boolean) typeCheckMethod.invoke(rule, unsupportedExpr1); boolean unsupportedIsSupported2 = (boolean) typeCheckMethod.invoke(rule, unsupportedExpr2); - - assert supportedIsSupported : "Add expression should be supported"; - assert !unsupportedIsSupported1 : "Lambda expression 1 should not be supported"; - assert !unsupportedIsSupported2 : "Lambda expression 2 should not be supported"; + Assertions.assertTrue(supportedIsSupported, "Add expression should be supported"); + Assertions.assertFalse(unsupportedIsSupported1, "Lambda expression 1 should not be supported"); + Assertions.assertFalse(unsupportedIsSupported2, "Lambda expression 2 should not be supported"); // Verify that both unsupported expressions have the same type checking result - assert unsupportedIsSupported1 == unsupportedIsSupported2 : - "Both lambda expressions should have the same support status"; - + Assertions.assertEquals(unsupportedIsSupported1, unsupportedIsSupported2, + "Both lambda expressions should have the same support status"); // Create filter conditions using both types GreaterThan gt1 = new GreaterThan(supportedExpr1, new IntegerLiteral(10)); GreaterThan gt2 = new GreaterThan(supportedExpr2, new IntegerLiteral(20)); @@ -798,31 +814,21 @@ public void testTypeFilteringWithMixedExpressions() { // since they require specific context, so we focus on the type checking LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); - LogicalFilter filter = new LogicalFilter<>( - ImmutableSet.of(gt1, gt2), scan); + LogicalFilter filter = new LogicalFilter<>(ImmutableSet.of(gt1, gt2), scan); // Test extraction - java.util.Map replaceMap = - new java.util.HashMap<>(); - com.google.common.collect.ImmutableList.Builder - virtualColumnsBuilder = com.google.common.collect.ImmutableList.builder(); - - extractMethod.invoke(rule, filter, java.util.Optional.empty(), replaceMap, virtualColumnsBuilder); - - // Verify results: only supported expressions should create virtual columns - java.util.List virtualColumns = - virtualColumnsBuilder.build(); + Map replaceMap = Maps.newHashMap(); + Map virtualColumnsMap = Maps.newHashMap(); + extractMethod.invoke(rule, filter, Optional.empty(), replaceMap, virtualColumnsMap); // Should have virtual columns only for supported expressions - for (org.apache.doris.nereids.trees.expressions.NamedExpression virtualCol : virtualColumns) { - if (virtualCol instanceof org.apache.doris.nereids.trees.expressions.Alias) { - org.apache.doris.nereids.trees.expressions.Alias alias = - (org.apache.doris.nereids.trees.expressions.Alias) virtualCol; - org.apache.doris.nereids.trees.expressions.Expression expr = alias.child(); + for (NamedExpression virtualCol : virtualColumnsMap.values()) { + if (virtualCol instanceof Alias) { + Alias alias = (Alias) virtualCol; + Expression expr = alias.child(); boolean isSupported = (boolean) typeCheckMethod.invoke(rule, expr); - assert isSupported : "All virtual column expressions should have supported types"; + Assertions.assertTrue(isSupported, "All virtual column expressions should have supported types"); } }