Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public enum ExpressionRuleType {
DIGITAL_MASKING_CONVERT,
DISTINCT_PREDICATES,
EXPR_ID_REWRITE_REPLACE,
VIRTUAL_EXPR_ID_REWRITE_REPLACE,
EXTRACT_COMMON_FACTOR,
FOLD_CONSTANT_ON_BE,
FOLD_CONSTANT_ON_FE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.Utils;
Expand Down Expand Up @@ -79,6 +80,11 @@ public Plan visit(Plan plan, Map<ExprId, ExprId> replaceMap) {
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Map<ExprId, ExprId> replaceMap) {
aggregate = visitChildren(this, aggregate, replaceMap);
aggregate = (LogicalAggregate<? extends Plan>) exprIdReplacer.rewriteExpr(aggregate, replaceMap);
if (aggregate.getSourceRepeat().isPresent()) {
LogicalRepeat<?> sourceRepeat = (LogicalRepeat<?>) exprIdReplacer.rewriteExpr(
aggregate.getSourceRepeat().get(), replaceMap);
aggregate = aggregate.withSourceRepeat(sourceRepeat);
}
Copy link
Contributor

@feiniaofeiafei feiniaofeiafei Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this will lead to agg source repeat be rewritten twice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not same instance, so doesn't repeat be rewritten twice.


if (aggregate.getGroupByExpressions().isEmpty() || aggregate.getSourceRepeat().isPresent()) {
return aggregate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;

import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Map;
import java.util.Optional;

/** replace SlotReference ExprId in logical plans */
public class ExprIdRewriter extends ExpressionRewrite {
Expand Down Expand Up @@ -75,18 +79,10 @@ public Plan rewriteExpr(Plan plan, Map<ExprId, ExprId> replaceMap) {
* SlotReference:a#0 -> a#3, a#1 -> a#7
* */
public static class ReplaceRule implements ExpressionPatternRuleFactory {
private final Map<ExprId, ExprId> replaceMap;

public ReplaceRule(Map<ExprId, ExprId> replaceMap) {
this.replaceMap = replaceMap;
}

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(SlotReference.class).thenApply(ctx -> {
Slot slot = ctx.expr;

private static final DefaultExpressionRewriter<Map<ExprId, ExprId>> SLOT_REPLACER =
new DefaultExpressionRewriter<Map<ExprId, ExprId>>() {
@Override
public Expression visitSlotReference(SlotReference slot, Map<ExprId, ExprId> replaceMap) {
ExprId newId = replaceMap.get(slot.getExprId());
if (newId == null) {
return slot;
Expand All @@ -100,7 +96,44 @@ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
lastId = newId;
}
}
}).toRule(ExpressionRuleType.EXPR_ID_REWRITE_REPLACE)
}
};
private final Map<ExprId, ExprId> replaceMap;

public ReplaceRule(Map<ExprId, ExprId> replaceMap) {
this.replaceMap = replaceMap;
}

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(SlotReference.class).thenApply(ctx -> {
Slot slot = ctx.expr;
return slot.accept(SLOT_REPLACER, replaceMap);
}).toRule(ExpressionRuleType.EXPR_ID_REWRITE_REPLACE),
matchesType(VirtualSlotReference.class).thenApply(ctx -> {
VirtualSlotReference virtualSlot = ctx.expr;
Copy link
Contributor

@feiniaofeiafei feiniaofeiafei Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.there is an attribute "realExpressions" in VirtualSlotReference, should we replace "realExpressions" too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

realExpressions is derived from original Expressions, we have replaced original Expressions rightly, realExpressions is right auto

return virtualSlot.accept(new DefaultExpressionRewriter<Map<ExprId, ExprId>>() {
@Override
public Expression visitVirtualReference(VirtualSlotReference virtualSlot,
Map<ExprId, ExprId> replaceMap) {
Optional<GroupingScalarFunction> originExpression = virtualSlot.getOriginExpression();
if (!originExpression.isPresent()) {
return virtualSlot;
}
GroupingScalarFunction groupingScalarFunction = originExpression.get();
GroupingScalarFunction rewrittenFunction =
(GroupingScalarFunction) groupingScalarFunction.accept(
SLOT_REPLACER, replaceMap);
if (!rewrittenFunction.children().equals(groupingScalarFunction.children())) {
return virtualSlot.withOriginExpressionAndComputeLongValueMethod(
Optional.of(rewrittenFunction),
rewrittenFunction::computeVirtualSlotValue);
}
return virtualSlot;
}
}, replaceMap);
}).toRule(ExpressionRuleType.VIRTUAL_EXPR_ID_REWRITE_REPLACE)
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ public VirtualSlotReference withExprId(ExprId exprId) {
originExpression, computeLongValueMethod);
}

public VirtualSlotReference withOriginExpressionAndComputeLongValueMethod(
Optional<GroupingScalarFunction> originExpression,
Function<GroupingSetShapes, List<Long>> computeLongValueMethod) {
return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier,
originExpression, computeLongValueMethod);
}

@Override
public Slot withIndexInSql(Pair<Integer, Integer> index) {
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ public LogicalAggregate<Plan> withInProjection(boolean withInProjection) {
sourceRepeat, Optional.empty(), Optional.empty(), child());
}

public LogicalAggregate<Plan> withSourceRepeat(LogicalRepeat<?> sourceRepeat) {
return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved,
generated, hasPushed, withInProjection, Optional.ofNullable(sourceRepeat),
Optional.empty(), Optional.empty(), child());
}

private boolean isUniqueGroupByUnique(NamedExpression namedExpression) {
if (namedExpression.children().size() != 1) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,99 @@ suite("eliminate_group_by_key_by_uniform") {
qt_to_limit_join_project_shape "explain shape plan select 1 as c1 from test1 t1 inner join (select * from test2 where b=105) t2 on t1.a=t2.a group by c1 order by 1;"
qt_to_limit_project_uniform_shape "explain shape plan select 1 as c1 from eli_gbk_by_uniform_t group by c1"
qt_to_limit_multi_group_by_shape "explain shape plan select 2 as c1 from eli_gbk_by_uniform_t where a=1 group by c1,a"

// test when has repeat above agg
// disable CONSTANT_PROPAGATION rules to test eliminate aggregate by uniform
sql """set disable_nereids_rules = 'CONSTANT_PROPAGATION';"""


sql """drop table if exists test_event"""
sql """
CREATE TABLE `test_event` (
`@dt` DATETIME NOT NULL COMMENT '',
`@event_name` VARCHAR(255) NOT NULL COMMENT '',
`@user_id` VARCHAR(100) NOT NULL COMMENT '',
`@event_time` DATETIME NOT NULL COMMENT '',
`@event_property_1` VARCHAR(255) NULL
)
ENGINE=OLAP
DUPLICATE KEY(`@dt`, `@event_name`, `@user_id`)
COMMENT ''
PARTITION BY RANGE(`@dt`)
(
PARTITION p202509 VALUES [('2025-09-01 00:00:00'), ('2025-10-05 00:00:00'))
)
DISTRIBUTED BY HASH(`@user_id`) BUCKETS 10
PROPERTIES (
"replication_num" = "1",
"dynamic_partition.enable" = "true",
"dynamic_partition.time_unit" = "MONTH",
"dynamic_partition.start" = "-2147483648",
"dynamic_partition.end" = "3",
"dynamic_partition.prefix" = "p",
"dynamic_partition.buckets" = "10"
);
"""

sql """
INSERT INTO `test_event` (`@dt`, `@event_name`, `@user_id`, `@event_time`, `@event_property_1`)
VALUES
('2025-09-03 10:00:00', 'shop_buy', 'user_A', '2025-09-03 10:00:00', 'prop_A1'),
('2025-09-03 10:01:00', 'shop_buy', 'user_A', '2025-09-03 10:01:00', 'prop_A2'),
('2025-09-04 15:30:00', 'shop_buy', 'user_A', '2025-09-04 15:30:00', 'prop_A3'),
('2025-09-05 08:00:00', 'shop_buy', 'user_B', '2025-09-05 08:00:00', 'prop_B1'),
('2025-09-05 08:05:00', 'shop_buy', 'user_B', '2025-09-05 08:05:00', 'prop_B2'),
('2025-09-09 23:59:59', 'shop_buy', 'user_C', '2025-09-09 23:59:59', 'prop_C1'),
('2025-10-01 00:00:00', 'shop_buy', 'user_D', '2025-10-01 00:00:00', 'prop_D1');
"""

sql """
SELECT
CASE WHEN GROUPING(event_date) = 1 THEN '(TOTAL)' ELSE CAST(event_date AS VARCHAR) END AS event_date,
user_id,
MAX(conversion_level) AS conversion_level,
CASE WHEN GROUPING(event_name_group) = 1 THEN '(TOTAL)' ELSE event_name_group END AS event_name_group
FROM
(
SELECT
src.event_date,
src.user_id,
WINDOW_FUNNEL(
3600 * 24 * 1,
'default',
src.event_time,
src.event_name = 'shop_buy',
src.event_name = 'shop_buy'
) AS conversion_level,
src.event_name_group
FROM
(
SELECT
CAST(etb.`@dt` AS DATE) AS event_date,
etb.`@event_name` AS event_name,
etb.`@event_time` AS event_time,
etb.`@event_name` AS event_name_group,
etb.`@user_id` AS user_id
FROM
`test_event` AS etb
WHERE
etb.`@dt` between '2025-09-03 02:00:00' AND '2025-09-10 01:59:59'
AND etb.`@event_name` = 'shop_buy'
AND etb.`@user_id` IS NOT NULL
AND etb.`@user_id` > '0'
) AS src
GROUP BY
src.event_date,
src.user_id,
src.event_name_group
) AS fwt
GROUP BY
GROUPING SETS (
(user_id),
(user_id, event_date),
(user_id, event_name_group),
(user_id, event_date, event_name_group)
);

"""
}
Loading