diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java index e85c6eb8dae6d9..5c5b98b26d797f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java @@ -186,7 +186,13 @@ public Plan visitLogicalAggregate(LogicalAggregate aggregate, De List outputExpressions = aggregate.getOutputExpressions().stream() .map(o -> (NamedExpression) ExpressionDeepCopier.INSTANCE.deepCopy(o, context)) .collect(ImmutableList.toImmutableList()); - return aggregate.withChildGroupByAndOutput(groupByExpressions, outputExpressions, child); + LogicalAggregate copiedAggregate = aggregate.withChildGroupByAndOutput(groupByExpressions, + outputExpressions, child); + Optional> childRepeat = + copiedAggregate.collectFirst(LogicalRepeat.class::isInstance); + return childRepeat.isPresent() ? aggregate.withChildGroupByAndOutputAndSourceRepeat( + groupByExpressions, outputExpressions, child, childRepeat) + : aggregate.withChildGroupByAndOutput(groupByExpressions, outputExpressions, child); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index 7401dbbaea7d17..07a2d1b7d97cda 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -352,6 +352,13 @@ public LogicalAggregate withChildGroupByAndOutput(List groupBy hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), newChild); } + public LogicalAggregate withChildGroupByAndOutputAndSourceRepeat(List groupByExprList, + List outputExpressionList, Plan newChild, + Optional> sourceRepeat) { + return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated, + hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), newChild); + } + public LogicalAggregate withChildAndOutput(CHILD_TYPE child, List outputExpressionList) { return new LogicalAggregate<>(groupByExpressions, outputExpressionList, normalized, ordinalIsResolved, diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java index fcbb4fbc0c2e8b..bc2dbe097f0443 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java @@ -17,22 +17,96 @@ package org.apache.doris.nereids.trees.copier; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; import org.apache.doris.nereids.util.PlanConstructor; +import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + public class LogicalPlanDeepCopierTest { @Test public void testDeepCopyOlapScan() { LogicalOlapScan relationPlan = PlanConstructor.newLogicalOlapScan(0, "a", 0); relationPlan = (LogicalOlapScan) relationPlan.withOperativeSlots(relationPlan.getOutput()); - LogicalOlapScan aCopy = (LogicalOlapScan) relationPlan.accept(LogicalPlanDeepCopier.INSTANCE, new DeepCopierContext()); + LogicalOlapScan aCopy = + (LogicalOlapScan) relationPlan.accept(LogicalPlanDeepCopier.INSTANCE, new DeepCopierContext()); for (Slot opSlot : aCopy.getOperativeSlots()) { Assertions.assertTrue(aCopy.getOutputSet().contains(opSlot)); } } + + @Test + public void testDeepCopyAggregateWithSourceRepeat() { + LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t", 0); + List groupingKeys = scan.getOutput().subList(0, 1); + List> groupingSets = ImmutableList.of( + ImmutableList.of(groupingKeys.get(0)), + ImmutableList.of() + ); + LogicalRepeat repeat = new LogicalRepeat<>( + groupingSets, + scan.getOutput().stream().map(NamedExpression.class::cast).collect(Collectors.toList()), + scan + ); + List groupByExprs = repeat.getOutput().subList(0, 1).stream() + .map(e -> (NamedExpression) e) + .collect(ImmutableList.toImmutableList()); + List outputExprs = repeat.getOutput(); + LogicalAggregate aggregate = new LogicalAggregate( + groupByExprs, + outputExprs, + repeat + ); + aggregate = aggregate.withSourceRepeat(repeat); + DeepCopierContext context = new DeepCopierContext(); + LogicalAggregate copiedAggregate = (LogicalAggregate) aggregate.accept( + LogicalPlanDeepCopier.INSTANCE, + context + ); + Assertions.assertTrue(copiedAggregate.getSourceRepeat().isPresent()); + + Optional> copiedRepeat = + copiedAggregate.collectFirst(LogicalRepeat.class::isInstance); + Assertions.assertTrue(copiedRepeat.isPresent()); + Assertions.assertSame(copiedAggregate.getSourceRepeat().get(), copiedRepeat.get()); + + Assertions.assertNotSame(aggregate, copiedAggregate); + Assertions.assertNotSame(repeat, copiedRepeat.get()); + } + + @Test + public void testDeepCopyAggregateWithoutSourceRepeat() { + LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t", 0); + List groupByExprs = scan.getOutput().subList(0, 1).stream() + .map(e -> (Expression) e) + .collect(ImmutableList.toImmutableList()); + List outputExprs = scan.getOutput(); + + LogicalAggregate aggregate = new LogicalAggregate( + groupByExprs, + outputExprs, + scan + ); + DeepCopierContext context = new DeepCopierContext(); + LogicalAggregate copiedAggregate = (LogicalAggregate) aggregate.accept( + LogicalPlanDeepCopier.INSTANCE, + context + ); + Assertions.assertFalse(copiedAggregate.getSourceRepeat().isPresent()); + Assertions.assertNotSame(aggregate, copiedAggregate); + Assertions.assertEquals(aggregate.getGroupByExpressions().size(), + copiedAggregate.getGroupByExpressions().size()); + } }