From 0eefa88ab936d025dfb3ec7e339568b9dfc2a05f Mon Sep 17 00:00:00 2001 From: jackwener Date: Wed, 20 Dec 2023 15:19:06 +0800 Subject: [PATCH] [test](Nereids): add test for scalar agg --- .../PushDownFilterThroughAggregationTest.java | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java index 830921b9a2b7b2..36cb8cee8d41ec 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java @@ -27,6 +27,8 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -91,6 +93,28 @@ public void pushDownPredicateOneFilterTest() { ); } + @Test + void scalarAgg() { + LogicalPlan plan = new LogicalPlanBuilder(scan) + .agg(ImmutableList.of(), ImmutableList.of((new Sum(scan.getOutput().get(0))).alias("sum"))) + .filter(new If(Literal.of(false), Literal.of(false), Literal.of(false))) + .project(ImmutableList.of(0)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownFilterThroughAggregation()) + .printlnTree() + .matches( + logicalProject( + logicalFilter( + logicalAggregate( + logicalOlapScan() + ) + ) + ) + ); + } + /*- * origin plan: * project @@ -174,7 +198,8 @@ public void pushDownPredicateGroupWithRepeatTest() { logicalAggregate( logicalFilter( logicalRepeat() - ).when(filter -> filter.getConjuncts().equals(ImmutableSet.of(filterPredicateId))) + ).when(filter -> filter.getConjuncts() + .equals(ImmutableSet.of(filterPredicateId))) ) ) ); @@ -195,9 +220,9 @@ public void pushDownPredicateGroupWithRepeatTest() { .matches( logicalProject( logicalFilter( - logicalAggregate( - logicalRepeat() - ) + logicalAggregate( + logicalRepeat() + ) ).when(filter -> filter.getConjuncts().equals(ImmutableSet.of(filterPredicateId))) ) );