diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java index e7bf9a57358d5d..98daf7425f5ae0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java @@ -31,7 +31,6 @@ import org.apache.doris.nereids.rules.analysis.CollectSubQueryAlias; import org.apache.doris.nereids.rules.analysis.CompressedMaterialize; import org.apache.doris.nereids.rules.analysis.EliminateDistinctConstant; -import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.rules.analysis.EliminateLogicalSelectHint; import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots; import org.apache.doris.nereids.rules.analysis.FillUpQualifyMissingSlot; @@ -141,8 +140,6 @@ private static List buildAnalyzerJobs() { // select SUM(lo_tax) FROM lineorder group by 1; // errCode = 2, detailMessage = GROUP BY expression must not contain aggregate functions: sum(lo_tax) bottomUp(new CheckAnalysis()), - topDown(new EliminateGroupByConstant()), - topDown(new SimplifyAggGroupBy()), bottomUp(new CompressedMaterialize()), topDown(new NormalizeAggregate()), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 80c8760d0e545e..f2a5f9d881f746 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -25,7 +25,6 @@ import org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet; import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount; import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite; -import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite; @@ -175,7 +174,6 @@ public class Rewriter extends AbstractBatchJobExecutor { topDown( new EliminateOrderByConstant(), new EliminateSortUnderSubqueryOrView(), - new EliminateGroupByConstant(), // MergeProjects depends on this rule new LogicalSubQueryAliasToLogicalProject(), // TODO: we should do expression normalization after plan normalization diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index e5ebee120a310c..4a2e226caae962 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -17,9 +17,12 @@ package org.apache.doris.nereids.rules.analysis; +import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE; import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot; import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Alias; @@ -35,6 +38,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; @@ -50,6 +54,7 @@ import com.google.common.collect.Sets; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -111,14 +116,16 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot { public List buildRules() { return ImmutableList.of( logicalHaving(logicalAggregate().whenNot(LogicalAggregate::isNormalized)) - .then(having -> normalizeAgg(having.child(), Optional.of(having))) + .thenApply(ctx -> normalizeAgg(ctx.root.child(), Optional.of(ctx.root), ctx.cascadesContext)) .toRule(RuleType.NORMALIZE_AGGREGATE), logicalAggregate().whenNot(LogicalAggregate::isNormalized) - .then(aggregate -> normalizeAgg(aggregate, Optional.empty())) + .thenApply(ctx -> normalizeAgg(ctx.root, Optional.empty(), ctx.cascadesContext)) .toRule(RuleType.NORMALIZE_AGGREGATE)); } - private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional> having) { + @SuppressWarnings("checkstyle:UnusedLocalVariable") + private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional> having, + CascadesContext ctx) { // The LogicalAggregate node may contain window agg functions and usual agg functions // we call window agg functions as window-agg and usual agg functions as trivial-agg for short // This rule simplify LogicalAggregate node by: @@ -279,8 +286,10 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional upperProjects = normalizeOutput(aggregateOutput, groupByExprContext, argsOfAggFuncNeedPushDownContext, normalizedAggFuncsToSlotContext); - // create a parent project node - LogicalProject project = new LogicalProject<>(upperProjects, newAggregate); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx); + LogicalProject project = eliminateGroupByConstant(groupByExprContext, rewriteContext, + normalizedGroupExprs, normalizedAggOutput, bottomProjects, aggregate, upperProjects, newAggregate); + // verify project used slots are all coming from agg's output List slots = collectAllUsedSlots(upperProjects); if (!slots.isEmpty()) { @@ -389,4 +398,93 @@ private Expression normalizeAggFuncChildren(NormalizeToSlotContext context, Expr return expr; } } + + private LogicalProject eliminateGroupByConstant(NormalizeToSlotContext groupByExprContext, + ExpressionRewriteContext rewriteContext, List normalizedGroupExprs, + List normalizedAggOutput, Set bottomProjects, + LogicalAggregate aggregate, List upperProjects, LogicalAggregate newAggregate) { + // 1. Find the expressions in group by that can be folded into constants and build a map(slot, literal) + Map replaceMap = groupByExprContext.getNormalizeToSlotMap(); + if (replaceMap.isEmpty()) { + return new LogicalProject<>(upperProjects, newAggregate); + } + Map slotToLiteral = new HashMap<>(); + for (Map.Entry entry : replaceMap.entrySet()) { + Expression foldExpression = FoldConstantRuleOnFE.evaluate(entry.getKey(), rewriteContext); + if (foldExpression.isConstant()) { + slotToLiteral.put(entry.getValue().remainExpr, foldExpression); + } + } + if (slotToLiteral.isEmpty()) { + return new LogicalProject<>(upperProjects, newAggregate); + } + // 2. Regenerate a group by list without constant key + List newNormalizedGroupExprs = new ArrayList<>(); + for (Expression normalizedGroupExpr : normalizedGroupExprs) { + if (!slotToLiteral.containsKey((Slot) normalizedGroupExpr)) { + newNormalizedGroupExprs.add(normalizedGroupExpr); + } + } + if (newNormalizedGroupExprs.size() == normalizedGroupExprs.size()) { + return new LogicalProject<>(upperProjects, newAggregate); + } + if (newNormalizedGroupExprs.isEmpty()) { + Alias tinyInt = new Alias(new TinyIntLiteral((byte) 1)); + bottomProjects = new HashSet<>(bottomProjects); + bottomProjects.add(tinyInt); + normalizedAggOutput = new ArrayList<>(normalizedAggOutput); + Slot tinyIntSlot = tinyInt.toSlot(); + normalizedAggOutput.add(tinyIntSlot); + newNormalizedGroupExprs.add(tinyIntSlot); + } + // 3. Replace the agg output expression and delete the constant group by key in the output + ImmutableList.Builder nonConstAggOutput = ImmutableList.builder(); + for (NamedExpression ne : normalizedAggOutput) { + if (ne instanceof Alias) { + nonConstAggOutput.add(ExpressionUtils.replaceNameExpression(ne, slotToLiteral)); + continue; + } else if (ne instanceof Slot) { + if (!slotToLiteral.containsKey(ne)) { + nonConstAggOutput.add(ne); + } + continue; + } + nonConstAggOutput.add(ne); + } + + // 4. The constant expression calculation in bottom projects needs to be deleted + // and put into upperProjects for calculation + Plan bottomPlan; + if (!bottomProjects.isEmpty()) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (NamedExpression bottomProject : bottomProjects) { + if (!slotToLiteral.containsKey(bottomProject.toSlot())) { + builder.add(bottomProject); + } + } + bottomPlan = new LogicalProject<>(builder.build(), aggregate.child()); + } else { + bottomPlan = aggregate.child(); + } + LogicalAggregate newAggAfterEliminate = aggregate.withNormalized(newNormalizedGroupExprs, + nonConstAggOutput.build(), bottomPlan); + // 5. This upperProjects needs to add the constant key that was deleted in the group by key + // and change the reference to the constant key to a constant expression + ImmutableList.Builder newUpperProjects = ImmutableList.builder(); + for (NamedExpression upperProject : upperProjects) { + if (upperProject instanceof Alias) { + newUpperProjects.add(ExpressionUtils.replaceNameExpression(upperProject, slotToLiteral)); + continue; + } else if (upperProject instanceof Slot) { + if (slotToLiteral.containsKey(upperProject)) { + Alias newLiteral = new Alias(upperProject.getExprId(), slotToLiteral.get(upperProject), + upperProject.getName()); + newUpperProjects.add(newLiteral); + continue; + } + } + newUpperProjects.add(upperProject); + } + return new LogicalProject<>(newUpperProjects.build(), newAggAfterEliminate); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java deleted file mode 100644 index c35b983911c859..00000000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstantTest.java +++ /dev/null @@ -1,165 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.analysis; - -import org.apache.doris.catalog.AggregateType; -import org.apache.doris.catalog.Column; -import org.apache.doris.catalog.KeysType; -import org.apache.doris.catalog.OlapTable; -import org.apache.doris.catalog.PartitionInfo; -import org.apache.doris.catalog.Type; -import org.apache.doris.nereids.trees.expressions.Add; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.functions.agg.Max; -import org.apache.doris.nereids.trees.expressions.functions.agg.Min; -import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; -import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; -import org.apache.doris.nereids.trees.plans.RelationId; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.util.LogicalPlanBuilder; -import org.apache.doris.nereids.util.MemoPatternMatchSupported; -import org.apache.doris.nereids.util.MemoTestUtils; -import org.apache.doris.nereids.util.PlanChecker; -import org.apache.doris.thrift.TStorageType; - -import com.google.common.collect.ImmutableList; -import org.junit.jupiter.api.Test; - -/** Tests for {@link EliminateGroupByConstant}. */ -class EliminateGroupByConstantTest implements MemoPatternMatchSupported { - private static final OlapTable table = new OlapTable(0L, "student", - ImmutableList.of(new Column("k1", Type.INT, true, AggregateType.NONE, "0", ""), - new Column("k2", Type.INT, false, AggregateType.NONE, "0", ""), - new Column("k3", Type.INT, true, AggregateType.NONE, "", "")), - KeysType.PRIMARY_KEYS, new PartitionInfo(), null); - - static { - table.setIndexMeta(-1, - "t1", - table.getFullSchema(), - 0, 0, (short) 0, - TStorageType.COLUMN, - KeysType.PRIMARY_KEYS); - } - - private static final LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table); - private static final Slot k1 = scan.getOutput().get(0); - private static final Slot k2 = scan.getOutput().get(1); - - @Test - void testIntegerLiteral() { - LogicalPlan aggregate = new LogicalPlanBuilder(scan) - .agg(ImmutableList.of(new IntegerLiteral(1), k2), - ImmutableList.of(k1, k2)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate) - .applyTopDown(new EliminateGroupByConstant()) - .applyBottomUp(new CheckAfterRewrite()) - .matches( - aggregate().when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2))) - ); - } - - @Test - void testOtherLiteral() { - LogicalPlan aggregate = new LogicalPlanBuilder(scan) - .agg(ImmutableList.of( - new StringLiteral("str"), k2), - ImmutableList.of( - new Alias(new StringLiteral("str"), "str"), k1, k2)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate) - .applyTopDown(new EliminateGroupByConstant()) - .applyBottomUp(new CheckAfterRewrite()) - .matches( - aggregate().when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2))) - ); - } - - @Test - void testMixedLiteral() { - LogicalPlan aggregate = new LogicalPlanBuilder(scan) - .agg(ImmutableList.of( - new StringLiteral("str"), k2, - new IntegerLiteral(1), - new IntegerLiteral(2), - new IntegerLiteral(3), - new Add(k1, k2)), - ImmutableList.of( - new Alias(new StringLiteral("str"), "str"), - k2, k1, new Alias(new IntegerLiteral(1), "integer"))) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate) - .applyTopDown(new EliminateGroupByConstant()) - .applyBottomUp(new CheckAfterRewrite()) - .matches( - aggregate() - .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2)))) - ); - } - - @Test - void testComplexGroupBy() { - LogicalPlan aggregate = new LogicalPlanBuilder(scan) - .agg(ImmutableList.of( - new IntegerLiteral(1), - new IntegerLiteral(2), - new Add(k1, k2)), - ImmutableList.of( - new Alias(new Max(k1), "max"), - new Alias(new Min(k2), "min"), - new Alias(new Add(k1, k2), "add"))) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate) - .applyTopDown(new EliminateGroupByConstant()) - .applyBottomUp(new CheckAfterRewrite()) - .matches( - aggregate() - .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(new Add(k1, k2)))) - ); - } - - @Test - void testOutOfRange() { - LogicalPlan aggregate = new LogicalPlanBuilder(scan) - .agg(ImmutableList.of( - new StringLiteral("str"), k2, - new IntegerLiteral(1), - new IntegerLiteral(2), - new IntegerLiteral(3), - new IntegerLiteral(5), - new Add(k1, k2)), - ImmutableList.of( - new Alias(new StringLiteral("str"), "str"), - k2, k1, new Alias(new IntegerLiteral(1), "integer"))) - .build(); - PlanChecker.from(MemoTestUtils.createConnectContext(), aggregate) - .applyTopDown(new EliminateGroupByConstant()) - .applyBottomUp(new CheckAfterRewrite()) - .matches( - aggregate() - .when(agg -> agg.getGroupByExpressions().equals(ImmutableList.of(k2, new Add(k1, k2)))) - ); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java index 3808fd1842810f..2fa945b0011e2b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java @@ -37,23 +37,35 @@ import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import java.util.List; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class NormalizeAggregateTest implements MemoPatternMatchSupported { +public class NormalizeAggregateTest extends TestWithFeService implements MemoPatternMatchSupported { private LogicalPlan rStudent; - @BeforeAll - public final void beforeAll() { + @Override + protected void runBeforeAll() throws Exception { rStudent = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.student, ImmutableList.of()); + createDatabase("test"); + connectContext.setDatabase("default_cluster:test"); + createTables( + "CREATE TABLE IF NOT EXISTS t1 (\n" + + " id int not null,\n" + + " name char\n" + + ")\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 10\n" + + "PROPERTIES (\"replication_num\" = \"1\")\n" + ); + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); } /*- @@ -190,4 +202,103 @@ public void testComplexKeyWithComplexOutputOfKey() { ); } + + // add test for agg eliminate const + @Test + void testEliminateGroupByConst() { + String sql = "select id ,1, 'abc' from t1 group by 1,2,3"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalAggregate().when(aggregate -> aggregate.getGroupByExpressions().size() == 1)); + } + + @Test + void useTinyIntEliminateGroupByConst() { + String sql = "select 1, 'abc' from t1 group by 1,2"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .nonMatch(logicalAggregate()); + } + + @Test + void testMixedConstTypes() { + String sql = "select id, 1, 'abc', true from t1 group by 1, 2, 3, 4"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testNullConst() { + String sql = "select id, NULL from t1 group by 1, 2"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testTwoNullConst() { + String sql = "select Null, NULL from t1 group by 1, 2"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .nonMatch(logicalAggregate()); + } + + @Test + void testExpressionConst() { + String sql = "select id, 1+1, CONCAT('a','b') from t1 group by 1, 2, 3"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testFunctionCallConst() { + String sql = "select id, NOW(), PI() from t1 group by 1, 2, 3"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testDifferentOrder() { + String sql = "select 1, id, 'abc' from t1 group by 2, 1, 3"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testDuplicateConst() { + String sql = "select id, 1, 1 from t1 group by 1, 2, 3"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testWithAggFunction() { + String sql = "select 'abc', 1, COUNT(*) from t1 group by 1, 2"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1 + && agg.getOutputExpressions().stream().anyMatch(e -> e.toString().contains("COUNT")))); + } } diff --git a/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.out b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.out new file mode 100644 index 00000000000000..f161b693bd1479 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.out @@ -0,0 +1,196 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !test_1 -- +0 \N \N +0 -1 -1 +0 0 0 +0 1 1 +0 2 2 +0 3 3 +0 4 4 +0 5 5 +0 10 10 +0 100 100 + +-- !test_2 -- +0 \N \N +0 -1 -1 +0 0 0 +0 1 1 +0 2 2 +0 3 3 +0 4 4 +0 5 5 +0 10 10 +0 100 100 + +-- !test_3 -- +1 0 \N +1 0 -1 +1 0 0 +1 0 1 +1 0 2 +1 0 3 +1 0 4 +1 0 5 +1 0 10 +1 0 100 + +-- !test_4 -- +1 0 \N +1 0 -1 +1 0 0 +1 0 1 +1 0 2 +1 0 3 +1 0 4 +1 0 5 +1 0 10 +1 0 100 + +-- !test_5 -- +1 0 \N +1 0 -1 +1 0 0 +1 0 1 +1 0 2 +1 0 3 +1 0 4 +1 0 5 +1 0 10 +1 0 100 + +-- !test_6 -- +1 0 \N +1 0 -1 +1 0 0 +1 0 1 +1 0 2 +1 0 3 +1 0 4 +1 0 5 +1 0 10 +1 0 100 + +-- !test_7 -- +\N 0 Honeydew +-1 0 Grape +0 0 Fig +1 0 Apple +2 0 Banana +3 0 Cherry +4 0 Date +5 0 Elderberry +10 0 Iceberg +100 0 Jackfruit + +-- !test_8 -- +2023-12-19 \N \N +2023-12-19 -1 20231218 +2023-12-19 0 20231219 +2023-12-19 1 20231220 +2023-12-19 2 20231221 +2023-12-19 3 20231222 +2023-12-19 4 20231223 +2023-12-19 5 20231224 +2023-12-19 10 20231229 +2023-12-19 100 20231319 + +-- !test_9 -- +2023-12-19 \N \N +2023-12-19 -1 20231218 +2023-12-19 0 20231219 +2023-12-19 1 20231220 +2023-12-19 2 20231221 +2023-12-19 3 20231222 +2023-12-19 4 20231223 +2023-12-19 5 20231224 +2023-12-19 10 20231229 +2023-12-19 100 20231319 + +-- !test_10 -- +1 2023-12-19 \N +1 2023-12-19 -1 +1 2023-12-19 0 +1 2023-12-19 1 +1 2023-12-19 2 +1 2023-12-19 3 +1 2023-12-19 4 +1 2023-12-19 5 +1 2023-12-19 10 +1 2023-12-19 100 + +-- !test_11 -- +1 2023-12-19 \N +1 2023-12-19 -1 +1 2023-12-19 0 +1 2023-12-19 1 +1 2023-12-19 2 +1 2023-12-19 3 +1 2023-12-19 4 +1 2023-12-19 5 +1 2023-12-19 10 +1 2023-12-19 100 + +-- !test_12 -- +1 2023-12-19 \N +1 2023-12-19 -1 +1 2023-12-19 0 +1 2023-12-19 1 +1 2023-12-19 2 +1 2023-12-19 3 +1 2023-12-19 4 +1 2023-12-19 5 +1 2023-12-19 10 +1 2023-12-19 100 + +-- !test_13 -- +20231220 2023-12-19 \N +20231220 2023-12-19 -1 +20231220 2023-12-19 0 +20231220 2023-12-19 1 +20231220 2023-12-19 2 +20231220 2023-12-19 3 +20231220 2023-12-19 4 +20231220 2023-12-19 5 +20231220 2023-12-19 10 +20231220 2023-12-19 100 + +-- !test_14 -- +\N 2023-12-19 Honeydew +20231218 2023-12-19 Grape +20231219 2023-12-19 Fig +20231220 2023-12-19 Apple +20231221 2023-12-19 Banana +20231222 2023-12-19 Cherry +20231223 2023-12-19 Date +20231224 2023-12-19 Elderberry +20231229 2023-12-19 Iceberg +20231319 2023-12-19 Jackfruit + +-- !gby_key_is_constant_expr_not_literal -- +1 2025-03-25 07:59:04 \N +1 2025-03-25 07:59:04 -1 +1 2025-03-25 07:59:04 0 +1 2025-03-25 07:59:04 1 +1 2025-03-25 07:59:04 2 +1 2025-03-25 07:59:04 3 +1 2025-03-25 07:59:04 4 +1 2025-03-25 07:59:04 5 +1 2025-03-25 07:59:04 10 +1 2025-03-25 07:59:04 100 + +-- !test_gby_key_is_all_constant -- +1 2025-03-25 07:59:04 2023-12-19 + +-- !duplicate_gby_key -- +2025-03-25 07:59:04 2025-03-25 07:59:04 +2025-03-25 07:59:04 2025-03-25 07:59:04 +2025-03-25 07:59:04 2025-03-25 07:59:04 +2025-03-25 07:59:04 2025-03-25 07:59:04 +2025-03-25 07:59:04 2025-03-25 07:59:04 +2025-03-25 07:59:04 2025-03-25 07:59:04 +2025-03-25 07:59:04 2025-03-25 07:59:04 +2025-03-25 07:59:04 2025-03-25 07:59:04 +2025-03-25 07:59:04 2025-03-25 07:59:04 +2025-03-25 07:59:04 2025-03-25 07:59:04 + diff --git a/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.groovy b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.groovy new file mode 100644 index 00000000000000..3158e2cededffa --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_constant_gby_key.groovy @@ -0,0 +1,172 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("eliminate_constant_gby_key") { + sql """DROP TABLE IF EXISTS t1;""" + sql """CREATE TABLE t1 ( + c1 INT, + c2 VARCHAR(50), + c3 DECIMAL(10,2), + c4 DATETIME, + c5 BOOLEAN + ) distributed by hash(c1) properties("replication_num"="1");""" + + sql """INSERT INTO t1 (c1, c2, c3, c4, c5) VALUES + (1, 'Apple', 10.50, '2023-01-01 10:00:00', true), + (2, 'Banana', 20.75, '2023-01-02 11:30:00', false), + (3, 'Cherry', 15.25, '2023-01-03 09:15:00', true), + (4, 'Date', 30.00, '2023-01-04 14:45:00', false), + (5, 'Elderberry', 12.99, '2023-01-05 16:20:00', true), + (0, 'Fig', 5.50, '2023-01-06 08:00:00', false), + (-1, 'Grape', 8.25, '2023-01-07 12:30:00', true), + (NULL, 'Honeydew', NULL, NULL, NULL), + (10, 'Iceberg', 18.40, '2023-01-08 13:10:00', false), + (100, 'Jackfruit', 42.99, '2023-01-09 17:55:00', true); + """ + + def funAList = [ + "TIMESTAMPDIFF(YEAR, NOW(), NOW())", + """(TO_DATE(CASE + WHEN ('2024-01-08' < '2024-02-18') THEN '2023-12-19' + WHEN (c4 < '2024-01-01') THEN '2026-02-18' + ELSE DATE_ADD(c4, INTERVAL 365 DAY) END))""" + ] + + def testCases = [ + [desc: "select funA, c1, funA+c1 group by funA, c1", + sql: { funA -> """ + SELECT + ${funA} AS funA, + c1, + ${funA} + c1 AS 'funA+c1' + FROM t1 + GROUP BY ${funA}, c1 + ORDER BY 1, 2, 3 + """ }], + + [desc: "select funA, c1, funA+c1 group by funA, c1, funA+c1", + sql: { funA -> """ + SELECT + ${funA} AS funA, + c1, + ${funA} + c1 AS 'funA+c1' + FROM t1 + GROUP BY ${funA}, c1, ${funA} + c1 + ORDER BY 1, 2, 3 + """ }], + + [desc: "select count(distinct funA), funA, c1 group by funA,c1", + sql: { funA -> """ + SELECT + COUNT(DISTINCT ${funA}) AS 'count(distinct funA)', + ${funA} AS funA, + c1 + FROM t1 + GROUP BY ${funA}, c1 + ORDER BY 1, 2, 3 + """ }], + + [desc: "select count(funA), funA, c1 group by funA, c1", + sql: { funA -> """ + SELECT + COUNT(${funA}) AS 'count(funA)', + ${funA} AS funA, + c1 + FROM t1 + GROUP BY ${funA}, c1 + ORDER BY 1, 2, 3 + """ }], + + [desc: "select COUNT(distinct funA+1), funA, c1 group by funA,c1", + sql: { funA -> """ + SELECT + COUNT(DISTINCT ${funA} + 1) AS 'count(distinct funA+1)', + ${funA} AS funA, + c1 + FROM t1 + GROUP BY ${funA}, c1 + ORDER BY 1, 2, 3 + """ }], + + [desc: "select max(funA+1), funA, c1 group by funA, c1", + sql: { funA -> """ + SELECT + MAX(${funA} + 1) AS 'max(funA+1)', + ${funA} AS funA, + c1 + FROM t1 + GROUP BY ${funA}, c1 + ORDER BY 1, 2, 3 + """ }], + + [desc: "select max(funA+c1), funA, c2 group by funA, c2", + sql: { funA -> """ + SELECT + MAX(${funA} + c1) AS 'max(funA+c1)', + ${funA} AS funA, + c2 + FROM t1 + GROUP BY ${funA}, c2 + ORDER BY 1, 2, 3 + """ }] + ] + + def idx = 1 + funAList.each { funA -> + testCases.each { testCase -> + quickTest("test_${idx}", testCase.sql(funA)) + idx++ + } + } + + qt_gby_key_is_constant_expr_not_literal """ + SELECT + count(DISTINCT from_unixtime(1742860744.003242)) AS 'max(distinct funA)', + from_unixtime(1742860744.003242) AS funA, + c1 + FROM t1 + GROUP BY from_unixtime(1742860744.003242), c1 + order by 1,2,3 + """ + + qt_test_gby_key_is_all_constant """ + SELECT + count(DISTINCT from_unixtime(1742860744.003242)) AS 'max(distinct funA)', + from_unixtime(1742860744.003242) AS funA, + (TO_DATE(CASE + WHEN ('2024-01-08' < '2024-02-18') THEN '2023-12-19' + WHEN (c4 < '2024-01-01') THEN '2026-02-18' + ELSE DATE_ADD(c4, INTERVAL 365 DAY) END)) + c1 + FROM t1 + GROUP BY from_unixtime(1742860744.003242), (TO_DATE(CASE + WHEN ('2024-01-08' < '2024-02-18') THEN '2023-12-19' + WHEN (c4 < '2024-01-01') THEN '2026-02-18' + ELSE DATE_ADD(c4, INTERVAL 365 DAY) END)), 'abc' + order by 1,2,3,4 + """ + + qt_duplicate_gby_key """ + SELECT + from_unixtime(1742860744.003242), + from_unixtime(1742860744.003242) + c1 + FROM t1 + GROUP BY from_unixtime(1742860744.003242), from_unixtime(1742860744.003242),'abc',c1 + order by 1,2,3,4 + """ +} \ No newline at end of file diff --git a/regression-test/suites/trino_p0/constant_group_key.groovy b/regression-test/suites/trino_p0/constant_group_key.groovy index ccee425070435b..a0f3cea04b5b89 100644 --- a/regression-test/suites/trino_p0/constant_group_key.groovy +++ b/regression-test/suites/trino_p0/constant_group_key.groovy @@ -31,7 +31,7 @@ suite("constant_group_key") { explain { sql("select 'oneline', sum(n_nationkey) from nation group by 'constant1', 'constant2'") - contains "group by: 'constant2'" + contains "group by: 1" } sql "drop table if exists cgk_tbl"