Skip to content

Commit fd5ffc1

Browse files
morrySnowYour Name
authored andcommitted
[fix](enforcer) shuffle if has continuous project or filter on cte consumer (#58964)
Related PR: #21412 Problem Summary: This pull request improves the handling of distribution properties (specifically "must shuffle") for `PhysicalProject` and `PhysicalFilter` nodes in the query planner, and adds comprehensive unit tests to ensure correctness. The main logic ensures that when certain child nodes require shuffling, the planner correctly adjusts the distribution requirements, especially in the presence of `Project`, `Filter`, and `Limit` nodes. Key changes include: **Distribution Property Handling Enhancements:** * Added logic in `ChildrenPropertiesRegulator` to check if a child node under a `PhysicalProject` or `PhysicalFilter` requires a "must shuffle" distribution, and to adjust the children’s properties accordingly. This is done via the new `mustShuffleUnderProjectOrFilter` method. * Included `PhysicalLimit` in the set of nodes that can trigger a shuffle requirement, by updating imports and logic. **Testing Improvements:** * Added a new test class `ChildrenPropertiesRegulatorTest.java` with detailed unit tests for the handling of "must shuffle" properties under `Project`, `Filter`, and `Limit` nodes. These tests use mocks to simulate various plan trees and assert correct distribution specification propagation. **Regression Test Coverage:** * Added a new regression test in `cte.groovy` to verify correct behavior when multiple `Project` nodes are present on a CTE consumer, ensuring the planner handles such cases as expected. These changes collectively make the planner more robust in handling complex plan trees with respect to distribution requirements, and ensure correctness through thorough testing.
1 parent 9ac212e commit fd5ffc1

File tree

3 files changed

+211
-0
lines changed

3 files changed

+211
-0
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
3838
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
3939
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
40+
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
4041
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
4142
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
4243
import org.apache.doris.nereids.trees.plans.physical.PhysicalPartitionTopN;
@@ -277,6 +278,17 @@ public List<List<PhysicalProperties>> visitPhysicalFilter(PhysicalFilter<? exten
277278
if (children.get(0).getPlan() instanceof PhysicalDistribute) {
278279
return ImmutableList.of();
279280
}
281+
DistributionSpec distributionSpec = originChildrenProperties.get(0).getDistributionSpec();
282+
// process must shuffle
283+
if (distributionSpec instanceof DistributionSpecMustShuffle) {
284+
Plan child = filter.child();
285+
Plan realChild = getChildPhysicalPlan(child);
286+
if (realChild instanceof PhysicalProject
287+
|| realChild instanceof PhysicalFilter
288+
|| realChild instanceof PhysicalLimit) {
289+
visit(filter, context);
290+
}
291+
}
280292
return ImmutableList.of(originChildrenProperties);
281293
}
282294

@@ -308,6 +320,19 @@ private boolean isBucketShuffleDownGrade(Plan oneSidePlan, DistributionSpecHash
308320
}
309321
}
310322

