Skip to content

Commit

Permalink
[Bug](materialized-view) enable rewrite on select materialized index …
Browse files Browse the repository at this point in the history
…with aggregate mode (apache#24691)

enable rewrite on select materialized index with aggregate mode
  • Loading branch information
BiteTheDDDDt authored and vinlee19 committed Oct 7, 2023
1 parent 0a8383c commit 55cd33a
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.rules.rewrite.mv.AbstractSelectMaterializedIndexRule.SlotContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ExprId;
Expand Down Expand Up @@ -65,6 +66,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.planner.PlanNode;
Expand Down Expand Up @@ -657,12 +659,8 @@ public List<Rule> buildRules() {
* 2. filter indexes that have all the required columns.
* 3. select best index from all the candidate indexes that could use.
*/
private SelectResult select(
LogicalOlapScan scan,
Set<Slot> requiredScanOutput,
Set<Expression> predicates,
List<AggregateFunction> aggregateFunctions,
List<Expression> groupingExprs,
private SelectResult select(LogicalOlapScan scan, Set<Slot> requiredScanOutput, Set<Expression> predicates,
List<AggregateFunction> aggregateFunctions, List<Expression> groupingExprs,
Set<? extends Expression> requiredExpr) {
// remove virtual slot for grouping sets.
Set<Slot> nonVirtualRequiredScanOutput = requiredScanOutput.stream()
Expand All @@ -677,105 +675,57 @@ private SelectResult select(
}

OlapTable table = scan.getTable();
switch (scan.getTable().getKeysType()) {
case AGG_KEYS:
case UNIQUE_KEYS:
case DUP_KEYS:
break;
default:
throw new RuntimeException("Not supported keys type: " + scan.getTable().getKeysType());
}

Map<Boolean, List<MaterializedIndex>> indexesGroupByIsBaseOrNot = table.getVisibleIndex()
.stream()
.collect(Collectors.groupingBy(index -> index.getId() == table.getBaseIndexId()));
if (table.isDupKeysOrMergeOnWrite()) {
// Duplicate-keys table could use base index and indexes that pre-aggregation status is on.
Set<MaterializedIndex> candidatesWithoutRewriting =
indexesGroupByIsBaseOrNot.getOrDefault(false, ImmutableList.of())
.stream()
.filter(index -> checkPreAggStatus(scan, index.getId(), predicates,
aggregateFunctions, groupingExprs).isOn())
.collect(Collectors.toSet());

// try to rewrite bitmap, hll by materialized index columns.
List<AggRewriteResult> candidatesWithRewriting = indexesGroupByIsBaseOrNot.getOrDefault(false,
ImmutableList.of())
.stream()
.filter(index -> !candidatesWithoutRewriting.contains(index))
.map(index -> rewriteAgg(index, scan, nonVirtualRequiredScanOutput, predicates,
aggregateFunctions,
groupingExprs))
.filter(aggRewriteResult -> checkPreAggStatus(scan, aggRewriteResult.index.getId(),
predicates,
// check pre-agg status of aggregate function that couldn't rewrite.
aggFuncsDiff(aggregateFunctions, aggRewriteResult),
groupingExprs).isOn())
.filter(result -> result.success)
.collect(Collectors.toList());

List<MaterializedIndex> haveAllRequiredColumns = Streams.concat(
candidatesWithoutRewriting.stream()
.filter(index -> containAllRequiredColumns(index, scan, nonVirtualRequiredScanOutput,
requiredExpr, predicates)),
candidatesWithRewriting.stream()
.filter(aggRewriteResult -> containAllRequiredColumns(aggRewriteResult.index, scan,
aggRewriteResult.requiredScanOutput,
requiredExpr.stream().map(e -> aggRewriteResult.exprRewriteMap.replaceAgg(e))
.collect(Collectors.toSet()),
predicates))
.map(aggRewriteResult -> aggRewriteResult.index))
.collect(Collectors.toList());

long selectIndexId = selectBestIndex(haveAllRequiredColumns, scan, predicates);
Optional<AggRewriteResult> rewriteResultOpt = candidatesWithRewriting.stream()
.filter(aggRewriteResult -> aggRewriteResult.index.getId() == selectIndexId)
.findAny();
// Pre-aggregation is set to `on` by default for duplicate-keys table.
return new SelectResult(PreAggStatus.on(), selectIndexId,
rewriteResultOpt.map(r -> r.exprRewriteMap).orElse(new ExprRewriteMap()));
} else {
if (scan.getPreAggStatus().isOff()) {
return new SelectResult(scan.getPreAggStatus(),
scan.getTable().getBaseIndexId(), new ExprRewriteMap());
}

Set<MaterializedIndex> candidatesWithoutRewriting = new HashSet<>();

for (MaterializedIndex index : indexesGroupByIsBaseOrNot.getOrDefault(false, ImmutableList.of())) {
final PreAggStatus preAggStatus;
if (preAggEnabledByHint(scan)) {
preAggStatus = PreAggStatus.on();
} else {
preAggStatus = checkPreAggStatus(scan, index.getId(), predicates,
aggregateFunctions, groupingExprs);
}
Set<MaterializedIndex> candidatesWithoutRewriting = indexesGroupByIsBaseOrNot
.getOrDefault(false, ImmutableList.of()).stream()
.filter(index -> preAggEnabledByHint(scan)
|| checkPreAggStatus(scan, index.getId(), predicates, aggregateFunctions, groupingExprs).isOn())
.collect(Collectors.toSet());

// try to rewrite bitmap, hll by materialized index columns.
List<AggRewriteResult> candidatesWithRewriting = indexesGroupByIsBaseOrNot
.getOrDefault(false, ImmutableList.of()).stream()
.filter(index -> !candidatesWithoutRewriting.contains(index))
.map(index -> rewriteAgg(index, scan, nonVirtualRequiredScanOutput, predicates, aggregateFunctions,
groupingExprs))
.filter(aggRewriteResult -> checkPreAggStatus(scan, aggRewriteResult.index.getId(), predicates,
// check pre-agg status of aggregate function that couldn't rewrite.
aggFuncsDiff(aggregateFunctions, aggRewriteResult), groupingExprs).isOn())
.filter(result -> result.success).collect(Collectors.toList());

List<MaterializedIndex> haveAllRequiredColumns = Streams.concat(
candidatesWithoutRewriting.stream()
.filter(index -> containAllRequiredColumns(index, scan, nonVirtualRequiredScanOutput,
requiredExpr, predicates)),
candidatesWithRewriting.stream()
.filter(aggRewriteResult -> containAllRequiredColumns(aggRewriteResult.index, scan,
aggRewriteResult.requiredScanOutput,
requiredExpr.stream().map(e -> aggRewriteResult.exprRewriteMap.replaceAgg(e))
.collect(Collectors.toSet()),
predicates))
.map(aggRewriteResult -> aggRewriteResult.index))
.collect(Collectors.toList());

if (preAggStatus.isOn()) {
candidatesWithoutRewriting.add(index);
}
}
SelectResult baseIndexSelectResult = new SelectResult(
checkPreAggStatus(scan, scan.getTable().getBaseIndexId(),
predicates, aggregateFunctions, groupingExprs),
scan.getTable().getBaseIndexId(), new ExprRewriteMap());
if (candidatesWithoutRewriting.isEmpty()) {
// return early if pre agg status if off.
return baseIndexSelectResult;
} else {
List<MaterializedIndex> rollupsWithAllRequiredCols =
Stream.concat(candidatesWithoutRewriting.stream(), indexesGroupByIsBaseOrNot.get(true).stream())
.filter(index -> containAllRequiredColumns(index, scan, nonVirtualRequiredScanOutput,
requiredExpr, predicates))
.collect(Collectors.toList());

long selectedIndex = selectBestIndex(rollupsWithAllRequiredCols, scan, predicates);
if (selectedIndex == scan.getTable().getBaseIndexId()) {
return baseIndexSelectResult;
}
return new SelectResult(PreAggStatus.on(), selectedIndex, new ExprRewriteMap());
long selectIndexId = selectBestIndex(haveAllRequiredColumns, scan, predicates);
// Pre-aggregation is set to `on` by default for duplicate-keys table.
// In other cases where mv is not hit, preagg may turn off from on.
if (!table.isDupKeysOrMergeOnWrite() && (new CheckContext(scan, selectIndexId)).isBaseIndex()) {
PreAggStatus preagg = scan.getPreAggStatus();
if (preagg.isOn()) {
preagg = checkPreAggStatus(scan, scan.getTable().getBaseIndexId(), predicates, aggregateFunctions,
groupingExprs);
}
return new SelectResult(preagg, selectIndexId, new ExprRewriteMap());
}

Optional<AggRewriteResult> rewriteResultOpt = candidatesWithRewriting.stream()
.filter(aggRewriteResult -> aggRewriteResult.index.getId() == selectIndexId).findAny();
return new SelectResult(PreAggStatus.on(), selectIndexId,
rewriteResultOpt.map(r -> r.exprRewriteMap).orElse(new ExprRewriteMap()));
}

private List<AggregateFunction> aggFuncsDiff(List<AggregateFunction> aggregateFunctions,
Expand Down Expand Up @@ -1191,6 +1141,13 @@ public RewriteContext(CheckContext context, ExprRewriteMap exprRewriteMap) {
}
}

private static Expression castIfNeed(Expression expr, DataType targetType) {
if (expr.getDataType().equals(targetType)) {
return expr;
}
return new Cast(expr, targetType);
}

private static class AggFuncRewriter extends DefaultExpressionRewriter<RewriteContext> {
public static final AggFuncRewriter INSTANCE = new AggFuncRewriter();

Expand All @@ -1212,7 +1169,7 @@ public Expression visitCount(Count count, RewriteContext context) {
// count(distinct col) -> bitmap_union_count(mv_bitmap_union_col)
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(count.child(0));

Expression expr = new ToBitmapWithCheck(new Cast(count.child(0), BigIntType.INSTANCE));
Expression expr = new ToBitmapWithCheck(castIfNeed(count.child(0), BigIntType.INSTANCE));
// count distinct a value column.
if (slotOpt.isPresent() && !context.checkContext.keyNameToColumn.containsKey(
normalizeName(expr.toSql()))) {
Expand Down Expand Up @@ -1425,7 +1382,7 @@ public Expression visitNdv(Ndv ndv, RewriteContext context) {
// ndv on a value column.
if (slotOpt.isPresent() && !context.checkContext.keyNameToColumn.containsKey(
normalizeName(slotOpt.get().toSql()))) {
Expression expr = new Cast(ndv.child(), VarcharType.SYSTEM_DEFAULT);
Expression expr = castIfNeed(ndv.child(), VarcharType.SYSTEM_DEFAULT);
String hllUnionColumn = normalizeName(
CreateMaterializedViewStmt.mvColumnBuilder(AggregateType.HLL_UNION,
CreateMaterializedViewStmt.mvColumnBuilder(new HllHash(expr).toSql())));
Expand Down Expand Up @@ -1459,7 +1416,7 @@ public Expression visitSum(Sum sum, RewriteContext context) {
Optional<Slot> slotOpt = ExpressionUtils.extractSlotOrCastOnSlot(sum.child(0));
if (!sum.isDistinct() && slotOpt.isPresent()
&& !context.checkContext.keyNameToColumn.containsKey(normalizeName(slotOpt.get().toSql()))) {
Expression expr = new Cast(sum.child(), BigIntType.INSTANCE);
Expression expr = castIfNeed(sum.child(), BigIntType.INSTANCE);
String sumColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(AggregateType.SUM,
CreateMaterializedViewStmt.mvColumnBuilder(expr.toSql())));
Column mvColumn = context.checkContext.getColumn(sumColumn);
Expand Down
4 changes: 4 additions & 0 deletions regression-test/data/mv_p0/test_o2/test_o2.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select_mv --
2023-08-16T22:27 ax asd 2

Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@
-- !select_mv --
1 2

-- !select_star --
2020-01-01 1 a 1
2020-01-01 1 a 2
2020-01-02 2 b 2

-- !select_mv --
1 2

60 changes: 60 additions & 0 deletions regression-test/suites/mv_p0/test_o2/test_o2.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import org.codehaus.groovy.runtime.IOGroovyMethods

suite ("test_o2") {
sql """set enable_nereids_planner=true"""
sql """SET enable_fallback_to_original_planner=false"""
sql """ DROP TABLE IF EXISTS o2_order_events; """

sql """
CREATE TABLE `o2_order_events` (
`ts` datetime NULL,
`metric_name` varchar(20) NULL,
`city_id` int(11) NULL,
`platform` varchar(20) NULL,
`vendor_id` int(11) NULL,
`pos_id` int(11) NULL,
`is_instant_restaurant` boolean NULL,
`country_id` int(11) NULL,
`logistics_partner_id` int(11) NULL,
`rpf_order` int(11) NULL,
`rejected_message_id` int(11) NULL,
`count_value` int(11) SUM NULL DEFAULT "0"
) ENGINE=OLAP
AGGREGATE KEY(`ts`, `metric_name`, `city_id`, `platform`, `vendor_id`, `pos_id`, `is_instant_restaurant`, `country_id`, `logistics_partner_id`, `rpf_order`, `rejected_message_id`)
COMMENT 'OLAP'
DISTRIBUTED BY HASH(`metric_name`, `platform`) BUCKETS 2
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);
"""

sql """insert into o2_order_events values ("2023-08-16 22:27:00 ","ax",1,"asd",2,1,1,1,1,1,1,1);"""

createMV ("""
create materialized view o2_order_events_mv as select ts,metric_name,platform,sum(count_value) from o2_order_events group by ts,metric_name,platform;;""")

sql """insert into o2_order_events values ("2023-08-16 22:27:00 ","ax",1,"asd",2,1,1,1,1,1,1,1);"""

explain {
sql("select ts,metric_name,platform,sum(count_value) from o2_order_events group by ts,metric_name,platform;")
contains "(o2_order_events_mv)"
}
qt_select_mv "select ts,metric_name,platform,sum(count_value) from o2_order_events group by ts,metric_name,platform;"
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,34 @@ suite ("testCountDistinctToBitmap") {
contains "(user_tags_mv)"
}
qt_select_mv "select user_id, count(distinct tag_id) a from user_tags group by user_id having a>1 order by a;"


sql """ DROP TABLE IF EXISTS user_tags2; """

sql """ create table user_tags2 (
time_col date,
user_id bigint,
user_name varchar(20),
tag_id bigint)
partition by range (time_col) (partition p1 values less than MAXVALUE) distributed by hash(time_col) buckets 3 properties('replication_num' = '1');
"""

sql """insert into user_tags2 values("2020-01-01",1,"a",1);"""
sql """insert into user_tags2 values("2020-01-02",2,"b",2);"""

createMV("create materialized view user_tags_mv as select user_id, bitmap_union(to_bitmap(tag_id)) from user_tags2 group by user_id;")

sql """insert into user_tags2 values("2020-01-01",1,"a",2);"""

explain {
sql("select * from user_tags2 order by time_col;")
contains "(user_tags2)"
}
qt_select_star "select * from user_tags2 order by time_col,tag_id;"

explain {
sql("select user_id, count(distinct tag_id) a from user_tags2 group by user_id having a>1 order by a;")
contains "(user_tags_mv)"
}
qt_select_mv "select user_id, count(distinct tag_id) a from user_tags2 group by user_id having a>1 order by a;"
}

0 comments on commit 55cd33a

Please sign in to comment.