diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniform.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniform.java index d5b33dc5488423..15435a0850408a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniform.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniform.java @@ -32,6 +32,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.Utils; @@ -81,6 +82,11 @@ public Plan visit(Plan plan, Map replaceMap) { public Plan visitLogicalAggregate(LogicalAggregate aggregate, Map replaceMap) { aggregate = visitChildren(this, aggregate, replaceMap); aggregate = (LogicalAggregate) exprIdReplacer.rewriteExpr(aggregate, replaceMap); + if (aggregate.getSourceRepeat().isPresent()) { + LogicalRepeat sourceRepeat = (LogicalRepeat) exprIdReplacer.rewriteExpr( + aggregate.getSourceRepeat().get(), replaceMap); + aggregate = aggregate.withSourceRepeat(sourceRepeat); + } if (aggregate.getGroupByExpressions().isEmpty() || aggregate.getSourceRepeat().isPresent()) { return aggregate; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java index 5e065fa3724b08..a6dc37f990e815 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java @@ -28,12 +28,16 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.plans.Plan; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Map; +import java.util.Optional; /** replace SlotReference ExprId in logical plans */ public class ExprIdRewriter extends ExpressionRewrite { @@ -74,6 +78,25 @@ public Plan rewriteExpr(Plan plan, Map replaceMap) { * SlotReference:a#0 -> a#3, a#1 -> a#7 * */ public static class ReplaceRule implements ExpressionPatternRuleFactory { + private static final DefaultExpressionRewriter> SLOT_REPLACER = + new DefaultExpressionRewriter>() { + @Override + public Expression visitSlotReference(SlotReference slot, Map replaceMap) { + ExprId newId = replaceMap.get(slot.getExprId()); + if (newId == null) { + return slot; + } + ExprId lastId = newId; + while (true) { + newId = replaceMap.get(lastId); + if (newId == null) { + return slot.withExprId(lastId); + } else { + lastId = newId; + } + } + } + }; private final Map replaceMap; public ReplaceRule(Map replaceMap) { @@ -85,14 +108,30 @@ public List> buildRules() { return ImmutableList.of( matchesType(SlotReference.class).thenApply(ctx -> { Slot slot = ctx.expr; - if (replaceMap.containsKey(slot.getExprId())) { - ExprId newId = replaceMap.get(slot.getExprId()); - while (replaceMap.containsKey(newId)) { - newId = replaceMap.get(newId); + return slot.accept(SLOT_REPLACER, replaceMap); + }), + matchesType(VirtualSlotReference.class).thenApply(ctx -> { + VirtualSlotReference virtualSlot = ctx.expr; + return virtualSlot.accept(new DefaultExpressionRewriter>() { + @Override + public Expression visitVirtualReference(VirtualSlotReference virtualSlot, + Map replaceMap) { + Optional originExpression = virtualSlot.getOriginExpression(); + if (!originExpression.isPresent()) { + return virtualSlot; + } + GroupingScalarFunction groupingScalarFunction = originExpression.get(); + GroupingScalarFunction rewrittenFunction = + (GroupingScalarFunction) groupingScalarFunction.accept( + SLOT_REPLACER, replaceMap); + if (!rewrittenFunction.children().equals(groupingScalarFunction.children())) { + return virtualSlot.withOriginExpressionAndComputeLongValueMethod( + Optional.of(rewrittenFunction), + rewrittenFunction::computeVirtualSlotValue); + } + return virtualSlot; } - return slot.withExprId(newId); - } - return slot; + }, replaceMap); }) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java index 42be621045977b..bac559f407d0ea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java @@ -156,6 +156,13 @@ public VirtualSlotReference withExprId(ExprId exprId) { originExpression, computeLongValueMethod); } + public VirtualSlotReference withOriginExpressionAndComputeLongValueMethod( + Optional originExpression, + Function> computeLongValueMethod) { + return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier, + originExpression, computeLongValueMethod); + } + @Override public Slot withIndexInSql(Pair index) { return this; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index 9bc7fbfd5e1fc3..e59025442dabf7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -311,6 +311,12 @@ public LogicalAggregate withNormalized(List normalizedGroupBy, hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), normalizedChild); } + public LogicalAggregate withSourceRepeat(LogicalRepeat sourceRepeat) { + return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, + generated, hasPushed, Optional.ofNullable(sourceRepeat), + Optional.empty(), Optional.empty(), child()); + } + private boolean isUniqueGroupByUnique(NamedExpression namedExpression) { if (namedExpression.children().size() != 1) { return false; diff --git a/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.groovy b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.groovy index b7e2dccaa7735d..d3fd6a4c293322 100644 --- a/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.groovy +++ b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.groovy @@ -237,4 +237,96 @@ suite("eliminate_group_by_key_by_uniform") { qt_to_limit_join_project_shape "explain shape plan select 1 as c1 from test1 t1 inner join (select * from test2 where b=105) t2 on t1.a=t2.a group by c1;" qt_to_limit_project_uniform_shape "explain shape plan select 1 as c1 from eli_gbk_by_uniform_t group by c1" qt_to_limit_multi_group_by_shape "explain shape plan select 2 as c1 from eli_gbk_by_uniform_t where a=1 group by c1,a" + + // test when has repeat above agg + + sql """drop table if exists test_event""" + sql """ + CREATE TABLE `test_event` ( + `@dt` DATETIME NOT NULL COMMENT '', + `@event_name` VARCHAR(255) NOT NULL COMMENT '', + `@user_id` VARCHAR(100) NOT NULL COMMENT '', + `@event_time` DATETIME NOT NULL COMMENT '', + `@event_property_1` VARCHAR(255) NULL + ) + ENGINE=OLAP + DUPLICATE KEY(`@dt`, `@event_name`, `@user_id`) + COMMENT '' + PARTITION BY RANGE(`@dt`) + ( + PARTITION p202509 VALUES [('2025-09-01 00:00:00'), ('2025-10-05 00:00:00')) + ) + DISTRIBUTED BY HASH(`@user_id`) BUCKETS 10 + PROPERTIES ( + "replication_num" = "1", + "dynamic_partition.enable" = "true", + "dynamic_partition.time_unit" = "MONTH", + "dynamic_partition.start" = "-2147483648", + "dynamic_partition.end" = "3", + "dynamic_partition.prefix" = "p", + "dynamic_partition.buckets" = "10" + ); + """ + + sql """ + INSERT INTO `test_event` (`@dt`, `@event_name`, `@user_id`, `@event_time`, `@event_property_1`) + VALUES + ('2025-09-03 10:00:00', 'shop_buy', 'user_A', '2025-09-03 10:00:00', 'prop_A1'), + ('2025-09-03 10:01:00', 'shop_buy', 'user_A', '2025-09-03 10:01:00', 'prop_A2'), + ('2025-09-04 15:30:00', 'shop_buy', 'user_A', '2025-09-04 15:30:00', 'prop_A3'), + ('2025-09-05 08:00:00', 'shop_buy', 'user_B', '2025-09-05 08:00:00', 'prop_B1'), + ('2025-09-05 08:05:00', 'shop_buy', 'user_B', '2025-09-05 08:05:00', 'prop_B2'), + ('2025-09-09 23:59:59', 'shop_buy', 'user_C', '2025-09-09 23:59:59', 'prop_C1'), + ('2025-10-01 00:00:00', 'shop_buy', 'user_D', '2025-10-01 00:00:00', 'prop_D1'); + """ + + sql """ + SELECT + CASE WHEN GROUPING(event_date) = 1 THEN '(TOTAL)' ELSE CAST(event_date AS VARCHAR) END AS event_date, + user_id, + MAX(conversion_level) AS conversion_level, + CASE WHEN GROUPING(event_name_group) = 1 THEN '(TOTAL)' ELSE event_name_group END AS event_name_group +FROM + ( + SELECT + src.event_date, + src.user_id, + WINDOW_FUNNEL( + 3600 * 24 * 1, + 'default', + src.event_time, + src.event_name = 'shop_buy', + src.event_name = 'shop_buy' + ) AS conversion_level, + src.event_name_group + FROM + ( + SELECT + CAST(etb.`@dt` AS DATE) AS event_date, + etb.`@event_name` AS event_name, + etb.`@event_time` AS event_time, + etb.`@event_name` AS event_name_group, + etb.`@user_id` AS user_id + FROM + `test_event` AS etb + WHERE + etb.`@dt` between '2025-09-03 02:00:00' AND '2025-09-10 01:59:59' + AND etb.`@event_name` = 'shop_buy' + AND etb.`@user_id` IS NOT NULL + AND etb.`@user_id` > '0' + ) AS src + GROUP BY + src.event_date, + src.user_id, + src.event_name_group + ) AS fwt +GROUP BY + GROUPING SETS ( + (user_id), + (user_id, event_date), + (user_id, event_name_group), + (user_id, event_date, event_name_group) + ); + + """ } \ No newline at end of file