Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ public List<Rule> buildRules() {
LogicalProject<LogicalOlapScan> project = topN.child();
LogicalOlapScan scan = project.child();
return pushDown(topN, project, scan, Optional.empty());
}).toRule(RuleType.PUSH_DOWN_VIRTUAL_COLUMNS_INTO_OLAP_SCAN),
}).toRule(RuleType.PUSH_DOWN_VECTOR_TOPN_INTO_OLAP_SCAN),
logicalTopN(logicalProject(logicalFilter(logicalOlapScan())))
.when(t -> t.getOrderKeys().size() == 1).then(topN -> {
LogicalProject<LogicalFilter<LogicalOlapScan>> project = topN.child();
LogicalFilter<LogicalOlapScan> filter = project.child();
LogicalOlapScan scan = filter.child();
return pushDown(topN, project, scan, Optional.of(filter));
}).toRule(RuleType.PUSH_DOWN_VIRTUAL_COLUMNS_INTO_OLAP_SCAN)
}).toRule(RuleType.PUSH_DOWN_VECTOR_TOPN_INTO_OLAP_SCAN)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeString;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.IsIpAddressInRange;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MultiMatch;
Expand Down Expand Up @@ -267,9 +268,10 @@ private void extractRepeatedSubExpressions(LogicalFilter<LogicalOlapScan> filter

for (Expression expr : allExpressions) {
// Skip expressions that contain lambda functions anywhere in the tree
if (expr.anyMatch(e -> e instanceof Lambda)) {
if (expr.anyMatch(e -> e instanceof Lambda)
|| expr.anyMatch(e -> e instanceof GroupingScalarFunction)) {
if (LOG.isDebugEnabled()) {
LOG.debug("Skipping expression containing lambda: {}", expr.toSql());
LOG.debug("Skipping expression containing lambda/grouping: {}", expr.toSql());
}
continue;
}
Expand Down Expand Up @@ -348,6 +350,11 @@ private void collectSubExpressions(Expression expr, Map<Expression, Integer> exp
* @return SkipResult indicating how to handle this expression
*/
private SkipResult shouldSkipExpression(Expression expr) {
// Grouping scalar functions can't be materialized into project/virtual columns.
// If an expression tree contains any grouping function, skip it entirely.
if (expr.anyMatch(e -> e instanceof GroupingScalarFunction)) {
return SkipResult.TERMINATE;
}
// Skip simple slots and literals as they don't benefit from being pushed down
if (expr instanceof Slot || expr.isConstant()) {
return SkipResult.TERMINATE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Concat;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Grouping;
import org.apache.doris.nereids.trees.expressions.functions.scalar.IsIpAddressInRange;
import org.apache.doris.nereids.trees.expressions.functions.scalar.L2Distance;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
Expand Down Expand Up @@ -832,6 +833,62 @@ public void testTypeFilteringWithMixedExpressions() {
}
}

@Test
public void testSkipGroupingFunctionsInFilter() {
// Ensure expressions containing grouping() are completely skipped
LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
SlotReference x = (SlotReference) scan.getOutput().get(0);

Grouping grouping1 = new Grouping(x);
Grouping grouping2 = new Grouping(x);
// even if repeated, should not extract
LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(
ImmutableSet.of(new EqualTo(grouping1, grouping2)), scan);
PlanChecker.from(MemoTestUtils.createConnectContext(), filter)
.applyTopDown(new PushDownVirtualColumnsIntoOlapScan())
.matches(logicalOlapScan().when(o -> o.getVirtualColumns().isEmpty()));
}

@Test
public void testSkipGroupingFunctionsInProject() {
// Ensure grouping() in project is not extracted or altered by CSE
LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
SlotReference x = (SlotReference) scan.getOutput().get(0);

Grouping grouping1 = new Grouping(x);
Grouping grouping2 = new Grouping(x);

LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(ImmutableSet.of(), scan);
List<NamedExpression> projects = ImmutableList.of(
new Alias(grouping1, "g1"),
new Alias(grouping2, "g2"),
new Alias(x, "x")
);
LogicalProject<LogicalFilter<LogicalOlapScan>> project = new LogicalProject<>(projects, filter);

Plan result = PlanChecker.from(MemoTestUtils.createConnectContext(), project)
.applyTopDown(new PushDownVirtualColumnsIntoOlapScan())
.getPlan();

Assertions.assertInstanceOf(LogicalProject.class, result);
LogicalProject<?> resProject = (LogicalProject<?>) result;
Assertions.assertInstanceOf(LogicalFilter.class, resProject.child());
LogicalFilter<?> resFilter = (LogicalFilter<?>) resProject.child();
Assertions.assertInstanceOf(LogicalOlapScan.class, resFilter.child());
LogicalOlapScan resScan = (LogicalOlapScan) resFilter.child();
Assertions.assertTrue(resScan.getVirtualColumns().isEmpty(),
"Grouping in project must not trigger virtual column extraction");

// Ensure grouping aliases remain present
List<String> aliasNames = resProject.getProjects().stream().map(ne -> {
if (ne instanceof Alias) {
return ((Alias) ne).getName();
}
return "";
}).collect(Collectors.toList());
Assertions.assertTrue(aliasNames.contains("g1") && aliasNames.contains("g2"));
}

@Test
void testOnceUniqueFunction() {
LogicalOlapScan olapScan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(),
Expand Down
Loading