diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java index 75624068bc765f..96fb96adf9b339 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java @@ -19,7 +19,10 @@ import org.apache.doris.qe.SessionVariable; -class CostV1 implements Cost { +/** + * Cost V1. + */ +public class CostV1 implements Cost { private static final CostV1 INFINITE = new CostV1(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java index f6868a819ccaf1..4c774a3de68b91 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java @@ -39,6 +39,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; +import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit; import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalPartitionTopN; @@ -202,6 +203,17 @@ public Boolean visitPhysicalFilter(PhysicalFilter filter, Void c if (children.get(0).getPlan() instanceof PhysicalDistribute) { return false; } + DistributionSpec distributionSpec = childrenProperties.get(0).getDistributionSpec(); + // process must shuffle + if (distributionSpec instanceof DistributionSpecMustShuffle) { + Plan child = filter.child(); + Plan realChild = getChildPhysicalPlan(child); + if (realChild instanceof PhysicalProject + || realChild instanceof PhysicalFilter + || realChild instanceof PhysicalLimit) { + visit(filter, context); + } + } return true; } @@ -234,6 +246,19 @@ private boolean isBucketShuffleDownGrade(Plan oneSidePlan, DistributionSpecHash } } + private Plan getChildPhysicalPlan(Plan plan) { + if (!(plan instanceof GroupPlan)) { + return null; + } + GroupPlan groupPlan = (GroupPlan) plan; + if (groupPlan == null || groupPlan.getGroup() == null + || groupPlan.getGroup().getPhysicalExpressions().isEmpty()) { + return null; + } else { + return groupPlan.getGroup().getPhysicalExpressions().get(0).getPlan(); + } + } + private PhysicalOlapScan findDownGradeBucketShuffleCandidate(GroupPlan groupPlan) { if (groupPlan == null || groupPlan.getGroup() == null || groupPlan.getGroup().getPhysicalExpressions().isEmpty()) { @@ -467,6 +492,20 @@ public Boolean visitPhysicalProject(PhysicalProject project, Voi if (children.get(0).getPlan() instanceof PhysicalDistribute) { return false; } + DistributionSpec distributionSpec = childrenProperties.get(0).getDistributionSpec(); + // process must shuffle + if (distributionSpec instanceof DistributionSpecMustShuffle) { + Plan child = project.child(); + Plan realChild = getChildPhysicalPlan(child); + if (realChild instanceof PhysicalLimit) { + visit(project, context); + } else if (realChild instanceof PhysicalProject) { + PhysicalProject physicalProject = (PhysicalProject) realChild; + if (!project.canMergeProjections(physicalProject)) { + visit(project, context); + } + } + } return true; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulatorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulatorTest.java new file mode 100644 index 00000000000000..1fefa7753ca394 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulatorTest.java @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.properties; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.cost.Cost; +import org.apache.doris.nereids.cost.CostCalculator; +import org.apache.doris.nereids.cost.CostV1; +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; +import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit; +import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class ChildrenPropertiesRegulatorTest { + + private List children; + private JobContext mockedJobContext; + private List childrenOutputProperties = Lists.newArrayList(PhysicalProperties.MUST_SHUFFLE); + + @BeforeEach + public void setUp() { + Group childGroup = Mockito.mock(Group.class); + Mockito.when(childGroup.getLogicalProperties()).thenReturn(Mockito.mock(LogicalProperties.class)); + GroupExpression child = Mockito.mock(GroupExpression.class); + Mockito.when(child.getOutputProperties(Mockito.any())).thenReturn(PhysicalProperties.MUST_SHUFFLE); + Mockito.when(child.getOwnerGroup()).thenReturn(childGroup); + Map>> lct = Maps.newHashMap(); + lct.put(PhysicalProperties.MUST_SHUFFLE, Pair.of(CostV1.zero(), Lists.newArrayList())); + Mockito.when(child.getLowestCostTable()).thenReturn(lct); + children = Lists.newArrayList(child); + + mockedJobContext = Mockito.mock(JobContext.class); + Mockito.when(mockedJobContext.getCascadesContext()).thenReturn(Mockito.mock(CascadesContext.class)); + + } + + @Test + public void testMustShuffleProjectProjectCanNotMerge() { + testMustShuffleProject(PhysicalProject.class, DistributionSpecExecutionAny.class, false); + + } + + @Test + public void testMustShuffleProjectProjectCanMerge() { + testMustShuffleProject(PhysicalProject.class, DistributionSpecMustShuffle.class, true); + + } + + @Test + public void testMustShuffleProjectFilter() { + testMustShuffleProject(PhysicalFilter.class, DistributionSpecMustShuffle.class, true); + + } + + @Test + public void testMustShuffleProjectLimit() { + testMustShuffleProject(PhysicalLimit.class, DistributionSpecExecutionAny.class, true); + } + + public void testMustShuffleProject(Class childClazz, + Class distributeClazz, + boolean canMergeChildProject) { + try (MockedStatic mockedCostCalculator = Mockito.mockStatic(CostCalculator.class)) { + mockedCostCalculator.when(() -> CostCalculator.calculateCost(Mockito.any(), Mockito.any(), + Mockito.anyList())).thenReturn(CostV1.zero()); + mockedCostCalculator.when(() -> CostCalculator.addChildCost(Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.any(), Mockito.anyInt())).thenReturn(CostV1.zero()); + + // project, cannot merge + Plan mockedChild = Mockito.mock(childClazz); + Mockito.when(mockedChild.withGroupExpression(Mockito.any())).thenReturn(mockedChild); + Group mockedGroup = Mockito.mock(Group.class); + List physicalExpressions = Lists.newArrayList(new GroupExpression(mockedChild)); + Mockito.when(mockedGroup.getPhysicalExpressions()).thenReturn(physicalExpressions); + GroupPlan mockedGroupPlan = Mockito.mock(GroupPlan.class); + Mockito.when(mockedGroupPlan.getGroup()).thenReturn(mockedGroup); + PhysicalProject parentPlan = new PhysicalProject<>(Lists.newArrayList(), null, mockedGroupPlan); + GroupExpression parent = new GroupExpression(parentPlan); + parentPlan = parentPlan.withGroupExpression(Optional.of(parent)); + parentPlan = Mockito.spy(parentPlan); + Mockito.doReturn(canMergeChildProject).when(parentPlan).canMergeProjections(Mockito.any()); + parent = Mockito.spy(parent); + Mockito.doReturn(parentPlan).when(parent).getPlan(); + List childrenProperties = new ArrayList<>(childrenOutputProperties); + ChildrenPropertiesRegulator regulator = new ChildrenPropertiesRegulator(parent, children, + childrenProperties, null, mockedJobContext); + regulator.adjustChildrenProperties(); + PhysicalProperties result = childrenProperties.get(0); + Assertions.assertInstanceOf(distributeClazz, result.getDistributionSpec()); + } + } + + @Test + public void testMustShuffleFilterProject() { + testMustShuffleFilter(PhysicalProject.class); + } + + @Test + public void testMustShuffleFilterFilter() { + testMustShuffleFilter(PhysicalFilter.class); + } + + @Test + public void testMustShuffleFilterLimit() { + testMustShuffleFilter(PhysicalLimit.class); + } + + private void testMustShuffleFilter(Class childClazz) { + try (MockedStatic mockedCostCalculator = Mockito.mockStatic(CostCalculator.class)) { + mockedCostCalculator.when(() -> CostCalculator.calculateCost(Mockito.any(), Mockito.any(), + Mockito.anyList())).thenReturn(CostV1.zero()); + mockedCostCalculator.when(() -> CostCalculator.addChildCost(Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.any(), Mockito.anyInt())).thenReturn(CostV1.zero()); + + // project, cannot merge + Plan mockedChild = Mockito.mock(childClazz); + Mockito.when(mockedChild.withGroupExpression(Mockito.any())).thenReturn(mockedChild); + Group mockedGroup = Mockito.mock(Group.class); + List physicalExpressions = Lists.newArrayList(new GroupExpression(mockedChild)); + Mockito.when(mockedGroup.getPhysicalExpressions()).thenReturn(physicalExpressions); + GroupPlan mockedGroupPlan = Mockito.mock(GroupPlan.class); + Mockito.when(mockedGroupPlan.getGroup()).thenReturn(mockedGroup); + GroupExpression parent = new GroupExpression(new PhysicalFilter<>(Sets.newHashSet(), null, mockedGroupPlan)); + List childrenProperties = new ArrayList<>(childrenOutputProperties); + ChildrenPropertiesRegulator regulator = new ChildrenPropertiesRegulator(parent, children, + childrenProperties, null, mockedJobContext); + regulator.adjustChildrenProperties(); + PhysicalProperties result = childrenProperties.get(0); + Assertions.assertInstanceOf(DistributionSpecExecutionAny.class, result.getDistributionSpec()); + } + } +} diff --git a/regression-test/suites/nereids_syntax_p0/cte.groovy b/regression-test/suites/nereids_syntax_p0/cte.groovy index 5402ffb8e2108e..f6b990b4f4dedc 100644 --- a/regression-test/suites/nereids_syntax_p0/cte.groovy +++ b/regression-test/suites/nereids_syntax_p0/cte.groovy @@ -334,5 +334,10 @@ suite("cte") { sql """ WITH cte_0 AS ( SELECT 1 AS a ), cte_1 AS ( SELECT 1 AS a ) select * from cte_0, cte_1 union select * from cte_0, cte_1 """ + + // test more than one project on cte consumer + sql """ + with a as (select 1 c1) select *, uuid() from a union all select c2, c2 from (select c1 + 1, uuid() c2 from a) x ; + """ }