Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: fix panic of aggregation distinct function when distinct_agg… #26386

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion planner/cascades/testdata/integration_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
"select /*+ STREAM_AGG() */ count(distinct c) from t;", // should push down after stream agg implemented
"select /*+ HASH_AGG() */ count(distinct c) from t;",
"select count(distinct c) from t group by c;",
"select count(distinct c) from t;"
"select count(distinct c) from t;",
"select count(*) from t group by a having avg(distinct a)>1;" // #24449 Projection should be add between HashAgg and TableReader
]
},
{
Expand Down
15 changes: 15 additions & 0 deletions planner/cascades/testdata/integration_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,21 @@
"Result": [
"2"
]
},
{
"SQL": "select count(*) from t group by a having avg(distinct a)>1;",
"Plan": [
"Projection_14 6400.00 root Column#5",
"└─Selection_15 6400.00 root gt(Column#6, 1)",
" └─HashAgg_20 8000.00 root group by:test.t.a, funcs:count(Column#8)->Column#5, funcs:avg(distinct Column#10)->Column#6",
" └─Projection_21 8000.00 root Column#8, cast(test.t.a, decimal(15,4) BINARY)->Column#10, test.t.a",
" └─TableReader_22 8000.00 root data:HashAgg_23",
" └─HashAgg_23 8000.00 cop[tikv] group by:test.t.a, funcs:count(1)->Column#8",
" └─TableFullScan_19 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
],
"Result": [
"1"
]
}
]
},
Expand Down
16 changes: 11 additions & 5 deletions planner/cascades/transformation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -1882,7 +1882,7 @@ func (*outerJoinEliminator) prepareForEliminateOuterJoin(joinExpr *memo.GroupExp
return
}

