Skip to content

Commit

Permalink
planner: fix panic of aggregation distinct function when distinct_agg…
Browse files Browse the repository at this point in the history
…_push_down and enable_cascades_planner enabled (pingcap#24449)
  • Loading branch information
zoomxi committed Jul 20, 2021
1 parent cca097d commit a638590
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 10 deletions.
3 changes: 2 additions & 1 deletion planner/cascades/testdata/integration_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
"select /*+ STREAM_AGG() */ count(distinct c) from t;", // should push down after stream agg implemented
"select /*+ HASH_AGG() */ count(distinct c) from t;",
"select count(distinct c) from t group by c;",
"select count(distinct c) from t;"
"select count(distinct c) from t;",
"select count(*) from t group by a having avg(distinct a)>1;" // #24449 Projection should be add between HashAgg and TableReader
]
},
{
Expand Down
15 changes: 15 additions & 0 deletions planner/cascades/testdata/integration_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,21 @@
"Result": [
"2"
]
},
{
"SQL": "select count(*) from t group by a having avg(distinct a)>1;",
"Plan": [
"Projection_14 6400.00 root Column#5",
"└─Selection_15 6400.00 root gt(Column#6, 1)",
" └─HashAgg_20 8000.00 root group by:test.t.a, funcs:count(Column#8)->Column#5, funcs:avg(distinct Column#10)->Column#6",
" └─Projection_21 8000.00 root Column#8, cast(test.t.a, decimal(15,4) BINARY)->Column#10, test.t.a",
" └─TableReader_22 8000.00 root data:HashAgg_23",
" └─HashAgg_23 8000.00 cop[tikv] group by:test.t.a, funcs:count(1)->Column#8",
" └─TableFullScan_19 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
],
"Result": [
"1"
]
}
]
},
Expand Down
16 changes: 11 additions & 5 deletions planner/cascades/transformation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -1882,7 +1882,7 @@ func (*outerJoinEliminator) prepareForEliminateOuterJoin(joinExpr *memo.GroupExp
return
}