323+
private Plan getChildPhysicalPlan(Plan plan) {
324+
if (!(plan instanceof GroupPlan)) {
325+
return null;
326+
}
327+
GroupPlan groupPlan = (GroupPlan) plan;
328+
if (groupPlan == null || groupPlan.getGroup() == null
329+
|| groupPlan.getGroup().getPhysicalExpressions().isEmpty()) {
330+
return null;
331+
} else {
332+
return groupPlan.getGroup().getPhysicalExpressions().get(0).getPlan();
333+
}
334+
}
335+
311336
private PhysicalOlapScan findDownGradeBucketShuffleCandidate(GroupPlan groupPlan) {
312337
if (groupPlan == null || groupPlan.getGroup() == null
313338
|| groupPlan.getGroup().getPhysicalExpressions().isEmpty()) {
@@ -574,6 +599,20 @@ public List<List<PhysicalProperties>> visitPhysicalProject(PhysicalProject<? ext
574599
if (children.get(0).getPlan() instanceof PhysicalDistribute) {
575600
return ImmutableList.of();
576601
}
602+
DistributionSpec distributionSpec = originChildrenProperties.get(0).getDistributionSpec();
603+
// process must shuffle
604+
if (distributionSpec instanceof DistributionSpecMustShuffle) {
605+
Plan child = project.child();
606+
Plan realChild = getChildPhysicalPlan(child);
607+
if (realChild instanceof PhysicalLimit) {
608+
visit(project, context);
609+
} else if (realChild instanceof PhysicalProject) {
610+
PhysicalProject physicalProject = (PhysicalProject) realChild;
611+
if (!project.canMergeChildProjections(physicalProject)) {
612+
visit(project, context);
613+
}
614+
}
615+
}
577616
return ImmutableList.of(originChildrenProperties);
578617
}
579618

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.properties;
19+
20+
import org.apache.doris.common.Pair;
21+
import org.apache.doris.nereids.CascadesContext;
22+
import org.apache.doris.nereids.cost.Cost;
23+
import org.apache.doris.nereids.cost.CostCalculator;
24+
import org.apache.doris.nereids.jobs.JobContext;
25+
import org.apache.doris.nereids.memo.Group;
26+
import org.apache.doris.nereids.memo.GroupExpression;
27+
import org.apache.doris.nereids.trees.plans.GroupPlan;
28+
import org.apache.doris.nereids.trees.plans.Plan;
29+
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
30+
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
31+
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
32+
33+
import com.google.common.collect.Lists;
34+
import com.google.common.collect.Maps;
35+
import com.google.common.collect.Sets;
36+
import org.junit.jupiter.api.Assertions;
37+
import org.junit.jupiter.api.BeforeEach;
38+
import org.junit.jupiter.api.Test;
39+
import org.mockito.MockedStatic;
40+
import org.mockito.Mockito;
41+
42+
import java.util.ArrayList;
43+
import java.util.BitSet;
44+
import java.util.List;
45+
import java.util.Map;
46+
import java.util.Optional;
47+
48+
public class ChildrenPropertiesRegulatorTest {
49+
50+
private List<GroupExpression> children;
51+
private JobContext mockedJobContext;
52+
private List<PhysicalProperties> originOutputChildrenProperties
53+
= Lists.newArrayList(PhysicalProperties.MUST_SHUFFLE);
54+
55+
@BeforeEach
56+
public void setUp() {
57+
Group childGroup = Mockito.mock(Group.class);
58+
Mockito.when(childGroup.getLogicalProperties()).thenReturn(Mockito.mock(LogicalProperties.class));
59+
GroupExpression child = Mockito.mock(GroupExpression.class);
60+
Mockito.when(child.getOutputProperties(Mockito.any())).thenReturn(PhysicalProperties.MUST_SHUFFLE);
61+
Mockito.when(child.getOwnerGroup()).thenReturn(childGroup);
62+
Map<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> lct = Maps.newHashMap();
63+
lct.put(PhysicalProperties.MUST_SHUFFLE, Pair.of(Cost.zero(), Lists.newArrayList()));
64+
Mockito.when(child.getLowestCostTable()).thenReturn(lct);
65+
children = Lists.newArrayList(child);
66+
67+
mockedJobContext = Mockito.mock(JobContext.class);
68+
Mockito.when(mockedJobContext.getCascadesContext()).thenReturn(Mockito.mock(CascadesContext.class));
69+
70+
}
71+
72+
@Test
73+
public void testMustShuffleProjectProjectCanNotMerge() {
74+
testMustShuffleProject(PhysicalProject.class, DistributionSpecExecutionAny.class, false);
75+
76+
}
77+
78+
@Test
79+
public void testMustShuffleProjectProjectCanMerge() {
80+
testMustShuffleProject(PhysicalProject.class, DistributionSpecMustShuffle.class, true);
81+
82+
}
83+
84+
@Test
85+
public void testMustShuffleProjectFilter() {
86+
testMustShuffleProject(PhysicalFilter.class, DistributionSpecMustShuffle.class, true);
87+
88+
}
89+
90+
@Test
91+
public void testMustShuffleProjectLimit() {
92+
testMustShuffleProject(PhysicalLimit.class, DistributionSpecExecutionAny.class, true);
93+
}
94+
95+
public void testMustShuffleProject(Class<? extends Plan> childClazz,
96+
Class<? extends DistributionSpec> distributeClazz,
97+
boolean canMergeChildProject) {
98+
try (MockedStatic<CostCalculator> mockedCostCalculator = Mockito.mockStatic(CostCalculator.class)) {
99+
mockedCostCalculator.when(() -> CostCalculator.calculateCost(Mockito.any(), Mockito.any(),
100+
Mockito.anyList())).thenReturn(Cost.zero());
101+
mockedCostCalculator.when(() -> CostCalculator.addChildCost(Mockito.any(), Mockito.any(), Mockito.any(),
102+
Mockito.any(), Mockito.anyInt())).thenReturn(Cost.zero());
103+
104+
// project, cannot merge
105+
Plan mockedChild = Mockito.mock(childClazz);
106+
Mockito.when(mockedChild.withGroupExpression(Mockito.any())).thenReturn(mockedChild);
107+
Group mockedGroup = Mockito.mock(Group.class);
108+
List<GroupExpression> physicalExpressions = Lists.newArrayList(new GroupExpression(mockedChild));
109+
Mockito.when(mockedGroup.getPhysicalExpressions()).thenReturn(physicalExpressions);
110+
GroupPlan mockedGroupPlan = Mockito.mock(GroupPlan.class);
111+
Mockito.when(mockedGroupPlan.getGroup()).thenReturn(mockedGroup);
112+
// let AbstractTreeNode's init happy
113+
Mockito.when(mockedGroupPlan.getAllChildrenTypes()).thenReturn(new BitSet());
114+
PhysicalProject parentPlan = new PhysicalProject<>(Lists.newArrayList(), null, mockedGroupPlan);
115+
GroupExpression parent = new GroupExpression(parentPlan);
116+
parentPlan = parentPlan.withGroupExpression(Optional.of(parent));
117+
parentPlan = Mockito.spy(parentPlan);
118+
Mockito.doReturn(canMergeChildProject).when(parentPlan).canMergeChildProjections(Mockito.any());
119+
parent = Mockito.spy(parent);
120+
Mockito.doReturn(parentPlan).when(parent).getPlan();
121+
ChildrenPropertiesRegulator regulator = new ChildrenPropertiesRegulator(parent, children,
122+
new ArrayList<>(originOutputChildrenProperties), null, mockedJobContext);
123+
PhysicalProperties result = regulator.adjustChildrenProperties().get(0).get(0);
124+
Assertions.assertInstanceOf(distributeClazz, result.getDistributionSpec());
125+
}
126+
}
127+
128+
@Test
129+
public void testMustShuffleFilterProject() {
130+
testMustShuffleFilter(PhysicalProject.class);
131+
}
132+
133+
@Test
134+
public void testMustShuffleFilterFilter() {
135+
testMustShuffleFilter(PhysicalFilter.class);
136+
}
137+
138+
@Test
139+
public void testMustShuffleFilterLimit() {
140+
testMustShuffleFilter(PhysicalLimit.class);
141+
}
142+
143+
private void testMustShuffleFilter(Class<? extends Plan> childClazz) {
144+
try (MockedStatic<CostCalculator> mockedCostCalculator = Mockito.mockStatic(CostCalculator.class)) {
145+
mockedCostCalculator.when(() -> CostCalculator.calculateCost(Mockito.any(), Mockito.any(),
146+
Mockito.anyList())).thenReturn(Cost.zero());
147+
mockedCostCalculator.when(() -> CostCalculator.addChildCost(Mockito.any(), Mockito.any(), Mockito.any(),
148+
Mockito.any(), Mockito.anyInt())).thenReturn(Cost.zero());
149+
150+
// project, cannot merge
151+
Plan mockedChild = Mockito.mock(childClazz);
152+
Mockito.when(mockedChild.withGroupExpression(Mockito.any())).thenReturn(mockedChild);
153+
Group mockedGroup = Mockito.mock(Group.class);
154+
List<GroupExpression> physicalExpressions = Lists.newArrayList(new GroupExpression(mockedChild));
155+
Mockito.when(mockedGroup.getPhysicalExpressions()).thenReturn(physicalExpressions);
156+
GroupPlan mockedGroupPlan = Mockito.mock(GroupPlan.class);
157+
Mockito.when(mockedGroupPlan.getGroup()).thenReturn(mockedGroup);
158+
// let AbstractTreeNode's init happy
159+
Mockito.when(mockedGroupPlan.getAllChildrenTypes()).thenReturn(new BitSet());
160+
GroupExpression parent = new GroupExpression(new PhysicalFilter<>(Sets.newHashSet(), null, mockedGroupPlan));
161+
ChildrenPropertiesRegulator regulator = new ChildrenPropertiesRegulator(parent, children,
162+
new ArrayList<>(originOutputChildrenProperties), null, mockedJobContext);
163+
PhysicalProperties result = regulator.adjustChildrenProperties().get(0).get(0);
164+
Assertions.assertInstanceOf(DistributionSpecExecutionAny.class, result.getDistributionSpec());
165+
}
166+
}
167+
}

regression-test/suites/nereids_syntax_p0/cte.groovy

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,5 +334,10 @@ suite("cte") {
334334
sql """
335335
WITH cte_0 AS ( SELECT 1 AS a ), cte_1 AS ( SELECT 1 AS a ) select * from cte_0, cte_1 union select * from cte_0, cte_1
336336
"""
337+
338+
// test more than one project on cte consumer
339+
sql """
340+
with a as (select 1 c1) select *, uuid() from a union all select c2, c2 from (select c1 + 1, uuid() c2 from a) x ;
341+
"""
337342
}
338343

0 commit comments

Comments
 (0)