diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java index 32a5202bee4f66..5e1346c299114a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java @@ -281,8 +281,7 @@ public List> visitPhysicalFilter(PhysicalFilter> visitPhysicalProject(PhysicalProject children; private JobContext mockedJobContext; private List originOutputChildrenProperties = Lists.newArrayList(PhysicalProperties.MUST_SHUFFLE); @BeforeEach public void setUp() { - Group childGroup = Mockito.mock(Group.class); - Mockito.when(childGroup.getLogicalProperties()).thenReturn(Mockito.mock(LogicalProperties.class)); - GroupExpression child = Mockito.mock(GroupExpression.class); - Mockito.when(child.getOutputProperties(Mockito.any())).thenReturn(PhysicalProperties.MUST_SHUFFLE); - Mockito.when(child.getOwnerGroup()).thenReturn(childGroup); - Map>> lct = Maps.newHashMap(); - lct.put(PhysicalProperties.MUST_SHUFFLE, Pair.of(Cost.zero(), Lists.newArrayList())); - Mockito.when(child.getLowestCostTable()).thenReturn(lct); - children = Lists.newArrayList(child); - mockedJobContext = Mockito.mock(JobContext.class); Mockito.when(mockedJobContext.getCascadesContext()).thenReturn(Mockito.mock(CascadesContext.class)); - } @Test public void testMustShuffleProjectProjectCanNotMerge() { testMustShuffleProject(PhysicalProject.class, DistributionSpecExecutionAny.class, false); - } @Test public void testMustShuffleProjectProjectCanMerge() { testMustShuffleProject(PhysicalProject.class, DistributionSpecMustShuffle.class, true); - } @Test public void testMustShuffleProjectFilter() { testMustShuffleProject(PhysicalFilter.class, DistributionSpecMustShuffle.class, true); - } @Test @@ -111,6 +96,19 @@ public void testMustShuffleProject(Class childClazz, Mockito.when(mockedGroupPlan.getGroup()).thenReturn(mockedGroup); // let AbstractTreeNode's init happy Mockito.when(mockedGroupPlan.getAllChildrenTypes()).thenReturn(new BitSet()); + + List children; + Group childGroup = Mockito.mock(Group.class); + Mockito.when(childGroup.getLogicalProperties()).thenReturn(Mockito.mock(LogicalProperties.class)); + GroupExpression child = Mockito.mock(GroupExpression.class); + Mockito.when(child.getOutputProperties(Mockito.any())).thenReturn(PhysicalProperties.MUST_SHUFFLE); + Mockito.when(child.getOwnerGroup()).thenReturn(childGroup); + Map>> lct = Maps.newHashMap(); + lct.put(PhysicalProperties.MUST_SHUFFLE, Pair.of(Cost.zero(), Lists.newArrayList())); + Mockito.when(child.getLowestCostTable()).thenReturn(lct); + Mockito.when(child.getPlan()).thenReturn(mockedChild); + children = Lists.newArrayList(child); + PhysicalProject parentPlan = new PhysicalProject<>(Lists.newArrayList(), null, mockedGroupPlan); GroupExpression parent = new GroupExpression(parentPlan); parentPlan = parentPlan.withGroupExpression(Optional.of(parent)); @@ -157,6 +155,19 @@ private void testMustShuffleFilter(Class childClazz) { Mockito.when(mockedGroupPlan.getGroup()).thenReturn(mockedGroup); // let AbstractTreeNode's init happy Mockito.when(mockedGroupPlan.getAllChildrenTypes()).thenReturn(new BitSet()); + + List children; + Group childGroup = Mockito.mock(Group.class); + Mockito.when(childGroup.getLogicalProperties()).thenReturn(Mockito.mock(LogicalProperties.class)); + GroupExpression child = Mockito.mock(GroupExpression.class); + Mockito.when(child.getOutputProperties(Mockito.any())).thenReturn(PhysicalProperties.MUST_SHUFFLE); + Mockito.when(child.getOwnerGroup()).thenReturn(childGroup); + Map>> lct = Maps.newHashMap(); + lct.put(PhysicalProperties.MUST_SHUFFLE, Pair.of(Cost.zero(), Lists.newArrayList())); + Mockito.when(child.getLowestCostTable()).thenReturn(lct); + Mockito.when(child.getPlan()).thenReturn(mockedChild); + children = Lists.newArrayList(child); + GroupExpression parent = new GroupExpression(new PhysicalFilter<>(Sets.newHashSet(), null, mockedGroupPlan)); ChildrenPropertiesRegulator regulator = new ChildrenPropertiesRegulator(parent, children, new ArrayList<>(originOutputChildrenProperties), null, mockedJobContext);