// check whether one of unique keys sets is contained by inner join keys.
// isInnerJoinKeysContainUniqueKey check whether one of unique keys sets is contained by inner join keys.
func (*outerJoinEliminator) isInnerJoinKeysContainUniqueKey(innerGroup *memo.Group, joinKeys *expression.Schema) (bool, error) {
// builds UniqueKey info of innerGroup.
innerGroup.BuildKeyInfo()
Expand Down Expand Up @@ -2129,7 +2129,7 @@ func (r *TransformAggregateCaseToSelection) isOnlyOneNotNull(ctx sessionctx.Cont
return !args[outputIdx].Equal(ctx, expression.NewNull()) && (argsNum == 2 || args[3-outputIdx].Equal(ctx, expression.NewNull()))
}

// TransformAggregateCaseToSelection only support `case when cond then var end` and `case when cond then var1 else var2 end`.
// isTwoOrThreeArgCase represents that TransformAggregateCaseToSelection only support `case when cond then var end` and `case when cond then var1 else var2 end`.
func (r *TransformAggregateCaseToSelection) isTwoOrThreeArgCase(expr expression.Expression) bool {
scalarFunc, ok := expr.(*expression.ScalarFunction)
if !ok {
Expand Down Expand Up @@ -2315,7 +2315,7 @@ func NewRuleInjectProjectionBelowAgg() Transformation {
// Match implements Transformation interface.
func (r *InjectProjectionBelowAgg) Match(expr *memo.ExprIter) bool {
agg := expr.GetExpr().ExprNode.(*plannercore.LogicalAggregation)
return agg.IsCompleteModeAgg()
return agg.HasCompleteModeAgg()
}

// OnTransform implements Transformation interface.
Expand All @@ -2326,9 +2326,15 @@ func (r *InjectProjectionBelowAgg) OnTransform(old *memo.ExprIter) (newExprs []*
hasScalarFunc := false
copyFuncs := make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs))
for _, aggFunc := range agg.AggFuncs {
copyFunc := aggFunc.Clone()
// WrapCastForAggArgs will modify AggFunc, so we should clone AggFunc.
copyFunc.WrapCastForAggArgs(agg.SCtx())
copyFunc := aggFunc.Clone()

// if aggFunc input is from 'partial data', no need to wrap cast for agg args
copyFunc.WrapCastAsDecimalForAggArgs(agg.SCtx())
if copyFunc.Mode != aggregation.FinalMode && copyFunc.Mode != aggregation.Partial2Mode {
copyFunc.WrapCastForAggArgs(agg.SCtx())
}

copyFuncs = append(copyFuncs, copyFunc)
for _, arg := range copyFunc.Args {
_, isScalarFunc := arg.(*expression.ScalarFunction)
Expand Down
23 changes: 19 additions & 4 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,19 @@ func (la *LogicalAggregation) HasDistinct() bool {
return false
}

// HasCompleteModeAgg shows whether LogicalAggregation has functions with CompleteMode.
func (la *LogicalAggregation) HasCompleteModeAgg() bool {
// not all of the AggFunctions has the same AggMode
// for example: when cascades planner on, after PushAggDownGather transformed,
// some aggFunctions are CompleteMode, and the others are FinalMode
for _, aggFunc := range la.AggFuncs {
if aggFunc.Mode == aggregation.CompleteMode {
return true
}
}
return false
}

// CopyAggHints copies the aggHints from another LogicalAggregation.
func (la *LogicalAggregation) CopyAggHints(agg *LogicalAggregation) {
// TODO: Copy the hint may make the un-applicable hint throw the
Expand Down Expand Up @@ -391,6 +404,7 @@ func (la *LogicalAggregation) GetUsedCols() (usedCols []*expression.Column) {
type LogicalSelection struct {
baseLogicalPlan

// Conditions represents a list of AND conditions.
// Originally the WHERE or ON condition is parsed into a single expression,
// but after we converted to CNF(Conjunctive normal form), it can be
// split into a list of AND conditions.
Expand Down Expand Up @@ -495,12 +509,13 @@ type DataSource struct {
// possibleAccessPaths stores all the possible access path for physical plan, including table scan.
possibleAccessPaths []*util.AccessPath

// isPartition represents whether the data source is a partition.
// The data source may be a partition, rather than a real table.
isPartition bool
physicalTableID int64
partitionNames []model.CIStr

// handleCol represents the handle column for the datasource, either the
// handleCols represents the handle column for the datasource, either the
// int primary key column or extra handle column.
// handleCol *expression.Column
handleCols HandleCols
Expand Down Expand Up @@ -558,7 +573,7 @@ type LogicalTableScan struct {
// LogicalIndexScan is the logical index scan operator for TiKV.
type LogicalIndexScan struct {
logicalSchemaProducer
// DataSource should be read-only here.
// Source should be read-only here.
Source *DataSource
IsDoubleRead bool

Expand Down Expand Up @@ -1191,7 +1206,7 @@ type LogicalShowDDLJobs struct {
// CTEClass holds the information and plan for a CTE. Most of the fields in this struct are the same as cteInfo.
// But the cteInfo is used when building the plan, and CTEClass is used also for building the executor.
type CTEClass struct {
// The union between seed part and recursive part is DISTINCT or DISTINCT ALL.
// IsDistinct represents the union between seed part and recursive part is DISTINCT or DISTINCT ALL.
IsDistinct bool
// seedPartLogicalPlan and recursivePartLogicalPlan are the logical plans for the seed part and recursive part of this CTE.
seedPartLogicalPlan LogicalPlan
Expand All @@ -1201,7 +1216,7 @@ type CTEClass struct {
recursivePartPhysicalPlan PhysicalPlan
// cteTask is the physical plan for this CTE, is a wrapper of the PhysicalCTE.
cteTask task
// storageID for this CTE.
// IDForStorage represents the storageID for this CTE.
IDForStorage int
// optFlag is the optFlag for the whole CTE.
optFlag uint64
Expand Down
26 changes: 26 additions & 0 deletions planner/core/logical_plans_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/planner/util"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -207,3 +208,28 @@ func (s *testUnitTestSuit) TestIndexPathSplitCorColCond(c *C) {
}
collate.SetNewCollationEnabledForTest(false)
}

func (s *testUnitTestSuit) TestHasCompleteModeAgg(c *C) {
defer testleak.AfterTest(c)()

aggFuncs := make([]*aggregation.AggFuncDesc, 2)
aggFuncs[0] = &aggregation.AggFuncDesc{
Mode: aggregation.FinalMode,
HasDistinct: true,
}
aggFuncs[1] = &aggregation.AggFuncDesc{
Mode: aggregation.CompleteMode,
HasDistinct: true,
}

newAgg := &LogicalAggregation{
AggFuncs: aggFuncs,
}
c.Assert(newAgg.HasCompleteModeAgg(), Equals, true)

aggFuncs[1] = &aggregation.AggFuncDesc{
Mode: aggregation.FinalMode,
HasDistinct: true,
}
c.Assert(newAgg.HasCompleteModeAgg(), Equals, false)
}

0 comments on commit a638590

Please sign in to comment.