Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.Utils;
Expand Down Expand Up @@ -81,6 +82,11 @@ public Plan visit(Plan plan, Map<ExprId, ExprId> replaceMap) {
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Map<ExprId, ExprId> replaceMap) {
aggregate = visitChildren(this, aggregate, replaceMap);
aggregate = (LogicalAggregate<? extends Plan>) exprIdReplacer.rewriteExpr(aggregate, replaceMap);
if (aggregate.getSourceRepeat().isPresent()) {
LogicalRepeat<?> sourceRepeat = (LogicalRepeat<?>) exprIdReplacer.rewriteExpr(
aggregate.getSourceRepeat().get(), replaceMap);
aggregate = aggregate.withSourceRepeat(sourceRepeat);
}

if (aggregate.getGroupByExpressions().isEmpty() || aggregate.getSourceRepeat().isPresent()) {
return aggregate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;

import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Map;
import java.util.Optional;

/** replace SlotReference ExprId in logical plans */
public class ExprIdRewriter extends ExpressionRewrite {
Expand Down Expand Up @@ -74,6 +78,25 @@ public Plan rewriteExpr(Plan plan, Map<ExprId, ExprId> replaceMap) {
* SlotReference:a#0 -> a#3, a#1 -> a#7
* */
public static class ReplaceRule implements ExpressionPatternRuleFactory {
private static final DefaultExpressionRewriter<Map<ExprId, ExprId>> SLOT_REPLACER =
new DefaultExpressionRewriter<Map<ExprId, ExprId>>() {
@Override
public Expression visitSlotReference(SlotReference slot, Map<ExprId, ExprId> replaceMap) {
ExprId newId = replaceMap.get(slot.getExprId());
if (newId == null) {
return slot;
}
ExprId lastId = newId;
while (true) {
newId = replaceMap.get(lastId);
if (newId == null) {
return slot.withExprId(lastId);
} else {
lastId = newId;
}
}
}
};
private final Map<ExprId, ExprId> replaceMap;

public ReplaceRule(Map<ExprId, ExprId> replaceMap) {
Expand All @@ -85,14 +108,30 @@ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(SlotReference.class).thenApply(ctx -> {
Slot slot = ctx.expr;
if (replaceMap.containsKey(slot.getExprId())) {
ExprId newId = replaceMap.get(slot.getExprId());
while (replaceMap.containsKey(newId)) {
newId = replaceMap.get(newId);
return slot.accept(SLOT_REPLACER, replaceMap);
}),
matchesType(VirtualSlotReference.class).thenApply(ctx -> {
VirtualSlotReference virtualSlot = ctx.expr;
return virtualSlot.accept(new DefaultExpressionRewriter<Map<ExprId, ExprId>>() {
@Override
public Expression visitVirtualReference(VirtualSlotReference virtualSlot,
Map<ExprId, ExprId> replaceMap) {
Optional<GroupingScalarFunction> originExpression = virtualSlot.getOriginExpression();
if (!originExpression.isPresent()) {
return virtualSlot;
}
GroupingScalarFunction groupingScalarFunction = originExpression.get();
GroupingScalarFunction rewrittenFunction =
(GroupingScalarFunction) groupingScalarFunction.accept(
SLOT_REPLACER, replaceMap);
if (!rewrittenFunction.children().equals(groupingScalarFunction.children())) {
return virtualSlot.withOriginExpressionAndComputeLongValueMethod(
Optional.of(rewrittenFunction),
rewrittenFunction::computeVirtualSlotValue);
}
return virtualSlot;
}
return slot.withExprId(newId);
}
return slot;
}, replaceMap);
})
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ public VirtualSlotReference withExprId(ExprId exprId) {
originExpression, computeLongValueMethod);
}

public VirtualSlotReference withOriginExpressionAndComputeLongValueMethod(
Optional<GroupingScalarFunction> originExpression,
Function<GroupingSetShapes, List<Long>> computeLongValueMethod) {
return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier,
originExpression, computeLongValueMethod);
}

@Override
public Slot withIndexInSql(Pair<Integer, Integer> index) {
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ public LogicalAggregate<Plan> withNormalized(List<Expression> normalizedGroupBy,
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), normalizedChild);
}

public LogicalAggregate<Plan> withSourceRepeat(LogicalRepeat<?> sourceRepeat) {
return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved,
generated, hasPushed, Optional.ofNullable(sourceRepeat),
Optional.empty(), Optional.empty(), child());
}

private boolean isUniqueGroupByUnique(NamedExpression namedExpression) {
if (namedExpression.children().size() != 1) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,96 @@ suite("eliminate_group_by_key_by_uniform") {
qt_to_limit_join_project_shape "explain shape plan select 1 as c1 from test1 t1 inner join (select * from test2 where b=105) t2 on t1.a=t2.a group by c1;"
qt_to_limit_project_uniform_shape "explain shape plan select 1 as c1 from eli_gbk_by_uniform_t group by c1"
qt_to_limit_multi_group_by_shape "explain shape plan select 2 as c1 from eli_gbk_by_uniform_t where a=1 group by c1,a"

// test when has repeat above agg

sql """drop table if exists test_event"""
sql """
CREATE TABLE `test_event` (
`@dt` DATETIME NOT NULL COMMENT '',
`@event_name` VARCHAR(255) NOT NULL COMMENT '',
`@user_id` VARCHAR(100) NOT NULL COMMENT '',
`@event_time` DATETIME NOT NULL COMMENT '',
`@event_property_1` VARCHAR(255) NULL
)
ENGINE=OLAP
DUPLICATE KEY(`@dt`, `@event_name`, `@user_id`)
COMMENT ''
PARTITION BY RANGE(`@dt`)
(
PARTITION p202509 VALUES [('2025-09-01 00:00:00'), ('2025-10-05 00:00:00'))
)
DISTRIBUTED BY HASH(`@user_id`) BUCKETS 10
PROPERTIES (
"replication_num" = "1",
"dynamic_partition.enable" = "true",
"dynamic_partition.time_unit" = "MONTH",
"dynamic_partition.start" = "-2147483648",
"dynamic_partition.end" = "3",
"dynamic_partition.prefix" = "p",
"dynamic_partition.buckets" = "10"
);
"""

sql """
INSERT INTO `test_event` (`@dt`, `@event_name`, `@user_id`, `@event_time`, `@event_property_1`)
VALUES
('2025-09-03 10:00:00', 'shop_buy', 'user_A', '2025-09-03 10:00:00', 'prop_A1'),
('2025-09-03 10:01:00', 'shop_buy', 'user_A', '2025-09-03 10:01:00', 'prop_A2'),
('2025-09-04 15:30:00', 'shop_buy', 'user_A', '2025-09-04 15:30:00', 'prop_A3'),
('2025-09-05 08:00:00', 'shop_buy', 'user_B', '2025-09-05 08:00:00', 'prop_B1'),
('2025-09-05 08:05:00', 'shop_buy', 'user_B', '2025-09-05 08:05:00', 'prop_B2'),
('2025-09-09 23:59:59', 'shop_buy', 'user_C', '2025-09-09 23:59:59', 'prop_C1'),
('2025-10-01 00:00:00', 'shop_buy', 'user_D', '2025-10-01 00:00:00', 'prop_D1');
"""

sql """
SELECT
CASE WHEN GROUPING(event_date) = 1 THEN '(TOTAL)' ELSE CAST(event_date AS VARCHAR) END AS event_date,
user_id,
MAX(conversion_level) AS conversion_level,
CASE WHEN GROUPING(event_name_group) = 1 THEN '(TOTAL)' ELSE event_name_group END AS event_name_group
FROM
(
SELECT
src.event_date,
src.user_id,
WINDOW_FUNNEL(
3600 * 24 * 1,
'default',
src.event_time,
src.event_name = 'shop_buy',
src.event_name = 'shop_buy'
) AS conversion_level,
src.event_name_group
FROM
(
SELECT
CAST(etb.`@dt` AS DATE) AS event_date,
etb.`@event_name` AS event_name,
etb.`@event_time` AS event_time,
etb.`@event_name` AS event_name_group,
etb.`@user_id` AS user_id
FROM
`test_event` AS etb
WHERE
etb.`@dt` between '2025-09-03 02:00:00' AND '2025-09-10 01:59:59'
AND etb.`@event_name` = 'shop_buy'
AND etb.`@user_id` IS NOT NULL
AND etb.`@user_id` > '0'
) AS src
GROUP BY
src.event_date,
src.user_id,
src.event_name_group
) AS fwt
GROUP BY
GROUPING SETS (
(user_id),
(user_id, event_date),
(user_id, event_name_group),
(user_id, event_date, event_name_group)
);

"""
}
Loading