// check whether one of unique keys sets is contained by inner join keys.
// isInnerJoinKeysContainUniqueKey check whether one of unique keys sets is contained by inner join keys.
func (*outerJoinEliminator) isInnerJoinKeysContainUniqueKey(innerGroup *memo.Group, joinKeys *expression.Schema) (bool, error) {
// builds UniqueKey info of innerGroup.
innerGroup.BuildKeyInfo()
Expand Down Expand Up @@ -2129,7 +2129,7 @@ func (r *TransformAggregateCaseToSelection) isOnlyOneNotNull(ctx sessionctx.Cont
return !args[outputIdx].Equal(ctx, expression.NewNull()) && (argsNum == 2 || args[3-outputIdx].Equal(ctx, expression.NewNull()))
}

// TransformAggregateCaseToSelection only support `case when cond then var end` and `case when cond then var1 else var2 end`.
// isTwoOrThreeArgCase represents that TransformAggregateCaseToSelection only support `case when cond then var end` and `case when cond then var1 else var2 end`.
func (r *TransformAggregateCaseToSelection) isTwoOrThreeArgCase(expr expression.Expression) bool {
scalarFunc, ok := expr.(*expression.ScalarFunction)
if !ok {
Expand Down Expand Up @@ -2315,7 +2315,7 @@ func NewRuleInjectProjectionBelowAgg() Transformation {
// Match implements Transformation interface.
func (r *InjectProjectionBelowAgg) Match(expr *memo.ExprIter) bool {
agg := expr.GetExpr().ExprNode.(*plannercore.LogicalAggregation)
return agg.IsCompleteModeAgg()
return agg.HasCompleteModeAgg()
}

// OnTransform implements Transformation interface.
Expand All @@ -2326,9 +2326,15 @@ func (r *InjectProjectionBelowAgg) OnTransform(old *memo.ExprIter) (newExprs []*
hasScalarFunc := false
copyFuncs := make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs))
for _, aggFunc := range agg.AggFuncs {
copyFunc := aggFunc.Clone()
// WrapCastForAggArgs will modify AggFunc, so we should clone AggFunc.
copyFunc.WrapCastForAggArgs(agg.SCtx())
copyFunc := aggFunc.Clone()

// if aggFunc input is from 'partial data', no need to wrap cast for agg args
copyFunc.WrapCastAsDecimalForAggArgs(agg.SCtx())
if copyFunc.Mode != aggregation.FinalMode && copyFunc.Mode != aggregation.Partial2Mode {
copyFunc.WrapCastForAggArgs(agg.SCtx())
}

copyFuncs = append(copyFuncs, copyFunc)
for _, arg := range copyFunc.Args {
_, isScalarFunc := arg.(*expression.ScalarFunction)
Expand Down
23 changes: 19 additions & 4 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,19 @@ func (la *LogicalAggregation) HasDistinct() bool {
return false
}

// HasCompleteModeAgg shows whether LogicalAggregation has functions with CompleteMode.
func (la *LogicalAggregation) HasCompleteModeAgg() bool {
// not all of the AggFunctions has the same AggMode
// for example: when cascades planner on, after PushAggDownGather transformed,
// some aggFunctions are CompleteMode, and the others are FinalMode
for _, aggFunc := range la.AggFuncs {
if aggFunc.Mode == aggregation.CompleteMode {
return true
}
}
return false
}

// CopyAggHints copies the aggHints from another LogicalAggregation.
func (la *LogicalAggregation) CopyAggHints(agg *LogicalAggregation) {
// TODO: Copy the hint may make the un-applicable hint throw the
Expand Down Expand Up @@ -391,6 +404,7 @@ func (la *LogicalAggregation) GetUsedCols() (usedCols []*expression.Column) {
type LogicalSelection struct {
baseLogicalPlan

// Conditions represents a list of AND conditions.
// Originally the WHERE or ON condition is parsed into a single expression,
// but after we converted to CNF(Conjunctive normal form), it can be
// split into a list of AND conditions.
Expand Down Expand Up @@ -495,12 +509,13 @@ type DataSource struct {
// possibleAccessPaths stores all the possible access path for physical plan, including table scan.
possibleAccessPaths []*util.AccessPath

// isPartition represents whether the data source is a partition.
// The data source may be a partition, rather than a real table.
isPartition bool
physicalTableID int64
partitionNames []model.CIStr

// handleCol represents the handle column for the datasource, either the
// handleCols represents the handle column for the datasource, either the
// int primary key column or extra handle column.
// handleCol *expression.Column
handleCols HandleCols
Expand Down Expand Up @@ -558,7 +573,7 @@ type LogicalTableScan struct {
// LogicalIndexScan is the logical index scan operator for TiKV.
type LogicalIndexScan struct {
logicalSchemaProducer
// DataSource should be read-only here.
// Source should be read-only here.
Source *DataSource
IsDoubleRead bool

Expand Down Expand Up @@ -1191,7 +1206,7 @@ type LogicalShowDDLJobs struct {
// CTEClass holds the information and plan for a CTE. Most of the fields in this struct are the same as cteInfo.
// But the cteInfo is used when building the plan, and CTEClass is used also for building the executor.
type CTEClass struct {
// The union between seed part and recursive part is DISTINCT or DISTINCT ALL.
// IsDistinct represents the union between seed part and recursive part is DISTINCT or DISTINCT ALL.
IsDistinct bool
// seedPartLogicalPlan and recursivePartLogicalPlan are the logical plans for the seed part and recursive part of this CTE.
seedPartLogicalPlan LogicalPlan
Expand All @@ -1201,7 +1216,7 @@ type CTEClass struct {
recursivePartPhysicalPlan PhysicalPlan
// cteTask is the physical plan for this CTE, is a wrapper of the PhysicalCTE.
cteTask task
// storageID for this CTE.
// IDForStorage represents the storageID for this CTE.
IDForStorage int
// optFlag is the optFlag for the whole CTE.
optFlag uint64
Expand Down
26 changes: 26 additions & 0 deletions planner/core/logical_plans_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/planner/util"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -207,3 +208,28 @@ func (s *testUnitTestSuit) TestIndexPathSplitCorColCond(c *C) {
}
collate.SetNewCollationEnabledForTest(false)
}

func (s *testUnitTestSuit) TestHasCompleteModeAgg(c *C) {
defer testleak.AfterTest(c)()

aggFuncs := make([]*aggregation.AggFuncDesc, 2)
aggFuncs[0] = &aggregation.AggFuncDesc{
Mode: aggregation.FinalMode,
HasDistinct: true,
}
aggFuncs[1] = &aggregation.AggFuncDesc{
Mode: aggregation.CompleteMode,
HasDistinct: true,
}

newAgg := &LogicalAggregation{
AggFuncs: aggFuncs,
}
c.Assert(newAgg.HasCompleteModeAgg(), Equals, true)

aggFuncs[1] = &aggregation.AggFuncDesc{
Mode: aggregation.FinalMode,
HasDistinct: true,
}
c.Assert(newAgg.HasCompleteModeAgg(), Equals, false)
}