From 2e537c3935a42738c39349eb3d6b946c8cbb4ee5 Mon Sep 17 00:00:00 2001 From: Shuo Wang Date: Thu, 13 Oct 2022 16:25:36 +0800 Subject: [PATCH] [Feature](Nereids) Support materialized index selection. --- .../translator/PhysicalPlanTranslator.java | 12 +- .../jobs/batch/NereidsRewriteJobExecutor.java | 8 +- .../apache/doris/nereids/rules/RuleType.java | 16 +- .../AbstractSelectMaterializedIndexRule.java | 262 +++++ ...SelectMaterializedIndexWithAggregate.java} | 309 ++--- ...lectMaterializedIndexWithoutAggregate.java | 149 +++ .../mv/SelectRollupWithoutAggregate.java | 60 - .../expressions/functions/agg/Count.java | 3 +- .../trees/plans/logical/LogicalOlapScan.java | 27 +- .../apache/doris/planner/OlapScanNode.java | 5 + .../mv/BaseMaterializedIndexSelectTest.java | 51 + .../nereids/rules/mv/SelectMvIndexTest.java | 1035 +++++++++++++++++ ...upTest.java => SelectRollupIndexTest.java} | 191 ++- .../doris/utframe/TestWithFeService.java | 53 +- 14 files changed, 1788 insertions(+), 393 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/AbstractSelectMaterializedIndexRule.java rename fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/{SelectRollupWithAggregate.java => SelectMaterializedIndexWithAggregate.java} (63%) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectMaterializedIndexWithoutAggregate.java delete mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectRollupWithoutAggregate.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/BaseMaterializedIndexSelectTest.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectMvIndexTest.java rename fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/{SelectRollupTest.java => SelectRollupIndexTest.java} (59%) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 818fa16f6257d61..13028b30473fcf5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -33,7 +33,6 @@ import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.Table; import org.apache.doris.common.Pair; -import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.properties.DistributionSpecHash; import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; import org.apache.doris.nereids.properties.OrderKey; @@ -313,22 +312,13 @@ public PlanFragment visitPhysicalOlapScan(PhysicalOlapScan olapScan, PlanTransla tupleDescriptor.setRef(tableRef); olapScanNode.setSelectedPartitionIds(olapScan.getSelectedPartitionIds()); - // TODO: Unify the logic here for all the table types once aggregate/unique key types are fully supported. switch (olapScan.getTable().getKeysType()) { case AGG_KEYS: case UNIQUE_KEYS: - // TODO: Improve complete info for aggregate and unique key types table. + case DUP_KEYS: PreAggStatus preAgg = olapScan.getPreAggStatus(); olapScanNode.setSelectedIndexInfo(olapScan.getSelectedIndexId(), preAgg.isOn(), preAgg.getOffReason()); break; - case DUP_KEYS: - try { - olapScanNode.updateScanRangeInfoByNewMVSelector(olapScan.getSelectedIndexId(), true, ""); - olapScanNode.setIsPreAggregation(true, ""); - } catch (Exception e) { - throw new AnalysisException(e.getMessage()); - } - break; default: throw new RuntimeException("Not supported key type: " + olapScan.getTable().getKeysType()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java index 91ede3e80cb3185..6fff27ede50d163 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java @@ -22,8 +22,8 @@ import org.apache.doris.nereids.rules.RuleSet; import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization; import org.apache.doris.nereids.rules.expression.rewrite.ExpressionOptimization; -import org.apache.doris.nereids.rules.mv.SelectRollupWithAggregate; -import org.apache.doris.nereids.rules.mv.SelectRollupWithoutAggregate; +import org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithAggregate; +import org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithoutAggregate; import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning; import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter; import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit; @@ -73,8 +73,8 @@ public NereidsRewriteJobExecutor(CascadesContext cascadesContext) { .add(topDownBatch(ImmutableList.of(new EliminateLimit()))) .add(topDownBatch(ImmutableList.of(new EliminateFilter()))) .add(topDownBatch(ImmutableList.of(new PruneOlapScanPartition()))) - .add(topDownBatch(ImmutableList.of(new SelectRollupWithAggregate()))) - .add(topDownBatch(ImmutableList.of(new SelectRollupWithoutAggregate()))) + .add(topDownBatch(ImmutableList.of(new SelectMaterializedIndexWithAggregate()))) + .add(topDownBatch(ImmutableList.of(new SelectMaterializedIndexWithoutAggregate()))) .build(); rulesJob.addAll(jobs); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 642d6bfd0e421d1..9a4e192649067eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -108,12 +108,16 @@ public enum RuleType { ELIMINATE_FILTER(RuleTypeClass.REWRITE), ELIMINATE_OUTER(RuleTypeClass.REWRITE), FIND_HASH_CONDITION_FOR_JOIN(RuleTypeClass.REWRITE), - ROLLUP_AGG_SCAN(RuleTypeClass.REWRITE), - ROLLUP_AGG_FILTER_SCAN(RuleTypeClass.REWRITE), - ROLLUP_AGG_PROJECT_SCAN(RuleTypeClass.REWRITE), - ROLLUP_AGG_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), - ROLLUP_AGG_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), - ROLLUP_WITH_OUT_AGG(RuleTypeClass.REWRITE), + MATERIALIZED_INDEX_AGG_SCAN(RuleTypeClass.REWRITE), + MATERIALIZED_INDEX_AGG_FILTER_SCAN(RuleTypeClass.REWRITE), + MATERIALIZED_INDEX_AGG_PROJECT_SCAN(RuleTypeClass.REWRITE), + MATERIALIZED_INDEX_AGG_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), + MATERIALIZED_INDEX_AGG_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), + MATERIALIZED_INDEX_SCAN(RuleTypeClass.REWRITE), + MATERIALIZED_INDEX_FILTER_SCAN(RuleTypeClass.REWRITE), + MATERIALIZED_INDEX_PROJECT_SCAN(RuleTypeClass.REWRITE), + MATERIALIZED_INDEX_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), + MATERIALIZED_INDEX_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE), EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION(RuleTypeClass.REWRITE), REWRITE_SENTINEL(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/AbstractSelectMaterializedIndexRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/AbstractSelectMaterializedIndexRule.java new file mode 100644 index 000000000000000..523ef54da7ff6db --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/AbstractSelectMaterializedIndexRule.java @@ -0,0 +1,262 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.mv; + +import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.MaterializedIndex; +import org.apache.doris.catalog.OlapTable; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.NullSafeEqual; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; + +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Base class for selecting materialized index rules. + */ +public abstract class AbstractSelectMaterializedIndexRule { + /** + * 1. indexes have all the required columns. + * 2. find matching key prefix most. + * 3. sort by row count, column count and index id. + */ + protected List select( + Stream inputCandidates, + LogicalOlapScan scan, + Set requiredScanOutput, + List predicates) { + + OlapTable table = scan.getTable(); + // Scan slot exprId -> slot name + Map exprIdToName = scan.getOutput() + .stream() + .collect(Collectors.toMap(NamedExpression::getExprId, NamedExpression::getName)); + + // get required column names in metadata. + Set requiredColumnNames = requiredScanOutput + .stream() + .map(slot -> exprIdToName.get(slot.getExprId())) + .collect(Collectors.toSet()); + + // 1. filter index contains all the required columns by column name. + List containAllRequiredColumns = inputCandidates + .filter(index -> table.getSchemaByIndexId(index.getId(), true) + .stream() + .map(Column::getName) + .collect(Collectors.toSet()) + .containsAll(requiredColumnNames) + ).collect(Collectors.toList()); + + // 2. find matching key prefix most. + List matchingKeyPrefixMost = matchPrefixMost(scan, containAllRequiredColumns, predicates, + exprIdToName); + + List partitionIds = scan.getSelectedPartitionIds(); + // 3. sort by row count, column count and index id. + return matchingKeyPrefixMost.stream() + .map(MaterializedIndex::getId) + .sorted(Comparator + // compare by row count + .comparing(rid -> partitionIds.stream() + .mapToLong(pid -> table.getPartition(pid).getIndex((Long) rid).getRowCount()) + .sum()) + // compare by column count + .thenComparing(rid -> table.getSchemaByIndexId((Long) rid).size()) + // compare by index id + .thenComparing(rid -> (Long) rid)) + .collect(Collectors.toList()); + } + + protected List matchPrefixMost( + LogicalOlapScan scan, + List candidate, + List predicates, + Map exprIdToName) { + Map> split = filterCanUsePrefixIndexAndSplitByEquality(predicates, exprIdToName); + Set equalColNames = split.getOrDefault(true, ImmutableSet.of()); + Set nonEqualColNames = split.getOrDefault(false, ImmutableSet.of()); + + if (!(equalColNames.isEmpty() && nonEqualColNames.isEmpty())) { + List matchingResult = matchKeyPrefixMost(scan.getTable(), candidate, + equalColNames, nonEqualColNames); + return matchingResult.isEmpty() ? candidate : matchingResult; + } else { + return candidate; + } + } + + /////////////////////////////////////////////////////////////////////////// + // Split conjuncts into equal-to and non-equal-to. + /////////////////////////////////////////////////////////////////////////// + + /** + * Filter the input conjuncts those can use prefix and split into 2 groups: is equal-to or non-equal-to predicate + * when comparing the key column. + */ + private Map> filterCanUsePrefixIndexAndSplitByEquality( + List conjunct, Map exprIdToColName) { + return conjunct.stream() + .map(expr -> PredicateChecker.canUsePrefixIndex(expr, exprIdToColName)) + .filter(result -> !result.equals(PrefixIndexCheckResult.FAILURE)) + .collect(Collectors.groupingBy( + result -> result.type == ResultType.SUCCESS_EQUAL, + Collectors.mapping(result -> result.colName, Collectors.toSet()) + )); + } + + private enum ResultType { + FAILURE, + SUCCESS_EQUAL, + SUCCESS_NON_EQUAL, + } + + private static class PrefixIndexCheckResult { + public static final PrefixIndexCheckResult FAILURE = new PrefixIndexCheckResult(null, ResultType.FAILURE); + private final String colName; + private final ResultType type; + + private PrefixIndexCheckResult(String colName, ResultType result) { + this.colName = colName; + this.type = result; + } + + public static PrefixIndexCheckResult createEqual(String name) { + return new PrefixIndexCheckResult(name, ResultType.SUCCESS_EQUAL); + } + + public static PrefixIndexCheckResult createNonEqual(String name) { + return new PrefixIndexCheckResult(name, ResultType.SUCCESS_NON_EQUAL); + } + } + + /** + * Check if an expression could prefix key index. + */ + private static class PredicateChecker extends ExpressionVisitor> { + private static final PredicateChecker INSTANCE = new PredicateChecker(); + + private PredicateChecker() { + } + + public static PrefixIndexCheckResult canUsePrefixIndex(Expression expression, + Map exprIdToName) { + return expression.accept(INSTANCE, exprIdToName); + } + + @Override + public PrefixIndexCheckResult visit(Expression expr, Map context) { + return PrefixIndexCheckResult.FAILURE; + } + + @Override + public PrefixIndexCheckResult visitInPredicate(InPredicate in, Map context) { + Optional slotOrCastOnSlot = ExpressionUtils.isSlotOrCastOnSlot(in.getCompareExpr()); + if (slotOrCastOnSlot.isPresent() && in.getOptions().stream().allMatch(Literal.class::isInstance)) { + return PrefixIndexCheckResult.createEqual(context.get(slotOrCastOnSlot.get())); + } else { + return PrefixIndexCheckResult.FAILURE; + } + } + + @Override + public PrefixIndexCheckResult visitComparisonPredicate(ComparisonPredicate cp, Map context) { + if (cp instanceof EqualTo || cp instanceof NullSafeEqual) { + return check(cp, context, PrefixIndexCheckResult::createEqual); + } else { + return check(cp, context, PrefixIndexCheckResult::createNonEqual); + } + } + + private PrefixIndexCheckResult check(ComparisonPredicate cp, Map exprIdToColumnName, + Function resultMapper) { + return check(cp).map(exprId -> resultMapper.apply(exprIdToColumnName.get(exprId))) + .orElse(PrefixIndexCheckResult.FAILURE); + } + + private Optional check(ComparisonPredicate cp) { + Optional exprId = check(cp.left(), cp.right()); + return exprId.isPresent() ? exprId : check(cp.right(), cp.left()); + } + + private Optional check(Expression maybeSlot, Expression maybeConst) { + Optional exprIdOpt = ExpressionUtils.isSlotOrCastOnSlot(maybeSlot); + return exprIdOpt.isPresent() && maybeConst.isConstant() ? exprIdOpt : Optional.empty(); + } + } + + /////////////////////////////////////////////////////////////////////////// + // Matching key prefix + /////////////////////////////////////////////////////////////////////////// + private List matchKeyPrefixMost( + OlapTable table, + List indexes, + Set equalColumns, + Set nonEqualColumns) { + TreeMap> collect = indexes.stream() + .collect(Collectors.toMap( + index -> indexKeyPrefixMatchCount(table, index, equalColumns, nonEqualColumns), + Lists::newArrayList, + (l1, l2) -> { + l1.addAll(l2); + return l1; + }, + TreeMap::new) + ); + return collect.descendingMap().firstEntry().getValue(); + } + + private int indexKeyPrefixMatchCount( + OlapTable table, + MaterializedIndex index, + Set equalColNames, + Set nonEqualColNames) { + int matchCount = 0; + for (Column column : table.getSchemaByIndexId(index.getId())) { + if (equalColNames.contains(column.getName())) { + matchCount++; + } else if (nonEqualColNames.contains(column.getName())) { + // Unequivalence predicate's columns can match only first column in index. + matchCount++; + break; + } else { + break; + } + } + return matchCount; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectRollupWithAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectMaterializedIndexWithAggregate.java similarity index 63% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectRollupWithAggregate.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectMaterializedIndexWithAggregate.java index b31fbb2a7c5f790..5cdb32cf69b2db2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectRollupWithAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectMaterializedIndexWithAggregate.java @@ -26,19 +26,15 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; -import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; -import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.NullSafeEqual; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; -import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.trees.plans.PreAggStatus; import org.apache.doris.nereids.trees.plans.algebra.Project; @@ -50,24 +46,26 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; -import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; import java.util.Set; -import java.util.TreeMap; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; /** - * Select rollup index when aggregate is present. + * Select materialized index, i.e., both for rollup and materialized view when aggregate is present. + * TODO: optimize queries with aggregate not on top of scan directly, e.g., aggregate -> join -> scan + * to use materialized index. */ @Developing -public class SelectRollupWithAggregate implements RewriteRuleFactory { +public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterializedIndexRule + implements RewriteRuleFactory { /////////////////////////////////////////////////////////////////////////// // All the patterns /////////////////////////////////////////////////////////////////////////// @@ -76,9 +74,9 @@ public List buildRules() { return ImmutableList.of( // only agg above scan // Aggregate(Scan) - logicalAggregate(logicalOlapScan().when(LogicalOlapScan::shouldSelectRollup)).then(agg -> { + logicalAggregate(logicalOlapScan().when(LogicalOlapScan::shouldSelectIndex)).then(agg -> { LogicalOlapScan scan = agg.child(); - Pair> result = selectCandidateRollupIds( + Pair> result = selectCandidateIndexIds( scan, agg.getInputSlots(), ImmutableList.of(), @@ -87,11 +85,11 @@ public List buildRules() { return agg.withChildren( scan.withMaterializedIndexSelected(result.key(), result.value()) ); - }).toRule(RuleType.ROLLUP_AGG_SCAN), + }).toRule(RuleType.MATERIALIZED_INDEX_AGG_SCAN), // filter could push down scan. // Aggregate(Filter(Scan)) - logicalAggregate(logicalFilter(logicalOlapScan().when(LogicalOlapScan::shouldSelectRollup))) + logicalAggregate(logicalFilter(logicalOlapScan().when(LogicalOlapScan::shouldSelectIndex))) .then(agg -> { LogicalFilter filter = agg.child(); LogicalOlapScan scan = filter.child(); @@ -100,7 +98,7 @@ public List buildRules() { .addAll(filter.getInputSlots()) .build(); - Pair> result = selectCandidateRollupIds( + Pair> result = selectCandidateIndexIds( scan, requiredSlots, filter.getConjuncts(), @@ -110,15 +108,15 @@ public List buildRules() { return agg.withChildren(filter.withChildren( scan.withMaterializedIndexSelected(result.key(), result.value()) )); - }).toRule(RuleType.ROLLUP_AGG_FILTER_SCAN), + }).toRule(RuleType.MATERIALIZED_INDEX_AGG_FILTER_SCAN), // column pruning or other projections such as alias, etc. // Aggregate(Project(Scan)) - logicalAggregate(logicalProject(logicalOlapScan().when(LogicalOlapScan::shouldSelectRollup))) + logicalAggregate(logicalProject(logicalOlapScan().when(LogicalOlapScan::shouldSelectIndex))) .then(agg -> { LogicalProject project = agg.child(); LogicalOlapScan scan = project.child(); - Pair> result = selectCandidateRollupIds( + Pair> result = selectCandidateIndexIds( scan, project.getInputSlots(), ImmutableList.of(), @@ -131,18 +129,21 @@ public List buildRules() { scan.withMaterializedIndexSelected(result.key(), result.value()) ) ); - }).toRule(RuleType.ROLLUP_AGG_PROJECT_SCAN), + }).toRule(RuleType.MATERIALIZED_INDEX_AGG_PROJECT_SCAN), // filter could push down and project. // Aggregate(Project(Filter(Scan))) logicalAggregate(logicalProject(logicalFilter(logicalOlapScan() - .when(LogicalOlapScan::shouldSelectRollup)))).then(agg -> { + .when(LogicalOlapScan::shouldSelectIndex)))).then(agg -> { LogicalProject> project = agg.child(); LogicalFilter filter = project.child(); LogicalOlapScan scan = filter.child(); - Pair> result = selectCandidateRollupIds( + Set requiredSlots = Stream.concat( + project.getInputSlots().stream(), filter.getInputSlots().stream()) + .collect(Collectors.toSet()); + Pair> result = selectCandidateIndexIds( scan, - agg.getInputSlots(), + requiredSlots, filter.getConjuncts(), extractAggFunctionAndReplaceSlot(agg, Optional.of(project)), ExpressionUtils.replace(agg.getGroupByExpressions(), @@ -151,16 +152,16 @@ public List buildRules() { return agg.withChildren(project.withChildren(filter.withChildren( scan.withMaterializedIndexSelected(result.key(), result.value()) ))); - }).toRule(RuleType.ROLLUP_AGG_PROJECT_FILTER_SCAN), + }).toRule(RuleType.MATERIALIZED_INDEX_AGG_PROJECT_FILTER_SCAN), // filter can't push down // Aggregate(Filter(Project(Scan))) logicalAggregate(logicalFilter(logicalProject(logicalOlapScan() - .when(LogicalOlapScan::shouldSelectRollup)))).then(agg -> { + .when(LogicalOlapScan::shouldSelectIndex)))).then(agg -> { LogicalFilter> filter = agg.child(); LogicalProject project = filter.child(); LogicalOlapScan scan = project.child(); - Pair> result = selectCandidateRollupIds( + Pair> result = selectCandidateIndexIds( scan, project.getInputSlots(), ImmutableList.of(), @@ -171,23 +172,24 @@ public List buildRules() { return agg.withChildren(filter.withChildren(project.withChildren( scan.withMaterializedIndexSelected(result.key(), result.value()) ))); - }).toRule(RuleType.ROLLUP_AGG_FILTER_PROJECT_SCAN) + }).toRule(RuleType.MATERIALIZED_INDEX_AGG_FILTER_PROJECT_SCAN) ); } /////////////////////////////////////////////////////////////////////////// - // Main entrance of select rollup + // Main entrance of select materialized index. /////////////////////////////////////////////////////////////////////////// /** - * Select candidate rollup ids. + * Select candidate materialized index ids. + *

+ * 0. turn off pre agg, checking input aggregate functions and group by expressions and pushdown predicates. *

- * 0. turn off pre agg, checking input aggregate functions and group by expressions, etc. - * 1. rollup contains all the required output slots. + * 1. index contains all the required output slots. * 2. match the most prefix index if pushdown predicates present. - * 3. sort the result matching rollup index ids. + * 3. sort the result matching materialized index ids. */ - private Pair> selectCandidateRollupIds( + private Pair> selectCandidateIndexIds( LogicalOlapScan scan, Set requiredScanOutput, List predicates, @@ -197,204 +199,46 @@ private Pair> selectCandidateRollupIds( String.format("Scan's output (%s) should contains all the input required scan output (%s).", scan.getOutput(), requiredScanOutput)); - // 0. maybe turn off pre agg. - PreAggStatus preAggStatus = checkPreAggStatus(scan, predicates, aggregateFunctions, groupingExprs); - if (preAggStatus.isOff()) { - // return early if pre agg status if off. - return Pair.of(preAggStatus, ImmutableList.of(scan.getTable().getBaseIndexId())); - } - OlapTable table = scan.getTable(); - // Scan slot exprId -> slot name - Map exprIdToName = scan.getOutput() - .stream() - .collect(Collectors.toMap(NamedExpression::getExprId, NamedExpression::getName)); - - // get required column names in metadata. - Set requiredColumnNames = requiredScanOutput - .stream() - .map(slot -> exprIdToName.get(slot.getExprId())) - .collect(Collectors.toSet()); - - // 1. filter rollup contains all the required columns by column name. - List containAllRequiredColumns = table.getVisibleIndex().stream() - .filter(rollup -> table.getSchemaByIndexId(rollup.getId(), true) + // 0. check pre-aggregation status. + final PreAggStatus preAggStatus; + final Stream checkPreAggResult; + switch (table.getKeysType()) { + case AGG_KEYS: + case UNIQUE_KEYS: + // Check pre-aggregation status by base index for aggregate-keys and unique-keys OLAP table. + preAggStatus = checkPreAggStatus(scan, table.getBaseIndexId(), predicates, + aggregateFunctions, groupingExprs); + if (preAggStatus.isOff()) { + // return early if pre agg status if off. + return Pair.of(preAggStatus, ImmutableList.of(scan.getTable().getBaseIndexId())); + } + checkPreAggResult = table.getVisibleIndex().stream(); + break; + case DUP_KEYS: + Map> indexesGroupByIsBaseOrNot = table.getVisibleIndex() .stream() - .map(Column::getName) - .collect(Collectors.toSet()) - .containsAll(requiredColumnNames) - ).collect(Collectors.toList()); - - Map> split = filterCanUsePrefixIndexAndSplitByEquality(predicates, exprIdToName); - Set equalColNames = split.getOrDefault(true, ImmutableSet.of()); - Set nonEqualColNames = split.getOrDefault(false, ImmutableSet.of()); - - // 2. find matching key prefix most. - List matchingKeyPrefixMost; - if (!(equalColNames.isEmpty() && nonEqualColNames.isEmpty())) { - List matchingResult = matchKeyPrefixMost(table, containAllRequiredColumns, - equalColNames, nonEqualColNames); - matchingKeyPrefixMost = matchingResult.isEmpty() ? containAllRequiredColumns : matchingResult; - } else { - matchingKeyPrefixMost = containAllRequiredColumns; - } - - List partitionIds = scan.getSelectedPartitionIds(); - // 3. sort by row count, column count and index id. - List sortedIndexId = matchingKeyPrefixMost.stream() - .map(MaterializedIndex::getId) - .sorted(Comparator - // compare by row count - .comparing(rid -> partitionIds.stream() - .mapToLong(pid -> table.getPartition(pid).getIndex((Long) rid).getRowCount()) - .sum()) - // compare by column count - .thenComparing(rid -> table.getSchemaByIndexId((Long) rid).size()) - // compare by rollup index id - .thenComparing(rid -> (Long) rid)) - .collect(Collectors.toList()); - return Pair.of(preAggStatus, sortedIndexId); - } - - /////////////////////////////////////////////////////////////////////////// - // Matching key prefix - /////////////////////////////////////////////////////////////////////////// - private List matchKeyPrefixMost( - OlapTable table, - List rollups, - Set equalColumns, - Set nonEqualColumns) { - TreeMap> collect = rollups.stream() - .collect(Collectors.toMap( - rollup -> rollupKeyPrefixMatchCount(table, rollup, equalColumns, nonEqualColumns), - Lists::newArrayList, - (l1, l2) -> { - l1.addAll(l2); - return l1; - }, - TreeMap::new) + .collect(Collectors.groupingBy(index -> index.getId() == table.getBaseIndexId())); + + // Duplicate-keys table could use base index and indexes that pre-aggregation status is on. + checkPreAggResult = Stream.concat( + indexesGroupByIsBaseOrNot.get(true).stream(), + indexesGroupByIsBaseOrNot.getOrDefault(false, ImmutableList.of()) + .stream() + .filter(index -> checkPreAggStatus(scan, index.getId(), predicates, + aggregateFunctions, groupingExprs).isOn()) ); - return collect.descendingMap().firstEntry().getValue(); - } - private int rollupKeyPrefixMatchCount( - OlapTable table, - MaterializedIndex rollup, - Set equalColNames, - Set nonEqualColNames) { - int matchCount = 0; - for (Column column : table.getSchemaByIndexId(rollup.getId())) { - if (equalColNames.contains(column.getName())) { - matchCount++; - } else if (nonEqualColNames.contains(column.getName())) { - // Unequivalence predicate's columns can match only first column in rollup. - matchCount++; + // Pre-aggregation is set to `on` by default for duplicate-keys table. + preAggStatus = PreAggStatus.on(); break; - } else { - break; - } - } - return matchCount; - } - - /////////////////////////////////////////////////////////////////////////// - // Split conjuncts into equal-to and non-equal-to. - /////////////////////////////////////////////////////////////////////////// - - /** - * Filter the input conjuncts those can use prefix and split into 2 groups: is equal-to or non-equal-to predicate - * when comparing the key column. - */ - private Map> filterCanUsePrefixIndexAndSplitByEquality( - List conjunct, Map exprIdToColName) { - return conjunct.stream() - .map(expr -> PredicateChecker.canUsePrefixIndex(expr, exprIdToColName)) - .filter(result -> !result.equals(PrefixIndexCheckResult.FAILURE)) - .collect(Collectors.groupingBy( - result -> result.type == ResultType.SUCCESS_EQUAL, - Collectors.mapping(result -> result.colName, Collectors.toSet()) - )); - } - - private enum ResultType { - FAILURE, - SUCCESS_EQUAL, - SUCCESS_NON_EQUAL, - } - - private static class PrefixIndexCheckResult { - public static final PrefixIndexCheckResult FAILURE = new PrefixIndexCheckResult(null, ResultType.FAILURE); - private final String colName; - private final ResultType type; - - private PrefixIndexCheckResult(String colName, ResultType result) { - this.colName = colName; - this.type = result; - } - - public static PrefixIndexCheckResult createEqual(String name) { - return new PrefixIndexCheckResult(name, ResultType.SUCCESS_EQUAL); + default: + throw new RuntimeException("Not supported keys type: " + table.getKeysType()); } - public static PrefixIndexCheckResult createNonEqual(String name) { - return new PrefixIndexCheckResult(name, ResultType.SUCCESS_NON_EQUAL); - } - } - - /** - * Check if an expression could prefix key index. - */ - private static class PredicateChecker extends ExpressionVisitor> { - private static final PredicateChecker INSTANCE = new PredicateChecker(); - - private PredicateChecker() { - } - - public static PrefixIndexCheckResult canUsePrefixIndex(Expression expression, - Map exprIdToName) { - return expression.accept(INSTANCE, exprIdToName); - } - - @Override - public PrefixIndexCheckResult visit(Expression expr, Map context) { - return PrefixIndexCheckResult.FAILURE; - } - - @Override - public PrefixIndexCheckResult visitInPredicate(InPredicate in, Map context) { - Optional slotOrCastOnSlot = ExpressionUtils.isSlotOrCastOnSlot(in.getCompareExpr()); - if (slotOrCastOnSlot.isPresent() && in.getOptions().stream().allMatch(Literal.class::isInstance)) { - return PrefixIndexCheckResult.createEqual(context.get(slotOrCastOnSlot.get())); - } else { - return PrefixIndexCheckResult.FAILURE; - } - } - - @Override - public PrefixIndexCheckResult visitComparisonPredicate(ComparisonPredicate cp, Map context) { - if (cp instanceof EqualTo || cp instanceof NullSafeEqual) { - return check(cp, context, PrefixIndexCheckResult::createEqual); - } else { - return check(cp, context, PrefixIndexCheckResult::createNonEqual); - } - } - - private PrefixIndexCheckResult check(ComparisonPredicate cp, Map exprIdToColumnName, - Function resultMapper) { - return check(cp).map(exprId -> resultMapper.apply(exprIdToColumnName.get(exprId))) - .orElse(PrefixIndexCheckResult.FAILURE); - } - - private Optional check(ComparisonPredicate cp) { - Optional exprId = check(cp.left(), cp.right()); - return exprId.isPresent() ? exprId : check(cp.right(), cp.left()); - } - - private Optional check(Expression maybeSlot, Expression maybeConst) { - Optional exprIdOpt = ExpressionUtils.isSlotOrCastOnSlot(maybeSlot); - return exprIdOpt.isPresent() && maybeConst.isConstant() ? exprIdOpt : Optional.empty(); - } + List sortedIndexId = select(checkPreAggResult, scan, requiredScanOutput, predicates); + return Pair.of(preAggStatus, sortedIndexId); } /** @@ -434,10 +278,11 @@ private List extractAggFunctionAndReplaceSlot( /////////////////////////////////////////////////////////////////////////// private PreAggStatus checkPreAggStatus( LogicalOlapScan olapScan, + long indexId, List predicates, List aggregateFuncs, List groupingExprs) { - CheckContext checkContext = new CheckContext(olapScan); + CheckContext checkContext = new CheckContext(olapScan, indexId); return checkAggregateFunctions(aggregateFuncs, checkContext) .offOrElse(() -> checkGroupingExprs(groupingExprs, checkContext)) .offOrElse(() -> checkPredicates(predicates, checkContext)); @@ -491,6 +336,18 @@ public PreAggStatus visitSum(Sum sum, CheckContext context) { return checkAggFunc(sum, AggregateType.SUM, extractSlotId(sum.child()), context, false); } + @Override + public PreAggStatus visitCount(Count count, CheckContext context) { + Optional exprIdOpt = extractSlotId(count.child()); + // Only count(distinct key_column) is supported. + if (count.isDistinct() && exprIdOpt.isPresent() && context.exprIdToKeyColumn.containsKey(exprIdOpt.get())) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off(String.format( + "Count distinct is only valid for key columns, but meet %s.", count.toSql())); + } + } + private PreAggStatus checkAggFunc( AggregateFunction aggFunc, AggregateType matchingAggType, @@ -532,9 +389,9 @@ private static class CheckContext { public final Map exprIdToKeyColumn; public final Map exprIdToValueColumn; - public CheckContext(LogicalOlapScan scan) { + public CheckContext(LogicalOlapScan scan, long indexId) { // map> - Map> nameToColumnGroupingByIsKey = scan.getTable().getBaseSchema() + Map> nameToColumnGroupingByIsKey = scan.getTable().getSchemaByIndexId(indexId) .stream() .collect(Collectors.groupingBy( Column::isKey, @@ -543,7 +400,7 @@ public CheckContext(LogicalOlapScan scan) { ) )); Map keyNameToColumn = nameToColumnGroupingByIsKey.get(true); - Map valueNameToColumn = nameToColumnGroupingByIsKey.get(false); + Map valueNameToColumn = nameToColumnGroupingByIsKey.getOrDefault(false, ImmutableMap.of()); Map nameToExprId = scan.getOutput() .stream() .collect(Collectors.toMap( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectMaterializedIndexWithoutAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectMaterializedIndexWithoutAggregate.java new file mode 100644 index 000000000000000..7ce345b3898542a --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectMaterializedIndexWithoutAggregate.java @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.mv; + +import org.apache.doris.catalog.MaterializedIndex; +import org.apache.doris.catalog.OlapTable; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.PreAggStatus; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +/** + * Select materialized index, i.e., both for rollup and materialized view when aggregate is not present. + *

+ * Scan OLAP table with aggregate is handled in {@link SelectMaterializedIndexWithAggregate}. + *

+ * Note that we should first apply {@link SelectMaterializedIndexWithAggregate} and then + * {@link SelectMaterializedIndexWithoutAggregate}. + * Besides, these two rules should run in isolated batches, thus when enter this rule, it's guaranteed that there is + * no aggregation on top of the scan. + *

+ * TODO: optimize queries with aggregate not on top of scan directly, e.g., aggregate -> join -> scan + * to use materialized index. + */ +public class SelectMaterializedIndexWithoutAggregate extends AbstractSelectMaterializedIndexRule + implements RewriteRuleFactory { + + @Override + public List buildRules() { + return ImmutableList.of( + // project with pushdown filter. + // Project(Filter(Scan)) + logicalProject(logicalFilter(logicalOlapScan().whenNot(LogicalOlapScan::isIndexSelected))) + .then(project -> { + LogicalFilter filter = project.child(); + LogicalOlapScan scan = filter.child(); + return project.withChildren(filter.withChildren( + selectIndex(scan, project::getInputSlots, filter::getConjuncts))); + }).toRule(RuleType.MATERIALIZED_INDEX_PROJECT_FILTER_SCAN), + + // project with filter that cannot be pushdown. + // Filter(Project(Scan)) + logicalFilter(logicalProject(logicalOlapScan().whenNot(LogicalOlapScan::isIndexSelected))) + .then(filter -> { + LogicalProject project = filter.child(); + LogicalOlapScan scan = project.child(); + return filter.withChildren(project.withChildren( + selectIndex(scan, project::getInputSlots, ImmutableList::of) + )); + }).toRule(RuleType.MATERIALIZED_INDEX_FILTER_PROJECT_SCAN), + + // scan with filters could be pushdown. + // Filter(Scan) + logicalFilter(logicalOlapScan().whenNot(LogicalOlapScan::isIndexSelected)) + .then(filter -> { + LogicalOlapScan scan = filter.child(); + return filter.withChildren(selectIndex(scan, ImmutableSet::of, filter::getConjuncts)); + }) + .toRule(RuleType.MATERIALIZED_INDEX_FILTER_SCAN), + + // project and scan. + // Project(Scan) + logicalProject(logicalOlapScan().whenNot(LogicalOlapScan::isIndexSelected)) + .then(project -> { + LogicalOlapScan scan = project.child(); + return project.withChildren( + selectIndex(scan, project::getInputSlots, ImmutableList::of)); + }) + .toRule(RuleType.MATERIALIZED_INDEX_PROJECT_SCAN), + + // only scan. + logicalOlapScan() + .whenNot(LogicalOlapScan::isIndexSelected) + .then(scan -> selectIndex(scan, scan::getOutputSet, ImmutableList::of)) + .toRule(RuleType.MATERIALIZED_INDEX_SCAN) + ); + } + + /** + * Select materialized index when aggregate node is not present. + * + * @param scan Scan node. + * @param requiredScanOutputSupplier Supplier to get the required scan output. + * @param predicatesSupplier Supplier to get pushdown predicates. + * @return Result scan node. + */ + private LogicalOlapScan selectIndex( + LogicalOlapScan scan, + Supplier> requiredScanOutputSupplier, + Supplier> predicatesSupplier) { + switch (scan.getTable().getKeysType()) { + case AGG_KEYS: + case UNIQUE_KEYS: + OlapTable table = scan.getTable(); + long baseIndexId = table.getBaseIndexId(); + int baseIndexKeySize = table.getKeyColumnsByIndexId(table.getBaseIndexId()).size(); + // No on aggregate on scan. + // So only base index and indexes that have all the keys could be used. + List candidates = table.getVisibleIndex().stream() + .filter(index -> index.getId() == baseIndexId + || table.getKeyColumnsByIndexId(index.getId()).size() == baseIndexKeySize) + .collect(Collectors.toList()); + PreAggStatus preAgg = PreAggStatus.off("No aggregate on scan."); + if (candidates.size() == 1) { + // `candidates` only have base index. + return scan.withMaterializedIndexSelected(preAgg, ImmutableList.of(baseIndexId)); + } else { + return scan.withMaterializedIndexSelected(preAgg, + select(candidates.stream(), scan, requiredScanOutputSupplier.get(), + predicatesSupplier.get())); + } + case DUP_KEYS: + // Set pre-aggregation to `on` to keep consistency with legacy logic. + return scan.withMaterializedIndexSelected(PreAggStatus.on(), + select(scan.getTable().getVisibleIndex().stream(), scan, requiredScanOutputSupplier.get(), + predicatesSupplier.get())); + default: + throw new RuntimeException("Not supported keys type: " + scan.getTable().getKeysType()); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectRollupWithoutAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectRollupWithoutAggregate.java deleted file mode 100644 index a37084953a809ae..000000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectRollupWithoutAggregate.java +++ /dev/null @@ -1,60 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.mv; - -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; -import org.apache.doris.nereids.trees.plans.PreAggStatus; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; - -import com.google.common.collect.ImmutableList; - -/** - * Select rollup index when aggregate is not present. - *

- * Scan OLAP table with aggregate is handled in {@link SelectRollupWithAggregate}. This rule is to disable - * pre-aggregation for OLAP scan when there is no aggregate plan. - *

- * Note that we should first apply {@link SelectRollupWithAggregate} and then {@link SelectRollupWithoutAggregate}. - * Besides, these two rules should run in isolated batches, thus when enter this rule, it's guaranteed that there is - * no aggregation on top of the scan. - */ -public class SelectRollupWithoutAggregate extends OneRewriteRuleFactory { - - @Override - public Rule build() { - return logicalOlapScan() - .whenNot(LogicalOlapScan::isRollupSelected) - .then(this::scanWithoutAggregate) - .toRule(RuleType.ROLLUP_WITH_OUT_AGG); - } - - private LogicalOlapScan scanWithoutAggregate(LogicalOlapScan scan) { - switch (scan.getTable().getKeysType()) { - case AGG_KEYS: - case UNIQUE_KEYS: - return scan.withMaterializedIndexSelected(PreAggStatus.off("No aggregate on scan."), - ImmutableList.of(scan.getTable().getBaseIndexId())); - default: - // Set pre-aggregation to `on` to keep consistency with legacy logic. - return scan.withMaterializedIndexSelected(PreAggStatus.on(), - ImmutableList.of(scan.getTable().getBaseIndexId())); - } - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java index 0dea798b04771c3..ce63dbd6ffa0a06 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Count.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; @@ -30,7 +31,7 @@ import java.util.stream.Collectors; /** count agg function. */ -public class Count extends AggregateFunction implements AlwaysNotNullable { +public class Count extends AggregateFunction implements UnaryExpression, AlwaysNotNullable { private final boolean isStar; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java index 0b6563034cae1fd..ab94c80a37792e3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java @@ -22,7 +22,7 @@ import org.apache.doris.catalog.Table; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.LogicalProperties; -import org.apache.doris.nereids.rules.mv.SelectRollupWithAggregate; +import org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithAggregate; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.PreAggStatus; @@ -50,7 +50,7 @@ public class LogicalOlapScan extends LogicalRelation { private final boolean partitionPruned; private final List candidateIndexIds; - private final boolean rollupSelected; + private final boolean indexSelected; private final PreAggStatus preAggStatus; @@ -87,7 +87,7 @@ public LogicalOlapScan(RelationId id, Table table, List qualifier, } this.partitionPruned = partitionPruned; this.candidateIndexIds = candidateIndexIds; - this.rollupSelected = rollupSelected; + this.indexSelected = rollupSelected; this.preAggStatus = preAggStatus; } @@ -128,18 +128,18 @@ public int hashCode() { @Override public Plan withGroupExpression(Optional groupExpression) { return new LogicalOlapScan(id, table, qualifier, groupExpression, Optional.of(getLogicalProperties()), - selectedPartitionIds, partitionPruned, candidateIndexIds, rollupSelected, preAggStatus); + selectedPartitionIds, partitionPruned, candidateIndexIds, indexSelected, preAggStatus); } @Override public LogicalOlapScan withLogicalProperties(Optional logicalProperties) { return new LogicalOlapScan(id, table, qualifier, Optional.empty(), logicalProperties, selectedPartitionIds, - partitionPruned, candidateIndexIds, rollupSelected, preAggStatus); + partitionPruned, candidateIndexIds, indexSelected, preAggStatus); } public LogicalOlapScan withSelectedPartitionId(List selectedPartitionId) { return new LogicalOlapScan(id, table, qualifier, Optional.empty(), Optional.of(getLogicalProperties()), - selectedPartitionId, true, candidateIndexIds, rollupSelected, preAggStatus); + selectedPartitionId, true, candidateIndexIds, indexSelected, preAggStatus); } public LogicalOlapScan withMaterializedIndexSelected(PreAggStatus preAgg, List candidateIndexIds) { @@ -164,8 +164,8 @@ public long getSelectedIndexId() { return selectedIndexId; } - public boolean isRollupSelected() { - return rollupSelected; + public boolean isIndexSelected() { + return indexSelected; } public PreAggStatus getPreAggStatus() { @@ -173,21 +173,22 @@ public PreAggStatus getPreAggStatus() { } /** - * Should apply {@link SelectRollupWithAggregate} or not. + * Should apply {@link SelectMaterializedIndexWithAggregate} or not. */ - public boolean shouldSelectRollup() { + public boolean shouldSelectIndex() { switch (((OlapTable) table).getKeysType()) { case AGG_KEYS: case UNIQUE_KEYS: - return !rollupSelected; + case DUP_KEYS: + return !indexSelected; default: return false; } } @VisibleForTesting - public Optional getSelectRollupName() { - return rollupSelected ? Optional.ofNullable(((OlapTable) table).getIndexNameById(selectedIndexId)) + public Optional getSelectedMaterializedIndexName() { + return indexSelected ? Optional.ofNullable(((OlapTable) table).getIndexNameById(selectedIndexId)) : Optional.empty(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java index a445e777db09a00..fb210915f406940 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java @@ -1203,4 +1203,9 @@ public DataPartition constructInputPartitionByDistributionInfo() throws UserExce public String getReasonOfPreAggregation() { return reasonOfPreAggregation; } + + @VisibleForTesting + public String getSelectedIndexName() { + return olapTable.getIndexNameById(selectedIndexId); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/BaseMaterializedIndexSelectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/BaseMaterializedIndexSelectTest.java new file mode 100644 index 000000000000000..0b09a30ac977803 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/BaseMaterializedIndexSelectTest.java @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.mv; + +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.planner.OlapScanNode; +import org.apache.doris.planner.ScanNode; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Assertions; + +import java.util.List; +import java.util.function.Consumer; + +/** + * Base class to test selecting materialized index. + */ +public abstract class BaseMaterializedIndexSelectTest extends TestWithFeService { + protected void singleTableTest(String sql, String indexName, boolean preAgg) { + singleTableTest(sql, scan -> { + Assertions.assertEquals(preAgg, scan.isPreAggregation()); + Assertions.assertEquals(indexName, scan.getSelectedIndexName()); + }); + } + + protected void singleTableTest(String sql, Consumer scanConsumer) { + PlanChecker.from(connectContext).checkPlannerResult(sql, planner -> { + List scans = planner.getScanNodes(); + Assertions.assertEquals(1, scans.size()); + ScanNode scanNode = scans.get(0); + Assertions.assertTrue(scanNode instanceof OlapScanNode); + OlapScanNode olapScan = (OlapScanNode) scanNode; + scanConsumer.accept(olapScan); + }); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectMvIndexTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectMvIndexTest.java new file mode 100644 index 000000000000000..6ec7ab3764fed3b --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectMvIndexTest.java @@ -0,0 +1,1035 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.mv; + +import org.apache.doris.catalog.FunctionSet; +import org.apache.doris.common.FeConstants; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.planner.OlapScanNode; +import org.apache.doris.planner.ScanNode; +import org.apache.doris.utframe.DorisAssert; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +/** + * Tests ported from {@link org.apache.doris.planner.MaterializedViewFunctionTest} + */ +public class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implements PatternMatchSupported { + + private static final String EMPS_TABLE_NAME = "emps"; + private static final String EMPS_MV_NAME = "emps_mv"; + private static final String HR_DB_NAME = "db1"; + private static final String DEPTS_TABLE_NAME = "depts"; + private static final String DEPTS_MV_NAME = "depts_mv"; + private static final String USER_TAG_TABLE_NAME = "user_tags"; + private static final String TEST_TABLE_NAME = "test_tb"; + + @Override + protected void beforeCreatingConnectContext() throws Exception { + FeConstants.default_scheduler_interval_millisecond = 10; + FeConstants.runningUnitTest = true; + } + + @Override + protected void runBeforeAll() throws Exception { + createDatabase(HR_DB_NAME); + useDatabase(HR_DB_NAME); + } + + @BeforeEach + void before() throws Exception { + createTable("create table " + HR_DB_NAME + "." + EMPS_TABLE_NAME + " (time_col date, empid int, " + + "name varchar, deptno int, salary int, commission int) partition by range (time_col) " + + "(partition p1 values less than MAXVALUE) distributed by hash(time_col) buckets 3" + + " properties('replication_num' = '1');"); + + createTable("create table " + HR_DB_NAME + "." + DEPTS_TABLE_NAME + + " (time_col date, deptno int, name varchar, cost int) partition by range (time_col) " + + "(partition p1 values less than MAXVALUE) " + + "distributed by hash(time_col) buckets 3 properties('replication_num' = '1');"); + + createTable("create table " + HR_DB_NAME + "." + USER_TAG_TABLE_NAME + + " (time_col date, user_id int, user_name varchar(20), tag_id int) partition by range (time_col) " + + " (partition p1 values less than MAXVALUE) " + + "distributed by hash(time_col) buckets 3 properties('replication_num' = '1');"); + } + + @AfterEach + public void after() throws Exception { + dropTable(EMPS_TABLE_NAME, true); + dropTable(DEPTS_TABLE_NAME, true); + dropTable(USER_TAG_TABLE_NAME, true); + } + + @Test + public void testProjectionMV1() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid from " + + EMPS_TABLE_NAME + " order by deptno;"; + String query = "select empid, deptno from " + EMPS_TABLE_NAME + ";"; + createMv(createMVSql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testProjectionMV2() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid from " + + EMPS_TABLE_NAME + " order by deptno;"; + String query1 = "select empid + 1 from " + EMPS_TABLE_NAME + " where deptno = 10;"; + createMv(createMVSql); + testMv(query1, EMPS_MV_NAME); + String query2 = "select name from " + EMPS_TABLE_NAME + " where deptno -10 = 0;"; + testMv(query2, EMPS_TABLE_NAME); + } + + @Test + public void testProjectionMV3() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid, name from " + + EMPS_TABLE_NAME + " order by deptno;"; + String query1 = "select empid +1, name from " + EMPS_TABLE_NAME + " where deptno = 10;"; + createMv(createMVSql); + testMv(query1, EMPS_MV_NAME); + String query2 = "select name from " + EMPS_TABLE_NAME + " where deptno - 10 = 0;"; + testMv(query2, EMPS_MV_NAME); + } + + // @Test + // public void testProjectionMV4() throws Exception { + // String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select name, deptno, salary from " + // + EMPS_TABLE_NAME + ";"; + // String query1 = "select name from " + EMPS_TABLE_NAME + " where deptno > 30 and salary > 3000;"; + // createMv(createMVSql); + // testMv(query1, EMPS_MV_NAME); + // String query2 = "select empid from " + EMPS_TABLE_NAME + " where deptno > 30 and empid > 10;"; + // dorisAssert.query(query2).explainWithout(QUERY_USE_EMPS_MV); + // } + + /** + * TODO: enable this when union is supported. + */ + @Disabled + public void testUnionQueryOnProjectionMV() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid from " + + EMPS_TABLE_NAME + " order by deptno;"; + String union = "select empid from " + EMPS_TABLE_NAME + " where deptno > 300" + " union all select empid from" + + " " + EMPS_TABLE_NAME + " where deptno < 200"; + createMv(createMVSql); + testMv(union, EMPS_MV_NAME); + } + + @Test + public void testAggQueryOnAggMV1() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, sum(salary), " + + "max(commission) from " + EMPS_TABLE_NAME + " group by deptno;"; + String query = "select sum(salary), deptno from " + EMPS_TABLE_NAME + " group by deptno;"; + createMv(createMVSql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testAggQueryOnAggMV2() throws Exception { + String agg = "select deptno, sum(salary) from " + EMPS_TABLE_NAME + " group by deptno"; + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as " + agg + ";"; + String query = "select * from (select deptno, sum(salary) as sum_salary from " + EMPS_TABLE_NAME + " group " + + "by" + " deptno) a where (sum_salary * 2) > 3;"; + createMv(createMVSql); + testMv(query, EMPS_MV_NAME); + } + + /* + TODO + The deduplicate materialized view is not yet supported + @Test + public void testAggQueryOnDeduplicatedMV() throws Exception { + String deduplicateSQL = "select deptno, empid, name, salary, commission from " + EMPS_TABLE_NAME + " group " + + "by" + " deptno, empid, name, salary, commission"; + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as " + deduplicateSQL + ";"; + String query1 = "select deptno, sum(salary) from (" + deduplicateSQL + ") A group by deptno;"; + createMv(createMVSql); + testMv(query1, EMPS_MV_NAME); + String query2 = "select deptno, empid from " + EMPS_TABLE_NAME + ";"; + dorisAssert.query(query2).explainWithout(QUERY_USE_EMPS_MV); + } + */ + + @Test + public void testAggQueryOnAggMV3() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, commission, sum(salary)" + + " from " + EMPS_TABLE_NAME + " group by deptno, commission;"; + String query = "select commission, sum(salary) from " + EMPS_TABLE_NAME + " where commission * (deptno + " + + "commission) = 100 group by commission;"; + createMv(createMVSql); + testMv(query, EMPS_MV_NAME); + } + + /** + * Matching failed because the filtering condition under Aggregate + * references columns for aggregation. + */ + @Test + public void testAggQueryOnAggMV4() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, commission, sum(salary)" + + " from " + EMPS_TABLE_NAME + " group by deptno, commission;"; + String query = "select deptno, sum(salary) from " + EMPS_TABLE_NAME + " where salary>1000 group by deptno;"; + createMv(createMVSql); + testMv(query, EMPS_TABLE_NAME); + } + + /** + * There will be a compensating Project added after matching of the Aggregate. + */ + @Test + public void testAggQuqeryOnAggMV5() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, commission, sum(salary)" + + " from " + EMPS_TABLE_NAME + " group by deptno, commission;"; + String query = "select * from (select deptno, sum(salary) as sum_salary from " + EMPS_TABLE_NAME + + " group by deptno) a where sum_salary>10;"; + createMv(createMVSql); + testMv(query, EMPS_MV_NAME); + } + + /** + * There will be a compensating Project + Filter added after matching of the Aggregate. + */ + @Test + public void testAggQuqeryOnAggMV6() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, commission, sum(salary)" + + " from " + EMPS_TABLE_NAME + " group by deptno, commission;"; + String query = "select * from (select deptno, sum(salary) as sum_salary from " + EMPS_TABLE_NAME + + " where deptno>=20 group by deptno) a where sum_salary>10;"; + createMv(createMVSql); + testMv(query, EMPS_MV_NAME); + } + + /** + * Aggregation query with groupSets at coarser level of aggregation than + * aggregation materialized view. + * TODO: enable this when group by rollup is supported. + */ + @Disabled + public void testGroupingSetQueryOnAggMV() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by empid, deptno;"; + String query = "select sum(salary), empid, deptno from " + EMPS_TABLE_NAME + " group by rollup(empid,deptno);"; + createMv(createMVSql); + testMv(query, EMPS_MV_NAME); + } + + /** + * Aggregation query at coarser level of aggregation than aggregation materialized view. + */ + @Test + public void testAggQuqeryOnAggMV7() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, commission, sum(salary) " + + "from " + EMPS_TABLE_NAME + " " + "group by deptno, commission;"; + String query = "select deptno, sum(salary) from " + EMPS_TABLE_NAME + " where deptno>=20 group by deptno;"; + createMv(createMVSql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testAggQueryOnAggMV8() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by deptno;"; + String query = "select deptno, sum(salary) + 1 from " + EMPS_TABLE_NAME + " group by deptno;"; + createMv(createMVSql); + testMv(query, EMPS_MV_NAME); + } + + /** + * Query with cube and arithmetic expr + * TODO: enable this when group by cube is supported. + */ + @Disabled + public void testAggQueryOnAggMV9() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, commission, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by deptno, commission;"; + String query = "select deptno, commission, sum(salary) + 1 from " + EMPS_TABLE_NAME + + " group by cube(deptno,commission);"; + createMv(createMVSql); + testMv(query, EMPS_MV_NAME); + } + + /** + * Query with rollup and arithmetic expr + * TODO: enable this when group by rollup is supported. + */ + @Disabled + public void testAggQueryOnAggMV10() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, commission, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by deptno, commission;"; + String query = "select deptno, commission, sum(salary) + 1 from " + EMPS_TABLE_NAME + + " group by rollup (deptno, commission);"; + createMv(createMVSql); + testMv(query, EMPS_MV_NAME); + } + + /** + * Aggregation query with two aggregation operators + */ + @Test + public void testAggQueryOnAggMV11() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, count(salary) " + + "from " + EMPS_TABLE_NAME + " group by deptno;"; + String query = "select deptno, count(salary) + count(1) from " + EMPS_TABLE_NAME + + " group by deptno;"; + createMv(createMVSql); + testMv(query, EMPS_TABLE_NAME); + } + + /** + * Aggregation query with set operand + * TODO: enable this when union is supported. + */ + @Disabled + public void testAggQueryWithSetOperandOnAggMV() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select deptno, count(salary) " + + "from " + EMPS_TABLE_NAME + " group by deptno;"; + String query = "select deptno, count(salary) + count(1) from " + EMPS_TABLE_NAME + + " group by deptno union " + + "select deptno, count(salary) + count(1) from " + EMPS_TABLE_NAME + + " group by deptno;"; + createMv(createMVSql); + testMv(query, EMPS_TABLE_NAME); + } + + @Test + public void testJoinOnLeftProjectToJoin() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + + " as select deptno, sum(salary), sum(commission) from " + EMPS_TABLE_NAME + " group by deptno;"; + String createDeptsMVSQL = "create materialized view " + DEPTS_MV_NAME + " as select deptno, max(cost) from " + + DEPTS_TABLE_NAME + " group by deptno;"; + String query = "select * from (select deptno , sum(salary) from " + EMPS_TABLE_NAME + " group by deptno) A " + + "join (select deptno, max(cost) from " + DEPTS_TABLE_NAME + " group by deptno ) B on A.deptno = B" + + ".deptno;"; + createMv(createEmpsMVsql); + createMv(createDeptsMVSQL); + testMv(query, ImmutableMap.of(EMPS_TABLE_NAME, EMPS_MV_NAME, DEPTS_TABLE_NAME, DEPTS_MV_NAME)); + } + + @Test + public void testJoinOnRightProjectToJoin() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, sum(salary), sum" + + "(commission) from " + EMPS_TABLE_NAME + " group by deptno;"; + String createDeptsMVSQL = "create materialized view " + DEPTS_MV_NAME + " as select deptno, max(cost) from " + + DEPTS_TABLE_NAME + " group by deptno;"; + String query = "select * from (select deptno , sum(salary), sum(commission) from " + EMPS_TABLE_NAME + + " group by deptno) A join (select deptno from " + DEPTS_TABLE_NAME + " group by deptno ) B on A" + + ".deptno = B.deptno;"; + createMv(createEmpsMVsql); + createMv(createDeptsMVSQL); + testMv(query, ImmutableMap.of(EMPS_TABLE_NAME, EMPS_MV_NAME, DEPTS_TABLE_NAME, DEPTS_MV_NAME)); + } + + @Test + public void testJoinOnProjectsToJoin() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, sum(salary), sum" + + "(commission) from " + EMPS_TABLE_NAME + " group by deptno;"; + String createDeptsMVSQL = "create materialized view " + DEPTS_MV_NAME + " as select deptno, max(cost) from " + + DEPTS_TABLE_NAME + " group by deptno;"; + String query = "select * from (select deptno , sum(salary) from " + EMPS_TABLE_NAME + " group by deptno) A " + + "join (select deptno from " + DEPTS_TABLE_NAME + " group by deptno ) B on A.deptno = B.deptno;"; + createMv(createEmpsMVsql); + createMv(createDeptsMVSQL); + testMv(query, ImmutableMap.of(EMPS_TABLE_NAME, EMPS_MV_NAME, DEPTS_TABLE_NAME, DEPTS_MV_NAME)); + } + + @Test + public void testJoinOnCalcToJoin0() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from " + + EMPS_TABLE_NAME + ";"; + String createDeptsMVSQL = "create materialized view " + DEPTS_MV_NAME + " as select deptno from " + + DEPTS_TABLE_NAME + ";"; + String query = "select * from (select empid, deptno from " + EMPS_TABLE_NAME + " where deptno > 10 ) A " + + "join (select deptno from " + DEPTS_TABLE_NAME + " ) B on A.deptno = B.deptno;"; + // createMv(createEmpsMVsql); + // createMv(createDeptsMVSQL); + new DorisAssert(connectContext).withMaterializedView(createDeptsMVSQL).withMaterializedView(createEmpsMVsql); + testMv(query, ImmutableMap.of(EMPS_TABLE_NAME, EMPS_MV_NAME, DEPTS_TABLE_NAME, DEPTS_MV_NAME)); + } + + @Test + public void testJoinOnCalcToJoin1() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from " + + EMPS_TABLE_NAME + ";"; + String createDeptsMVSQL = "create materialized view " + DEPTS_MV_NAME + " as select deptno from " + + DEPTS_TABLE_NAME + ";"; + String query = "select * from (select empid, deptno from " + EMPS_TABLE_NAME + " ) A join (select " + + "deptno from " + DEPTS_TABLE_NAME + " where deptno > 10 ) B on A.deptno = B.deptno;"; + createMv(createEmpsMVsql); + createMv(createDeptsMVSQL); + testMv(query, ImmutableMap.of(EMPS_TABLE_NAME, EMPS_MV_NAME, DEPTS_TABLE_NAME, DEPTS_MV_NAME)); + } + + @Test + public void testJoinOnCalcToJoin2() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from " + + EMPS_TABLE_NAME + ";"; + String createDeptsMVSQL = "create materialized view " + DEPTS_MV_NAME + " as select deptno from " + + DEPTS_TABLE_NAME + ";"; + String query = "select * from (select empid, deptno from " + EMPS_TABLE_NAME + " where empid >10 ) A " + + "join (select deptno from " + DEPTS_TABLE_NAME + " where deptno > 10 ) B on A.deptno = B.deptno;"; + createMv(createEmpsMVsql); + createMv(createDeptsMVSQL); + testMv(query, ImmutableMap.of(EMPS_TABLE_NAME, EMPS_MV_NAME, DEPTS_TABLE_NAME, DEPTS_MV_NAME)); + } + + @Test + public void testJoinOnCalcToJoin3() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from " + + EMPS_TABLE_NAME + ";"; + String createDeptsMVSQL = "create materialized view " + DEPTS_MV_NAME + " as select deptno from " + + DEPTS_TABLE_NAME + ";"; + String query = "select * from (select empid, deptno + 1 deptno from " + EMPS_TABLE_NAME + " where empid >10 )" + + " A join (select deptno from " + DEPTS_TABLE_NAME + + " where deptno > 10 ) B on A.deptno = B.deptno;"; + createMv(createEmpsMVsql); + createMv(createDeptsMVSQL); + testMv(query, ImmutableMap.of(EMPS_TABLE_NAME, EMPS_MV_NAME, DEPTS_TABLE_NAME, DEPTS_MV_NAME)); + } + + /** + * TODO: enable this when implicit case is fully developed. + */ + @Disabled + public void testJoinOnCalcToJoin4() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from " + + EMPS_TABLE_NAME + ";"; + String createDeptsMVSQL = "create materialized view " + DEPTS_MV_NAME + " as select deptno from " + + DEPTS_TABLE_NAME + ";"; + String query = "select * from (select empid, deptno + 1 deptno from " + EMPS_TABLE_NAME + + " where empid is not null ) A full join (select deptno from " + DEPTS_TABLE_NAME + + " where deptno is not null ) B on A.deptno = B.deptno;"; + createMv(createEmpsMVsql); + createMv(createDeptsMVSQL); + testMv(query, ImmutableMap.of(EMPS_TABLE_NAME, EMPS_MV_NAME, DEPTS_TABLE_NAME, DEPTS_MV_NAME)); + } + + /** + * TODO: enable this when order by column not in project is supported. + */ + @Disabled + public void testOrderByQueryOnProjectView() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid from " + + EMPS_TABLE_NAME + ";"; + String query = "select empid from " + EMPS_TABLE_NAME + " order by deptno"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + /** + * TODO: enable this when order by column not in select is supported. + */ + @Disabled + public void testOrderByQueryOnOrderByView() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid from " + + EMPS_TABLE_NAME + " order by deptno;"; + String query = "select empid from " + EMPS_TABLE_NAME + " order by deptno"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testAggregateMVAggregateFuncs1() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by empid, deptno;"; + String query = "select deptno from " + EMPS_TABLE_NAME + " group by deptno"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testAggregateMVAggregateFuncs2() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by empid, deptno;"; + String query = "select deptno, sum(salary) from " + EMPS_TABLE_NAME + " group by deptno"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testAggregateMVAggregateFuncs3() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by empid, deptno;"; + String query = "select deptno, empid, sum(salary) from " + EMPS_TABLE_NAME + " group by deptno, empid"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testAggregateMVAggregateFuncs4() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by empid, deptno;"; + String query = "select deptno, sum(salary) from " + EMPS_TABLE_NAME + " where deptno > 10 group by deptno"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testAggregateMVAggregateFuncs5() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by empid, deptno;"; + String query = "select deptno, sum(salary) + 1 from " + EMPS_TABLE_NAME + " where deptno > 10 group by deptno"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testAggregateMVCalcGroupByQuery1() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by empid, deptno;"; + String query = "select deptno+1, sum(salary) + 1 from " + EMPS_TABLE_NAME + " where deptno > 10 " + + "group by deptno+1;"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testAggregateMVCalcGroupByQuery2() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by empid, deptno;"; + String query = "select deptno * empid, sum(salary) + 1 from " + EMPS_TABLE_NAME + " where deptno > 10 " + + "group by deptno * empid;"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testAggregateMVCalcGroupByQuery3() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by empid, deptno;"; + String query = "select empid, deptno * empid, sum(salary) + 1 from " + EMPS_TABLE_NAME + " where deptno > 10 " + + "group by empid, deptno * empid;"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testAggregateMVCalcAggFunctionQuery() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by empid, deptno;"; + String query = "select deptno, sum(salary + 1) from " + EMPS_TABLE_NAME + " where deptno > 10 " + + "group by deptno;"; + createMv(createEmpsMVsql); + testMv(query, EMPS_TABLE_NAME); + } + + /** + * TODO: enable this when estimate stats bug fixed. + */ + @Disabled + public void testSubQuery() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid " + + "from " + EMPS_TABLE_NAME + ";"; + createMv(createEmpsMVsql); + String query = "select empid, deptno, salary from " + EMPS_TABLE_NAME + " e1 where empid = (select max(empid)" + + " from " + EMPS_TABLE_NAME + " where deptno = e1.deptno);"; + PlanChecker.from(connectContext).checkPlannerResult(query, planner -> { + List scans = planner.getScanNodes(); + Assertions.assertEquals(2, scans.size()); + + ScanNode scanNode0 = scans.get(0); + Assertions.assertTrue(scanNode0 instanceof OlapScanNode); + OlapScanNode scan0 = (OlapScanNode) scanNode0; + Assertions.assertTrue(scan0.isPreAggregation()); + Assertions.assertEquals("emps_mv", scan0.getSelectedIndexName()); + + ScanNode scanNode1 = scans.get(1); + Assertions.assertTrue(scanNode1 instanceof OlapScanNode); + OlapScanNode scan1 = (OlapScanNode) scanNode1; + Assertions.assertTrue(scan1.isPreAggregation()); + Assertions.assertEquals("emps", scan1.getSelectedIndexName()); + }); + } + + /** + * TODO: enable this when sum(distinct xxx) is supported. + */ + @Disabled + public void testDistinctQuery() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by deptno;"; + String query1 = "select distinct deptno from " + EMPS_TABLE_NAME + ";"; + createMv(createEmpsMVsql); + testMv(query1, EMPS_MV_NAME); + String query2 = "select deptno, sum(distinct salary) from " + EMPS_TABLE_NAME + " group by deptno;"; + testMv(query2, EMPS_MV_NAME); + } + + @Test + public void testSingleMVMultiUsage() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select deptno, empid, salary " + + "from " + EMPS_TABLE_NAME + " order by deptno;"; + String query = "select * from (select deptno, empid from " + EMPS_TABLE_NAME + " where deptno>100) A join " + + "(select deptno, empid from " + EMPS_TABLE_NAME + " where deptno >200) B on A.deptno=B.deptno;"; + createMv(createEmpsMVsql); + // both of the 2 table scans should use mv index. + testMv(query, ImmutableMap.of(EMPS_TABLE_NAME, EMPS_MV_NAME)); + } + + @Test + public void testMultiMVMultiUsage() throws Exception { + String createEmpsMVSql01 = "create materialized view emp_mv_01 as select deptno, empid, salary " + + "from " + EMPS_TABLE_NAME + " order by deptno;"; + String createEmpsMVSql02 = "create materialized view emp_mv_02 as select deptno, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by deptno;"; + createMv(createEmpsMVSql01); + createMv(createEmpsMVSql02); + String query = "select * from (select deptno, empid from " + EMPS_TABLE_NAME + " where deptno>100) A join " + + "(select deptno, sum(salary) from " + EMPS_TABLE_NAME + " where deptno >200 group by deptno) B " + + "on A.deptno=B.deptno"; + PlanChecker.from(connectContext).checkPlannerResult(query, planner -> { + List scans = planner.getScanNodes(); + Assertions.assertEquals(2, scans.size()); + + ScanNode scanNode0 = scans.get(0); + Assertions.assertTrue(scanNode0 instanceof OlapScanNode); + OlapScanNode scan0 = (OlapScanNode) scanNode0; + Assertions.assertTrue(scan0.isPreAggregation()); + Assertions.assertEquals("emp_mv_01", scan0.getSelectedIndexName()); + + ScanNode scanNode1 = scans.get(1); + Assertions.assertTrue(scanNode1 instanceof OlapScanNode); + OlapScanNode scan1 = (OlapScanNode) scanNode1; + Assertions.assertTrue(scan1.isPreAggregation()); + Assertions.assertEquals("emp_mv_02", scan1.getSelectedIndexName()); + }); + } + + @Test + public void testMVOnJoinQuery() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select salary, empid, deptno from " + + EMPS_TABLE_NAME + " order by salary;"; + createMv(createEmpsMVsql); + String query = "select empid, salary from " + EMPS_TABLE_NAME + " join " + DEPTS_TABLE_NAME + + " on emps.deptno=depts.deptno where salary > 300;"; + testMv(query, ImmutableMap.of(EMPS_TABLE_NAME, EMPS_MV_NAME, DEPTS_TABLE_NAME, DEPTS_TABLE_NAME)); + } + + @Test + public void testAggregateMVOnCountDistinctQuery1() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno, sum(salary) " + + "from " + EMPS_TABLE_NAME + " group by empid, deptno;"; + String query = "select deptno, count(distinct empid) from " + EMPS_TABLE_NAME + " group by deptno;"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + @Test + public void testQueryAfterTrimingOfUnusedFields() throws Exception { + String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from " + + EMPS_TABLE_NAME + " order by empid, deptno;"; + String query = "select empid, deptno from (select empid, deptno, salary from " + EMPS_TABLE_NAME + ") A;"; + createMv(createEmpsMVsql); + testMv(query, EMPS_MV_NAME); + } + + /** + * TODO: enable this when union is supported. + */ + @Disabled + public void testUnionAll() throws Exception { + // String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from " + // + EMPS_TABLE_NAME + " order by empid, deptno;"; + // String query = "select empid, deptno from " + EMPS_TABLE_NAME + " where empid >1 union all select empid," + // + " deptno from " + EMPS_TABLE_NAME + " where empid <0;"; + // dorisAssert.withMaterializedView(createEmpsMVsql).query(query).explainContains(QUERY_USE_EMPS_MV, 2); + } + + /** + * TODO: enable this when union is supported. + */ + @Disabled + public void testUnionDistinct() throws Exception { + // String createEmpsMVsql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from " + // + EMPS_TABLE_NAME + " order by empid, deptno;"; + // String query = "select empid, deptno from " + EMPS_TABLE_NAME + " where empid >1 union select empid," + // + " deptno from " + EMPS_TABLE_NAME + " where empid <0;"; + // dorisAssert.withMaterializedView(createEmpsMVsql).query(query).explainContains(QUERY_USE_EMPS_MV, 2); + } + + /** + * Only key columns rollup for aggregate-keys table. + */ + @Test + public void testDeduplicateQueryInAgg() throws Exception { + String aggregateTable = "create table agg_table (k1 int, k2 int, v1 bigint sum) aggregate key (k1, k2) " + + "distributed by hash(k1) buckets 3 properties('replication_num' = '1');"; + createTable(aggregateTable); + + // don't use rollup k1_v1 + addRollup("alter table agg_table add rollup k1_v1(k1, v1)"); + // use rollup only_keys + addRollup("alter table agg_table add rollup only_keys (k1, k2) properties ('replication_num' = '1')"); + + String query = "select k1, k2 from agg_table;"; + // todo: `preagg` should be ture when rollup could be used. + singleTableTest(query, "only_keys", false); + } + + /** + * Group by only mv for duplicate-keys table. + * duplicate table (k1, k2, v1 sum) + * aggregate mv index (k1, k2) + */ + @Test + public void testGroupByOnlyForDuplicateTable() throws Exception { + createTable("create table t (k1 int, k2 int, v1 bigint) duplicate key(k1, k2, v1)" + + "distributed by hash(k1) buckets 3 properties('replication_num' = '1')"); + createMv("create materialized view k1_k2 as select k1, k2 from t group by k1, k2"); + singleTableTest("select k1, k2 from t group by k1, k2", "k1_k2", true); + } + + @Test + public void testAggFunctionInHaving() throws Exception { + String duplicateTable = "CREATE TABLE " + TEST_TABLE_NAME + " ( k1 int(11) NOT NULL , k2 int(11) NOT NULL ," + + "v1 varchar(4096) NOT NULL, v2 float NOT NULL , v3 decimal(20, 7) NOT NULL ) ENGINE=OLAP " + + "DUPLICATE KEY( k1 , k2 ) DISTRIBUTED BY HASH( k1 , k2 ) BUCKETS 3 " + + "PROPERTIES ('replication_num' = '1'); "; + createTable(duplicateTable); + String createK1K2MV = "create materialized view k1_k2 as select k1,k2 from " + TEST_TABLE_NAME + " group by " + + "k1,k2;"; + createMv(createK1K2MV); + String query = "select k1 from " + TEST_TABLE_NAME + " group by k1 having max(v1) > 10;"; + testMv(query, TEST_TABLE_NAME); + dropTable(TEST_TABLE_NAME, true); + } + + /** + * TODO: enable this when order by aggregate function is supported. + */ + @Disabled + public void testAggFunctionInOrder() throws Exception { + String duplicateTable = "CREATE TABLE " + TEST_TABLE_NAME + " ( k1 int(11) NOT NULL , k2 int(11) NOT NULL ," + + "v1 varchar(4096) NOT NULL, v2 float NOT NULL , v3 decimal(20, 7) NOT NULL ) ENGINE=OLAP " + + "DUPLICATE KEY( k1 , k2 ) DISTRIBUTED BY HASH( k1 , k2 ) BUCKETS 3 " + + "PROPERTIES ('replication_num' = '1'); "; + createTable(duplicateTable); + String createK1K2MV = "create materialized view k1_k2 as select k1,k2 from " + TEST_TABLE_NAME + " group by " + + "k1,k2;"; + createMv(createK1K2MV); + String query = "select k1 from " + TEST_TABLE_NAME + " group by k1 order by max(v1);"; + testMv(query, TEST_TABLE_NAME); + dropTable(TEST_TABLE_NAME, true); + } + + /** + * TODO: enable when window is supported. + */ + @Test + public void testWindowsFunctionInQuery() throws Exception { + // String duplicateTable = "CREATE TABLE " + TEST_TABLE_NAME + " ( k1 int(11) NOT NULL , k2 int(11) NOT NULL ," + // + "v1 varchar(4096) NOT NULL, v2 float NOT NULL , v3 decimal(20, 7) NOT NULL ) ENGINE=OLAP " + // + "DUPLICATE KEY( k1 , k2 ) DISTRIBUTED BY HASH( k1 , k2 ) BUCKETS 3 " + // + "PROPERTIES ('replication_num' = '1'); "; + // dorisAssert.withTable(duplicateTable); + // String createK1K2MV = "create materialized view k1_k2 as select k1,k2 from " + TEST_TABLE_NAME + " group by " + // + "k1,k2;"; + // String query = "select k1 , sum(k2) over (partition by v1 ) from " + TEST_TABLE_NAME + ";"; + // dorisAssert.withMaterializedView(createK1K2MV).query(query).explainWithout("k1_k2"); + // dorisAssert.dropTable(TEST_TABLE_NAME, true); + } + + @Test + public void testUniqueTableInQuery() throws Exception { + String uniqueTable = "CREATE TABLE " + TEST_TABLE_NAME + " (k1 int, v1 int) UNIQUE KEY (k1) " + + "DISTRIBUTED BY HASH(k1) BUCKETS 3 PROPERTIES ('replication_num' = '1');"; + createTable(uniqueTable); + String createK1MV = "create materialized view only_k1 as select k1 from " + TEST_TABLE_NAME + " group by " + + "k1;"; + createMv(createK1MV); + String query = "select * from " + TEST_TABLE_NAME + ";"; + singleTableTest(query, TEST_TABLE_NAME, false); + } + + /** + * TODO: enable this when bitmap is supported. + */ + @Disabled + public void testBitmapUnionInQuery() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + // + " as select user_id, bitmap_union(to_bitmap(tag_id)) from " + // + USER_TAG_TABLE_NAME + " group by user_id;"; + // dorisAssert.withMaterializedView(createUserTagMVSql); + // String query = "select user_id, bitmap_union_count(to_bitmap(tag_id)) a from " + USER_TAG_TABLE_NAME + // + " group by user_id having a>1 order by a;"; + // dorisAssert.query(query).explainContains(QUERY_USE_USER_TAG_MV); + } + + /** + * TODO: enable this when bitmap is supported. + */ + @Disabled + public void testBitmapUnionInSubquery() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, " + // + "bitmap_union(to_bitmap(tag_id)) from " + USER_TAG_TABLE_NAME + " group by user_id;"; + // dorisAssert.withMaterializedView(createUserTagMVSql); + // String query = "select user_id from " + USER_TAG_TABLE_NAME + " where user_id in (select user_id from " + // + USER_TAG_TABLE_NAME + " group by user_id having bitmap_union_count(to_bitmap(tag_id)) >1 ) ;"; + // dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, USER_TAG_TABLE_NAME); + } + + /** + * TODO: enable this when bitmap is supported. + */ + @Disabled + public void testIncorrectMVRewriteInQuery() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, " + // + "bitmap_union(to_bitmap(tag_id)) from " + USER_TAG_TABLE_NAME + " group by user_id;"; + // dorisAssert.withMaterializedView(createUserTagMVSql); + // String createEmpMVSql = "create materialized view " + EMPS_MV_NAME + " as select name, deptno from " + // + EMPS_TABLE_NAME + ";"; + // dorisAssert.withMaterializedView(createEmpMVSql); + // String query = "select user_name, bitmap_union_count(to_bitmap(tag_id)) a from " + USER_TAG_TABLE_NAME + ", " + // + "(select name, deptno from " + EMPS_TABLE_NAME + ") a" + " where user_name=a.name group by " + // + "user_name having a>1 order by a;"; + // testMv(query, EMPS_MV_NAME); + // dorisAssert.query(query).explainWithout(QUERY_USE_USER_TAG_MV); + } + + /** + * TODO: enable this when bitmap is supported. + */ + @Disabled + public void testIncorrectMVRewriteInSubquery() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, " + // + "bitmap_union(to_bitmap(tag_id)) from " + USER_TAG_TABLE_NAME + " group by user_id;"; + // dorisAssert.withMaterializedView(createUserTagMVSql); + // String query = "select user_id, bitmap_union(to_bitmap(tag_id)) from " + USER_TAG_TABLE_NAME + " where " + // + "user_name in (select user_name from " + USER_TAG_TABLE_NAME + " group by user_name having " + // + "bitmap_union_count(to_bitmap(tag_id)) >1 )" + " group by user_id;"; + // dorisAssert.query(query).explainContains(QUERY_USE_USER_TAG); + } + + /** + * TODO: enable this when bitmap is supported. + */ + @Disabled + public void testTwoTupleInQuery() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, " + // + "bitmap_union(to_bitmap(tag_id)) from " + USER_TAG_TABLE_NAME + " group by user_id;"; + // dorisAssert.withMaterializedView(createUserTagMVSql); + // String query = "select * from (select user_id, bitmap_union_count(to_bitmap(tag_id)) x from " + // + USER_TAG_TABLE_NAME + " group by user_id) a, (select user_name, bitmap_union_count(to_bitmap(tag_id))" + // + "" + " y from " + USER_TAG_TABLE_NAME + " group by user_name) b where a.x=b.y;"; + // dorisAssert.query(query).explainContains(QUERY_USE_USER_TAG, QUERY_USE_USER_TAG_MV); + } + + /** + * TODO: enable this when bitmap is supported. + */ + @Disabled + public void testAggTableCountDistinctInBitmapType() throws Exception { + // String aggTable = "CREATE TABLE " + TEST_TABLE_NAME + " (k1 int, v1 bitmap bitmap_union) Aggregate KEY (k1) " + // + "DISTRIBUTED BY HASH(k1) BUCKETS 3 PROPERTIES ('replication_num' = '1');"; + // dorisAssert.withTable(aggTable); + // String query = "select k1, count(distinct v1) from " + TEST_TABLE_NAME + " group by k1;"; + // dorisAssert.query(query).explainContains(TEST_TABLE_NAME, "bitmap_union_count"); + // dorisAssert.dropTable(TEST_TABLE_NAME, true); + } + + /** + * TODO: enable this when hll is supported. + */ + @Disabled + public void testAggTableCountDistinctInHllType() throws Exception { + // String aggTable = "CREATE TABLE " + TEST_TABLE_NAME + " (k1 int, v1 hll " + FunctionSet.HLL_UNION + // + ") Aggregate KEY (k1) " + // + "DISTRIBUTED BY HASH(k1) BUCKETS 3 PROPERTIES ('replication_num' = '1');"; + // dorisAssert.withTable(aggTable); + // String query = "select k1, count(distinct v1) from " + TEST_TABLE_NAME + " group by k1;"; + // dorisAssert.query(query).explainContains(TEST_TABLE_NAME, "hll_union_agg"); + // dorisAssert.dropTable(TEST_TABLE_NAME, true); + } + + /** + * TODO: enable this when bitmap is supported. + */ + @Disabled + public void testCountDistinctToBitmap() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, " + // + "bitmap_union(to_bitmap(tag_id)) from " + USER_TAG_TABLE_NAME + " group by user_id;"; + // dorisAssert.withMaterializedView(createUserTagMVSql); + // String query = "select count(distinct tag_id) from " + USER_TAG_TABLE_NAME + ";"; + // dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, "bitmap_union_count"); + } + + /** + * TODO: enable this when bitmap is supported. + */ + @Disabled + public void testIncorrectRewriteCountDistinct() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, " + // + "bitmap_union(to_bitmap(tag_id)) from " + USER_TAG_TABLE_NAME + " group by user_id;"; + // dorisAssert.withMaterializedView(createUserTagMVSql); + // String query = "select user_name, count(distinct tag_id) from " + USER_TAG_TABLE_NAME + " group by user_name;"; + // dorisAssert.query(query).explainContains(USER_TAG_TABLE_NAME, FunctionSet.COUNT); + } + + /** + * TODO: enable this when hll is supported. + */ + @Disabled + public void testNDVToHll() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, " + // + "`" + FunctionSet.HLL_UNION + "`(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME + // + " group by user_id;"; + // dorisAssert.withMaterializedView(createUserTagMVSql); + // String query = "select ndv(tag_id) from " + USER_TAG_TABLE_NAME + ";"; + // dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, "hll_union_agg"); + } + + /** + * TODO: enable this when hll is supported. + */ + @Disabled + public void testApproxCountDistinctToHll() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, " + // + "`" + FunctionSet.HLL_UNION + "`(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME + // + " group by user_id;"; + // dorisAssert.withMaterializedView(createUserTagMVSql); + // String query = "select approx_count_distinct(tag_id) from " + USER_TAG_TABLE_NAME + ";"; + // dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, "hll_union_agg"); + } + + /** + * TODO: enable this when hll is supported. + */ + @Test + public void testHLLUnionFamilyRewrite() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, " + // + "`" + FunctionSet.HLL_UNION + "`(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME + // + " group by user_id;"; + // createMv(createUserTagMVSql); + // String query = "select `" + FunctionSet.HLL_UNION + "`(" + FunctionSet.HLL_HASH + "(tag_id)) from " + // + USER_TAG_TABLE_NAME + ";"; + // String mvColumnName = CreateMaterializedViewStmt.mvColumnBuilder("" + FunctionSet.HLL_UNION + "", "tag_id"); + // dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, mvColumnName); + // query = "select hll_union_agg(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME + ";"; + // dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, mvColumnName); + // query = "select hll_raw_agg(" + FunctionSet.HLL_HASH + "(tag_id)) from " + USER_TAG_TABLE_NAME + ";"; + // dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, mvColumnName); + } + + @Test + public void testAggInHaving() throws Exception { + String createMVSql = "create materialized view " + EMPS_MV_NAME + " as select empid, deptno from " + + EMPS_TABLE_NAME + " group by empid, deptno;"; + createMv(createMVSql); + String query = "select empid from " + EMPS_TABLE_NAME + " group by empid having max(salary) > 1;"; + testMv(query, EMPS_TABLE_NAME); + } + + /** + * TODO: support count in mv. + */ + @Disabled + public void testCountFieldInQuery() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, " + // + "count(tag_id) from " + USER_TAG_TABLE_NAME + " group by user_id;"; + // createMv(createUserTagMVSql); + // String query = "select count(tag_id) from " + USER_TAG_TABLE_NAME + ";"; + // String mvColumnName = CreateMaterializedViewStmt.mvColumnBuilder(FunctionSet.COUNT, "tag_id"); + // // dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, mvColumnName); + // + // String explain = getSQLPlanOrErrorMsg(query); + // mv_count_tag_id + /* + PARTITION: HASH_PARTITIONED: `default_cluster:db1`.`user_tags`.`time_col` + + STREAM DATA SINK + EXCHANGE ID: 02 + UNPARTITIONED + + 1:VAGGREGATE (update serialize) + | output: sum(`mv_count_tag_id`) + | group by: + | cardinality=1 + | + 0:VOlapScanNode + TABLE: user_tags(user_tags_mv), PREAGGREGATION: ON + partitions=1/1, tablets=3/3, tabletList=10034,10036,10038 + cardinality=0, avgRowSize=8.0, numNodes=1 + */ + // System.out.println("mvColumnName:" + mvColumnName); + // System.out.println("explain:\n" + explain); + // query = "select user_name, count(tag_id) from " + USER_TAG_TABLE_NAME + " group by user_name;"; + // dorisAssert.query(query).explainWithout(USER_TAG_MV_NAME); + } + + /** + * TODO: enable this when bitmap is supported. + */ + @Disabled + public void testCreateMVBaseBitmapAggTable() throws Exception { + String createTableSQL = "create table " + HR_DB_NAME + ".agg_table " + + "(empid int, name varchar, salary bitmap " + FunctionSet.BITMAP_UNION + ") " + + "aggregate key (empid, name) " + + "partition by range (empid) " + + "(partition p1 values less than MAXVALUE) " + + "distributed by hash(empid) buckets 3 properties('replication_num' = '1');"; + createTable(createTableSQL); + String createMVSql = "create materialized view mv as select empid, " + FunctionSet.BITMAP_UNION + + "(salary) from agg_table " + + "group by empid;"; + createMv(createMVSql); + String query = "select count(distinct salary) from agg_table;"; + testMv(query, "mv"); + dropTable("agg_table", true); + } + + /** + * TODO: support count in mv. + */ + @Disabled + public void testSelectMVWithTableAlias() throws Exception { + // String createUserTagMVSql = "create materialized view " + USER_TAG_MV_NAME + " as select user_id, " + // + "count(tag_id) from " + USER_TAG_TABLE_NAME + " group by user_id;"; + // dorisAssert.withMaterializedView(createUserTagMVSql); + // String query = "select count(tag_id) from " + USER_TAG_TABLE_NAME + " t ;"; + // String mvColumnName = CreateMaterializedViewStmt.mvColumnBuilder(FunctionSet.COUNT, "tag_id"); + // dorisAssert.query(query).explainContains(USER_TAG_MV_NAME, mvColumnName); + } + + private void testMv(String sql, Map tableToIndex) { + PlanChecker.from(connectContext).checkPlannerResult(sql, planner -> { + List scans = planner.getScanNodes(); + for (ScanNode scanNode : scans) { + Assertions.assertTrue(scanNode instanceof OlapScanNode); + OlapScanNode olapScan = (OlapScanNode) scanNode; + Assertions.assertTrue(olapScan.isPreAggregation()); + Assertions.assertEquals(tableToIndex.get(olapScan.getOlapTable().getName()), + olapScan.getSelectedIndexName()); + } + }); + } + + private void testMv(String sql, String indexName) { + singleTableTest(sql, indexName, true); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectRollupTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectRollupIndexTest.java similarity index 59% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectRollupTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectRollupIndexTest.java index 256aec0b6613c3a..091d39013c98f38 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectRollupTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectRollupIndexTest.java @@ -21,22 +21,22 @@ import org.apache.doris.nereids.trees.plans.PreAggStatus; import org.apache.doris.nereids.util.PatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; -import org.apache.doris.planner.OlapScanNode; -import org.apache.doris.planner.ScanNode; -import org.apache.doris.utframe.TestWithFeService; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.List; +class SelectRollupIndexTest extends BaseMaterializedIndexSelectTest implements PatternMatchSupported { -class SelectRollupTest extends TestWithFeService implements PatternMatchSupported { + @Override + protected void beforeCreatingConnectContext() throws Exception { + FeConstants.default_scheduler_interval_millisecond = 10; + FeConstants.runningUnitTest = true; + } @Override protected void runBeforeAll() throws Exception { - FeConstants.runningUnitTest = true; createDatabase("test"); - connectContext.setDatabase("default_cluster:test"); + useDatabase("test"); createTable("CREATE TABLE `t` (\n" + " `k1` int(11) NULL,\n" @@ -57,6 +57,26 @@ protected void runBeforeAll() throws Exception { // waiting table state to normal Thread.sleep(500); addRollup("alter table t add rollup r2(k2, k3, v1)"); + addRollup("alter table t add rollup r3(k2)"); + addRollup("alter table t add rollup r4(k2, k3)"); + + createTable("CREATE TABLE `t1` (\n" + + " `k1` int(11) NULL,\n" + + " `k2` int(11) NULL,\n" + + " `v1` int(11) SUM NULL\n" + + ") ENGINE=OLAP\n" + + "AGGREGATE KEY(`k1`, `k2`)\n" + + "COMMENT 'OLAP'\n" + + "DISTRIBUTED BY HASH(`k1`) BUCKETS 3\n" + + "PROPERTIES (\n" + + "\"replication_allocation\" = \"tag.location.default: 1\",\n" + + "\"in_memory\" = \"false\",\n" + + "\"storage_format\" = \"V2\",\n" + + "\"disable_auto_compaction\" = \"false\"\n" + + ");"); + addRollup("alter table t1 add rollup r1(k1)"); + addRollup("alter table t1 add rollup r2(k2, v1)"); + addRollup("alter table t1 add rollup r3(k1, k2)"); createTable("CREATE TABLE `duplicate_tbl` (\n" + " `k1` int(11) NULL,\n" @@ -77,24 +97,17 @@ protected void runBeforeAll() throws Exception { @Test public void testAggMatching() { - PlanChecker.from(connectContext) - .analyze(" select k2, sum(v1) from t group by k2") - .applyTopDown(new SelectRollupWithAggregate()) - .matches(logicalOlapScan().when(scan -> { - Assertions.assertTrue(scan.getPreAggStatus().isOn()); - Assertions.assertEquals("r1", scan.getSelectRollupName().get()); - return true; - })); + singleTableTest("select k2, sum(v1) from t group by k2", "r1", true); } @Test public void testMatchingBase() { PlanChecker.from(connectContext) .analyze(" select k1, sum(v1) from t group by k1") - .applyTopDown(new SelectRollupWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); - Assertions.assertEquals("t", scan.getSelectRollupName().get()); + Assertions.assertEquals("t", scan.getSelectedMaterializedIndexName().get()); return true; })); } @@ -103,10 +116,10 @@ public void testMatchingBase() { void testAggFilterScan() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k3=0 group by k2") - .applyTopDown(new SelectRollupWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); - Assertions.assertEquals("r2", scan.getSelectRollupName().get()); + Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); return true; })); } @@ -118,29 +131,22 @@ void testTranslate() { @Test public void testTranslateWhenPreAggIsOff() { - PlanChecker.from(connectContext).checkPlannerResult( - "select k2, min(v1) from t group by k2", - planner -> { - List scans = planner.getScanNodes(); - Assertions.assertEquals(1, scans.size()); - ScanNode scanNode = scans.get(0); - Assertions.assertTrue(scanNode instanceof OlapScanNode); - OlapScanNode olapScan = (OlapScanNode) scanNode; - Assertions.assertFalse(olapScan.isPreAggregation()); - Assertions.assertEquals("Aggregate operator don't match, " - + "aggregate function: min(v1), column aggregate type: SUM", - olapScan.getReasonOfPreAggregation()); - }); + singleTableTest("select k2, min(v1) from t group by k2", scan -> { + Assertions.assertFalse(scan.isPreAggregation()); + Assertions.assertEquals("Aggregate operator don't match, " + + "aggregate function: min(v1), column aggregate type: SUM", + scan.getReasonOfPreAggregation()); + }); } @Test public void testWithEqualFilter() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k3=0 group by k2") - .applyTopDown(new SelectRollupWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); - Assertions.assertEquals("r2", scan.getSelectRollupName().get()); + Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); return true; })); } @@ -149,10 +155,10 @@ public void testWithEqualFilter() { public void testWithNonEqualFilter() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k3>0 group by k2") - .applyTopDown(new SelectRollupWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); - Assertions.assertEquals("r2", scan.getSelectRollupName().get()); + Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); return true; })); } @@ -161,10 +167,10 @@ public void testWithNonEqualFilter() { public void testWithFilter() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k2>3 group by k3") - .applyTopDown(new SelectRollupWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); - Assertions.assertEquals("r2", scan.getSelectRollupName().get()); + Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); return true; })); } @@ -176,11 +182,11 @@ public void testWithFilterAndProject() { + " where c3>0 group by c2"; PlanChecker.from(connectContext) .analyze(sql) - .applyTopDown(new SelectRollupWithAggregate()) - .applyTopDown(new SelectRollupWithoutAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); - Assertions.assertEquals("r2", scan.getSelectRollupName().get()); + Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); return true; })); } @@ -193,8 +199,8 @@ public void testWithFilterAndProject() { public void testNoAggregate() { PlanChecker.from(connectContext) .analyze("select k1, v1 from t") - .applyTopDown(new SelectRollupWithAggregate()) - .applyTopDown(new SelectRollupWithoutAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); @@ -207,8 +213,8 @@ public void testNoAggregate() { public void testAggregateTypeNotMatch() { PlanChecker.from(connectContext) .analyze("select k1, min(v1) from t group by k1") - .applyTopDown(new SelectRollupWithAggregate()) - .applyTopDown(new SelectRollupWithoutAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); @@ -222,8 +228,8 @@ public void testAggregateTypeNotMatch() { public void testInvalidSlotInAggFunction() { PlanChecker.from(connectContext) .analyze("select k1, sum(v1 + 1) from t group by k1") - .applyTopDown(new SelectRollupWithAggregate()) - .applyTopDown(new SelectRollupWithoutAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); @@ -237,8 +243,8 @@ public void testInvalidSlotInAggFunction() { public void testKeyColumnInAggFunction() { PlanChecker.from(connectContext) .analyze("select k1, sum(k2) from t group by k1") - .applyTopDown(new SelectRollupWithAggregate()) - .applyTopDown(new SelectRollupWithoutAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); @@ -252,12 +258,12 @@ public void testKeyColumnInAggFunction() { public void testMaxCanUseKeyColumn() { PlanChecker.from(connectContext) .analyze("select k2, max(k3) from t group by k3") - .applyTopDown(new SelectRollupWithAggregate()) - .applyTopDown(new SelectRollupWithoutAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); - Assertions.assertEquals("r2", scan.getSelectRollupName().get()); + Assertions.assertEquals("r4", scan.getSelectedMaterializedIndexName().get()); return true; })); } @@ -266,12 +272,12 @@ public void testMaxCanUseKeyColumn() { public void testMinCanUseKeyColumn() { PlanChecker.from(connectContext) .analyze("select k2, min(k3) from t group by k3") - .applyTopDown(new SelectRollupWithAggregate()) - .applyTopDown(new SelectRollupWithoutAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); - Assertions.assertEquals("r2", scan.getSelectRollupName().get()); + Assertions.assertEquals("r4", scan.getSelectedMaterializedIndexName().get()); return true; })); } @@ -280,8 +286,8 @@ public void testMinCanUseKeyColumn() { public void testDuplicatePreAggOn() { PlanChecker.from(connectContext) .analyze("select k1, sum(k1) from duplicate_tbl group by k1") - .applyTopDown(new SelectRollupWithAggregate()) - .applyTopDown(new SelectRollupWithoutAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -293,8 +299,8 @@ public void testDuplicatePreAggOn() { public void testDuplicatePreAggOnEvenWithoutAggregate() { PlanChecker.from(connectContext) .analyze("select k1, v1 from duplicate_tbl") - .applyTopDown(new SelectRollupWithAggregate()) - .applyTopDown(new SelectRollupWithoutAggregate()) + .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -302,4 +308,71 @@ public void testDuplicatePreAggOnEvenWithoutAggregate() { })); } + @Test + public void testKeysOnlyQuery() throws Exception { + singleTableTest("select k1 from t1", "r3", false); + singleTableTest("select k2 from t1", "r3", false); + singleTableTest("select k1, k2 from t1", "r3", false); + singleTableTest("select k1 from t1 group by k1", "r1", true); + singleTableTest("select k2 from t1 group by k2", "r2", true); + singleTableTest("select k1, k2 from t1 group by k1, k2", "r3", true); + } + + /** + * Rollup with all the keys should be used. + */ + @Test + public void testRollupWithAllTheKeys() throws Exception { + createTable(" CREATE TABLE `t4` (\n" + + " `k1` int(11) NULL,\n" + + " `k2` int(11) NULL,\n" + + " `v1` int(11) SUM NULL,\n" + + " `v2` int(11) SUM NULL\n" + + ") ENGINE=OLAP\n" + + "AGGREGATE KEY(`k1`, `k2`)\n" + + "COMMENT 'OLAP'\n" + + "DISTRIBUTED BY HASH(`k1`) BUCKETS 3\n" + + "PROPERTIES (\n" + + "\"replication_allocation\" = \"tag.location.default: 1\",\n" + + "\"in_memory\" = \"false\",\n" + + "\"storage_format\" = \"V2\",\n" + + "\"disable_auto_compaction\" = \"false\"\n" + + ");"); + addRollup("alter table t4 add rollup r1(k1, k2, v1)"); + + singleTableTest("select k1, k2, v1 from t4", "r1", false); + singleTableTest("select k1, k2, sum(v1) from t4 group by k1, k2", "r1", true); + singleTableTest("select k1, v1 from t4", "r1", false); + singleTableTest("select k1, sum(v1) from t4 group by k1", "r1", true); + } + + @Test + public void testComplexGroupingExpr() throws Exception { + singleTableTest("select k2 + 1, sum(v1) from t group by k2 + 1", "r1", true); + } + + @Test + public void testCountDistinctKeyColumn() { + singleTableTest("select k2, count(distinct k3) from t group by k2", "r4", true); + } + + @Test + public void testCountDistinctValueColumn() { + singleTableTest("select k1, count(distinct v1) from from t group by k1", scan -> { + Assertions.assertFalse(scan.isPreAggregation()); + Assertions.assertEquals("Count distinct is only valid for key columns, but meet count(distinct v1).", + scan.getReasonOfPreAggregation()); + Assertions.assertEquals("t", scan.getSelectedIndexName()); + }); + } + + @Test + public void testOnlyValueColumn1() throws Exception { + singleTableTest("select sum(v1) from t", "r1", true); + } + + @Test + public void testOnlyValueColumn2() throws Exception { + singleTableTest("select v1 from t", "t", false); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java b/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java index fc04e524f2d4d64..f02a698c6efde86 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java +++ b/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java @@ -22,6 +22,7 @@ import org.apache.doris.analysis.AlterTableStmt; import org.apache.doris.analysis.Analyzer; import org.apache.doris.analysis.CreateDbStmt; +import org.apache.doris.analysis.CreateMaterializedViewStmt; import org.apache.doris.analysis.CreatePolicyStmt; import org.apache.doris.analysis.CreateSqlBlockRuleStmt; import org.apache.doris.analysis.CreateTableAsSelectStmt; @@ -29,6 +30,7 @@ import org.apache.doris.analysis.CreateViewStmt; import org.apache.doris.analysis.DropPolicyStmt; import org.apache.doris.analysis.DropSqlBlockRuleStmt; +import org.apache.doris.analysis.DropTableStmt; import org.apache.doris.analysis.ExplainOptions; import org.apache.doris.analysis.ShowCreateTableStmt; import org.apache.doris.analysis.SqlParser; @@ -73,7 +75,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import org.junit.Assert; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; @@ -116,6 +117,7 @@ public abstract class TestWithFeService { @BeforeAll public final void beforeAll() throws Exception { + beforeCreatingConnectContext(); connectContext = createDefaultCtx(); createDorisCluster(); runBeforeAll(); @@ -134,6 +136,10 @@ public final void beforeEach() throws Exception { runBeforeEach(); } + protected void beforeCreatingConnectContext() throws Exception { + + } + protected void runBeforeAll() throws Exception { } @@ -452,6 +458,12 @@ public void createTable(String sql) throws Exception { createTables(sql); } + public void dropTable(String table, boolean force) throws Exception { + DropTableStmt dropTableStmt = (DropTableStmt) parseAndAnalyzeStmt( + "drop table " + table + (force ? " force" : "") + ";", connectContext); + Env.getCurrentEnv().dropTable(dropTableStmt); + } + public void createTableAsSelect(String sql) throws Exception { CreateTableAsSelectStmt createTableAsSelectStmt = (CreateTableAsSelectStmt) parseAndAnalyzeStmt(sql); Env.getCurrentEnv().createTableAsSelect(createTableAsSelectStmt); @@ -522,7 +534,16 @@ protected void addRollup(String sql) throws Exception { Thread.sleep(100); } - private void checkAlterJob() throws InterruptedException, MetaNotFoundException { + protected void createMv(String sql) throws Exception { + CreateMaterializedViewStmt createMaterializedViewStmt = + (CreateMaterializedViewStmt) UtFrameUtils.parseAndAnalyzeStmt(sql, connectContext); + Env.getCurrentEnv().createMaterializedView(createMaterializedViewStmt); + checkAlterJob(); + // waiting table state to normal + Thread.sleep(100); + } + + private void checkAlterJob() throws InterruptedException { // check alter job Map alterJobs = Env.getCurrentEnv().getMaterializedViewHandler().getAlterJobsV2(); for (AlterJobV2 alterJobV2 : alterJobs.values()) { @@ -532,17 +553,23 @@ private void checkAlterJob() throws InterruptedException, MetaNotFoundException Thread.sleep(100); } System.out.println("alter job " + alterJobV2.getDbId() + " is done. state: " + alterJobV2.getJobState()); - Assert.assertEquals(AlterJobV2.JobState.FINISHED, alterJobV2.getJobState()); - - // Add table state check in case of below Exception: - // there is still a short gap between "job finish" and "table become normal", - // so if user send next alter job right after the "job finish", - // it may encounter "table's state not NORMAL" error. - Database db = - Env.getCurrentInternalCatalog().getDbOrMetaException(alterJobV2.getDbId()); - OlapTable tbl = (OlapTable) db.getTableOrMetaException(alterJobV2.getTableId(), Table.TableType.OLAP); - while (tbl.getState() != OlapTable.OlapTableState.NORMAL) { - Thread.sleep(1000); + Assertions.assertEquals(AlterJobV2.JobState.FINISHED, alterJobV2.getJobState()); + + try { + // Add table state check in case of below Exception: + // there is still a short gap between "job finish" and "table become normal", + // so if user send next alter job right after the "job finish", + // it may encounter "table's state not NORMAL" error. + Database db = + Env.getCurrentInternalCatalog().getDbOrMetaException(alterJobV2.getDbId()); + OlapTable tbl = (OlapTable) db.getTableOrMetaException(alterJobV2.getTableId(), Table.TableType.OLAP); + while (tbl.getState() != OlapTable.OlapTableState.NORMAL) { + Thread.sleep(1000); + } + } catch (MetaNotFoundException e) { + // Sometimes table could be dropped by tests, but the corresponding alter job is not deleted yet. + // Ignore this error. + System.out.println(e.getMessage()); } } }