Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat](Nereids) support outer join and aggregate bitmap rewrite by mv #28596

Merged
merged 10 commits into from
Dec 20, 2023
Merged
10 changes: 4 additions & 6 deletions fe/fe-core/src/main/java/org/apache/doris/catalog/MTMV.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.doris.mtmv.MTMVRelation;
import org.apache.doris.mtmv.MTMVStatus;
import org.apache.doris.persist.gson.GsonUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.Sets;
import com.google.gson.annotations.SerializedName;
Expand Down Expand Up @@ -128,10 +129,6 @@ public MTMVRelation getRelation() {
return relation;
}

public MTMVCache getCache() {
return cache;
}

public void setCache(MTMVCache cache) {
this.cache = cache;
}
Expand Down Expand Up @@ -193,12 +190,13 @@ public Set<String> getExcludedTriggerTables() {
return Sets.newHashSet(split);
}

public MTMVCache getOrGenerateCache() throws AnalysisException {
// this should use the same connectContext with query, to use the same session variable
public MTMVCache getOrGenerateCache(ConnectContext parent) throws AnalysisException {
if (cache == null) {
writeMvLock();
try {
if (cache == null) {
this.cache = MTMVCache.from(this, MTMVPlanUtil.createMTMVContext(this));
this.cache = MTMVCache.from(this, parent);
}
} finally {
writeMvUnlock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public static MTMVCache from(MTMV mtmv, ConnectContext connectContext) {
? (Plan) ((LogicalResultSink) mvRewrittenPlan).child() : mvRewrittenPlan;
// use rewritten plan output expression currently, if expression rewrite fail,
// consider to use the analyzed plan for output expressions only
List<NamedExpression> mvOutputExpressions = mvRewrittenPlan.getExpressions().stream()
List<NamedExpression> mvOutputExpressions = mvPlan.getExpressions().stream()
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
return new MTMVCache(mvPlan, mvOutputExpressions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.doris.nereids.trees.plans.commands.info.TableNameInfo;
import org.apache.doris.persist.AlterMTMV;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.commons.collections.CollectionUtils;
Expand All @@ -49,7 +50,7 @@ public class MTMVRelationManager implements MTMVHookService {
private Map<BaseTableInfo, Set<BaseTableInfo>> tableMTMVs = Maps.newConcurrentMap();

public Set<BaseTableInfo> getMtmvsByBaseTable(BaseTableInfo table) {
return tableMTMVs.get(table);
return tableMTMVs.getOrDefault(table, ImmutableSet.of());
}

public Set<MTMV> getAvailableMTMVs(List<BaseTableInfo> tableInfos) {
Expand Down
4 changes: 2 additions & 2 deletions fe/fe-core/src/main/java/org/apache/doris/mtmv/MTMVUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ public static Collection<Partition> getMTMVCanRewritePartitions(MTMV mtmv, Conne
List<Partition> res = Lists.newArrayList();
Collection<Partition> allPartitions = mtmv.getPartitions();
// check session variable if enable rewrite
if (!ctx.getSessionVariable().isEnableMvRewrite()) {
if (!ctx.getSessionVariable().isEnableMaterializedViewRewrite()) {
return res;
}
MTMVRelation mtmvRelation = mtmv.getRelation();
Expand Down Expand Up @@ -438,7 +438,7 @@ private static long getTableMinVisibleVersionTime(OlapTable table) {
* @param relatedTable
* @return mv.partitionId ==> relatedTable.partitionId
*/
private static Map<Long, Set<Long>> getMvToBasePartitions(MTMV mtmv, OlapTable relatedTable)
public static Map<Long, Set<Long>> getMvToBasePartitions(MTMV mtmv, OlapTable relatedTable)
throws AnalysisException {
HashMap<Long, Set<Long>> res = Maps.newHashMap();
Map<Long, PartitionItem> relatedTableItems = relatedTable.getPartitionInfo().getIdToItem(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ public void setOuterScope(@Nullable Scope outerScope) {
}

public List<MaterializationContext> getMaterializationContexts() {
return materializationContexts;
return materializationContexts.stream()
.filter(MaterializationContext::isAvailable)
.collect(Collectors.toList());
}

public void addMaterializationContext(MaterializationContext materializationContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
* @see PlaceholderCollector
*/
public class PlaceholderExpression extends Expression implements AlwaysNotNullable {

private final Class<? extends Expression> delegateClazz;
/**
* 1 based
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewAggregateRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewFilterJoinRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewFilterProjectJoinRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewOnlyJoinRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectAggregateRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectFilterJoinRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectJoinRule;
import org.apache.doris.nereids.rules.implementation.AggregateStrategies;
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
Expand Down Expand Up @@ -222,7 +226,11 @@ public class RuleSet {
.build();

public static final List<Rule> MATERIALIZED_VIEW_RULES = planRuleFactories()
.add(MaterializedViewOnlyJoinRule.INSTANCE)
.add(MaterializedViewProjectJoinRule.INSTANCE)
.add(MaterializedViewFilterJoinRule.INSTANCE)
.add(MaterializedViewFilterProjectJoinRule.INSTANCE)
.add(MaterializedViewProjectFilterJoinRule.INSTANCE)
.add(MaterializedViewAggregateRule.INSTANCE)
.add(MaterializedViewProjectAggregateRule.INSTANCE)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext;
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Any;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
Expand All @@ -39,8 +43,11 @@
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -53,6 +60,17 @@
*/
public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule {

protected static final Map<Expression, Expression>
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = new HashMap<>();
protected final String currentClassName = this.getClass().getSimpleName();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a blank line between static and non-static attr


private final Logger logger = LogManager.getLogger(this.getClass());

static {
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(Any.INSTANCE));
}

@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
Expand All @@ -63,10 +81,12 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
// get view and query aggregate and top plan correspondingly
Pair<Plan, LogicalAggregate<Plan>> viewTopPlanAndAggPair = splitToTopPlanAndAggregate(viewStructInfo);
if (viewTopPlanAndAggPair == null) {
logger.warn(currentClassName + " split to view to top plan and agg fail so return null");
return null;
}
Pair<Plan, LogicalAggregate<Plan>> queryTopPlanAndAggPair = splitToTopPlanAndAggregate(queryStructInfo);
if (queryTopPlanAndAggPair == null) {
logger.warn(currentClassName + " split to query to top plan and agg fail so return null");
return null;
}
// Firstly, handle query group by expression rewrite
Expand All @@ -88,13 +108,14 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
needRollUp = !queryGroupShuttledExpression.equals(viewGroupShuttledExpression);
}
if (!needRollUp) {
List<Expression> rewrittenQueryGroupExpr = rewriteExpression(queryTopPlan.getOutput(),
List<Expression> rewrittenQueryGroupExpr = rewriteExpression(queryTopPlan.getExpressions(),
queryTopPlan,
materializationContext.getMvExprToMvScanExprMapping(),
queryToViewSlotMapping,
true);
if (rewrittenQueryGroupExpr == null) {
if (rewrittenQueryGroupExpr.isEmpty()) {
// can not rewrite, bail out.
logger.debug(currentClassName + " can not rewrite expression when not need roll up");
return null;
}
return new LogicalProject<>(
Expand All @@ -109,12 +130,14 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
viewExpr -> viewExpr.anyMatch(expr -> expr instanceof AggregateFunction
&& ((AggregateFunction) expr).isDistinct()))) {
// if mv aggregate function contains distinct, can not roll up, bail out.
logger.debug(currentClassName + " view contains distinct function so can not roll up");
return null;
}
// split the query top plan expressions to group expressions and functions, if can not, bail out.
Pair<Set<? extends Expression>, Set<? extends Expression>> queryGroupAndFunctionPair
= topPlanSplitToGroupAndFunction(queryTopPlanAndAggPair);
if (queryGroupAndFunctionPair == null) {
logger.warn(currentClassName + " query top plan split to group by and function fail so return null");
return null;
}
// Secondly, try to roll up the agg functions
Expand All @@ -132,30 +155,27 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
for (Expression topExpression : queryTopPlan.getExpressions()) {
// is agg function, try to roll up and rewrite
if (queryTopPlanFunctionSet.contains(topExpression)) {
Expression needRollupShuttledExpr = ExpressionUtils.shuttleExpressionWithLineage(
Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(
topExpression,
queryTopPlan);
if (!mvExprToMvScanExprQueryBased.containsKey(needRollupShuttledExpr)) {
// function can not rewrite by view
return null;
}
// try to roll up
AggregateFunction needRollupAggFunction = (AggregateFunction) topExpression.firstMatch(
AggregateFunction queryFunction = (AggregateFunction) topExpression.firstMatch(
expr -> expr instanceof AggregateFunction);
AggregateFunction rollupAggregateFunction = rollup(needRollupAggFunction,
mvExprToMvScanExprQueryBased.get(needRollupShuttledExpr));
Function rollupAggregateFunction = rollup(queryFunction, queryFunctionShuttled,
mvExprToMvScanExprQueryBased);
if (rollupAggregateFunction == null) {
return null;
}
// key is query need roll up expr, value is mv scan based roll up expr
needRollupExprMap.put(needRollupShuttledExpr, rollupAggregateFunction);
needRollupExprMap.put(queryFunctionShuttled, rollupAggregateFunction);
// rewrite query function expression by mv expression
Expression rewrittenFunctionExpression = rewriteExpression(topExpression,
queryTopPlan,
new ExpressionMapping(needRollupExprMap),
queryToViewSlotMapping,
false);
if (rewrittenFunctionExpression == null) {
logger.debug(currentClassName + " roll up expression can not rewrite by view so return null");
return null;
}
finalAggregateExpressions.add((NamedExpression) rewrittenFunctionExpression);
Expand All @@ -165,6 +185,8 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
ExpressionUtils.shuttleExpressionWithLineage(topExpression, queryTopPlan);
if (!mvExprToMvScanExprQueryBased.containsKey(queryGroupShuttledExpr)) {
// group expr can not rewrite by view
logger.debug(currentClassName
+ " view group expressions can not contains the query group by expression so return null");
return null;
}
groupRewrittenExprMap.put(queryGroupShuttledExpr,
Expand All @@ -177,6 +199,8 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
queryToViewSlotMapping,
true);
if (rewrittenGroupExpression == null) {
logger.debug(currentClassName
+ " query top expression can not be rewritten by view so return null");
return null;
}
finalAggregateExpressions.add((NamedExpression) rewrittenGroupExpression);
Expand Down Expand Up @@ -226,17 +250,33 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
}

// only support sum roll up, support other agg functions later.
private AggregateFunction rollup(AggregateFunction originFunction,
Expression mappedExpression) {
Class<? extends AggregateFunction> rollupAggregateFunction = originFunction.getRollup();
if (rollupAggregateFunction == null) {
private Function rollup(AggregateFunction queryFunction,
Expression queryFunctionShuttled,
Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
if (!(queryFunction instanceof CouldRollUp)) {
return null;
}
if (Sum.class.isAssignableFrom(rollupAggregateFunction)) {
return new Sum(originFunction.isDistinct(), mappedExpression);
Expression rollupParam = null;
if (mvExprToMvScanExprQueryBased.containsKey(queryFunctionShuttled)) {
// function can rewrite by view
rollupParam = mvExprToMvScanExprQueryBased.get(queryFunctionShuttled);
} else {
// function can not rewrite by view, try to use complex roll up param
// eg: query is count(distinct param), mv sql is bitmap_union(to_bitmap(param))
for (Expression mvExprShuttled : mvExprToMvScanExprQueryBased.keySet()) {
if (!(mvExprShuttled instanceof Function)) {
continue;
}
if (isAggregateFunctionEquivalent(queryFunction, (Function) mvExprShuttled)) {
rollupParam = mvExprToMvScanExprQueryBased.get(mvExprShuttled);
}
}
}
// can rollup return null
return null;
if (rollupParam == null) {
return null;
}
// do roll up
return ((CouldRollUp) queryFunction).constructRollUp(rollupParam);
}

private Pair<Set<? extends Expression>, Set<? extends Expression>> topPlanSplitToGroupAndFunction(
Expand Down Expand Up @@ -306,4 +346,23 @@ protected boolean checkPattern(StructInfo structInfo) {
}
return true;
}

private boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) {
if (queryFunction.equals(viewFunction)) {
return true;
}
// get query equivalent function
Expression equivalentFunction = null;
for (Map.Entry<Expression, Expression> entry : AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.entrySet()) {
if (entry.getKey().equals(queryFunction)) {
equivalentFunction = entry.getValue();
}
}
// check is have equivalent function or not
if (equivalentFunction == null) {
return false;
}
// current compare
return equivalentFunction.equals(viewFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -35,6 +38,10 @@
* This is responsible for common join rewriting
*/
public abstract class AbstractMaterializedViewJoinRule extends AbstractMaterializedViewRule {

protected final String currentClassName = this.getClass().getSimpleName();
private final Logger logger = LogManager.getLogger(this.getClass());

@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
Expand All @@ -53,6 +60,7 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
// Can not rewrite, bail out
if (expressionsRewritten.isEmpty()
|| expressionsRewritten.stream().anyMatch(expr -> !(expr instanceof NamedExpression))) {
logger.warn(currentClassName + " expression to rewrite is not named expr so return null");
return null;
}
// record the group id in materializationContext, and when rewrite again in
Expand Down
Loading
Loading