Skip to content

Commit

Permalink
planner: support 3 stage aggregation for single scalar distinct agg (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
fixdb committed Sep 17, 2022
1 parent 0f93f7b commit efc0720
Show file tree
Hide file tree
Showing 16 changed files with 602 additions and 71 deletions.
8 changes: 8 additions & 0 deletions executor/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,14 @@ func TestSetVar(t *testing.T) {
tk.MustExec("set global tidb_opt_skew_distinct_agg=1")
tk.MustQuery("select @@global.tidb_opt_skew_distinct_agg").Check(testkit.Rows("1"))

// test for tidb_opt_three_stage_distinct_agg
tk.MustQuery("select @@session.tidb_opt_three_stage_distinct_agg").Check(testkit.Rows("1")) // default value is 1
tk.MustExec("set session tidb_opt_three_stage_distinct_agg=0")
tk.MustQuery("select @@session.tidb_opt_three_stage_distinct_agg").Check(testkit.Rows("0"))
tk.MustQuery("select @@global.tidb_opt_three_stage_distinct_agg").Check(testkit.Rows("1")) // default value is 1
tk.MustExec("set global tidb_opt_three_stage_distinct_agg=0")
tk.MustQuery("select @@global.tidb_opt_three_stage_distinct_agg").Check(testkit.Rows("0"))

// the value of max_allowed_packet should be a multiple of 1024
tk.MustExec("set @@global.max_allowed_packet=16385")
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect max_allowed_packet value: '16385'"))
Expand Down
21 changes: 21 additions & 0 deletions expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@ func (col *CorrelatedColumn) MemoryUsage() (sum int64) {
return sum
}

// RemapColumn remaps columns with provided mapping and returns new expression
func (col *CorrelatedColumn) RemapColumn(m map[int64]*Column) (Expression, error) {
mapped := m[(&col.Column).UniqueID]
if mapped == nil {
return nil, errors.Errorf("Can't remap column for %s", col)
}
return &CorrelatedColumn{
Column: *mapped,
Data: col.Data,
}, nil
}

// Column represents a column.
type Column struct {
RetType *types.FieldType
Expand Down Expand Up @@ -537,6 +549,15 @@ func (col *Column) resolveIndicesByVirtualExpr(schema *Schema) bool {
return false
}

// RemapColumn remaps columns with provided mapping and returns new expression
func (col *Column) RemapColumn(m map[int64]*Column) (Expression, error) {
mapped := m[col.UniqueID]
if mapped == nil {
return nil, errors.Errorf("Can't remap column for %s", col)
}
return mapped, nil
}

// Vectorized returns if this expression supports vectorized evaluation.
func (col *Column) Vectorized() bool {
return true
Expand Down
5 changes: 5 additions & 0 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,11 @@ func (c *Constant) resolveIndicesByVirtualExpr(_ *Schema) bool {
return true
}

// RemapColumn remaps columns with provided mapping and returns new expression
func (c *Constant) RemapColumn(_ map[int64]*Column) (Expression, error) {
return c, nil
}

// Vectorized returns if this expression supports vectorized evaluation.
func (c *Constant) Vectorized() bool {
if c.DeferredExpr != nil {
Expand Down
3 changes: 3 additions & 0 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ type Expression interface {
// resolveIndicesByVirtualExpr is called inside the `ResolveIndicesByVirtualExpr` It will perform on the expression itself.
resolveIndicesByVirtualExpr(schema *Schema) bool

// RemapColumn remaps columns with provided mapping and returns new expression
RemapColumn(map[int64]*Column) (Expression, error)

// ExplainInfo returns operator information to be explained.
ExplainInfo() string

Expand Down
18 changes: 18 additions & 0 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,24 @@ func (sf *ScalarFunction) resolveIndicesByVirtualExpr(schema *Schema) bool {
return true
}

// RemapColumn remaps columns with provided mapping and returns new expression
func (sf *ScalarFunction) RemapColumn(m map[int64]*Column) (Expression, error) {
newSf, ok := sf.Clone().(*ScalarFunction)
if !ok {
return nil, errors.New("failed to cast to scalar function")
}
for i, arg := range sf.GetArgs() {
newArg, err := arg.RemapColumn(m)
if err != nil {
return nil, err
}
newSf.GetArgs()[i] = newArg
}
// clear hash code
newSf.hashcode = nil
return newSf, nil
}

// GetSingleColumn returns (Col, Desc) when the ScalarFunction is equivalent to (Col, Desc)
// when used as a sort key, otherwise returns (nil, false).
//
Expand Down
1 change: 1 addition & 0 deletions expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ func (m *MockExpr) ResolveIndices(schema *Schema) (Expression, error)
func (m *MockExpr) resolveIndices(schema *Schema) error { return nil }
func (m *MockExpr) ResolveIndicesByVirtualExpr(schema *Schema) (Expression, bool) { return m, true }
func (m *MockExpr) resolveIndicesByVirtualExpr(schema *Schema) bool { return true }
func (m *MockExpr) RemapColumn(_ map[int64]*Column) (Expression, error) { return m, nil }
func (m *MockExpr) ExplainInfo() string { return "" }
func (m *MockExpr) ExplainNormalizedInfo() string { return "" }
func (m *MockExpr) HashCode(sc *stmtctx.StatementContext) []byte { return nil }
Expand Down
51 changes: 51 additions & 0 deletions planner/core/enforce_mpp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,54 @@ func TestMPPSkewedGroupDistinctRewrite(t *testing.T) {
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}

// Test 3 stage aggregation for single count distinct
func TestMPPSingleDistinct3Stage(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)

// test table
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b bigint not null, c bigint, d date, e varchar(20) collate utf8mb4_general_ci)")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
}
}
}

var input []string
var output []struct {
SQL string
Plan []string
Warn []string
}
enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData()
enforceMPPSuiteData.LoadTestCases(t, &input, &output)
for i, tt := range input {
testdata.OnRecord(func() {
output[i].SQL = tt
})
if strings.HasPrefix(tt, "set") || strings.HasPrefix(tt, "UPDATE") {
tk.MustExec(tt)
continue
}
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}
112 changes: 112 additions & 0 deletions planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,35 @@ func (p *basePhysicalAgg) newPartialAggregate(copTaskType kv.StoreType, isMPPTas
return partialAgg, finalAgg
}

// canUse3StageDistinctAgg returns true if this agg can use 3 stage for distinct aggregation
func (p *basePhysicalAgg) canUse3StageDistinctAgg() bool {
num := 0
if !p.ctx.GetSessionVars().Enable3StageDistinctAgg || len(p.GroupByItems) > 0 {
return false
}
for _, fun := range p.AggFuncs {
if fun.HasDistinct {
num++
if num > 1 || fun.Name != ast.AggFuncCount {
return false
}
for _, arg := range fun.Args {
// bail out when args are not simple column, see GitHub issue #35417
if _, ok := arg.(*expression.Column); !ok {
return false
}
}
} else if len(fun.Args) > 1 {
return false
}

if len(fun.OrderByItems) > 0 {
return false
}
}
return num == 1
}

func genFirstRowAggForGroupBy(ctx sessionctx.Context, groupByItems []expression.Expression) ([]*aggregation.AggFuncDesc, error) {
aggFuncs := make([]*aggregation.AggFuncDesc, 0, len(groupByItems))
for _, groupBy := range groupByItems {
Expand Down Expand Up @@ -1642,15 +1671,98 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
if !mpp.needEnforceExchanger(prop) {
return p.attach2TaskForMpp1Phase(mpp)
}
// we have to check it before the content of p has been modified
canUse3StageAgg := p.canUse3StageDistinctAgg()
proj := p.convertAvgForMPP()
partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, true)
if finalAgg == nil {
return invalidTask
}

// generate 3 stage aggregation for single count distinct if applicable.
// select count(distinct a), count(b) from foo
// will generate plan:
// HashAgg sum(#1), sum(#2) -> final agg
// +- Exchange Passthrough
// +- HashAgg count(distinct a) #1, sum(#3) #2 -> middle agg
// +- Exchange HashPartition by a
// +- HashAgg count(b) #3, group by a -> partial agg
// +- TableScan foo
var middleAgg *PhysicalHashAgg = nil
if partialAgg != nil && canUse3StageAgg {
clonedAgg, err := finalAgg.Clone()
if err != nil {
return invalidTask
}
middleAgg = clonedAgg.(*PhysicalHashAgg)
distinctPos := 0
middleSchema := expression.NewSchema()
schemaMap := make(map[int64]*expression.Column, len(middleAgg.AggFuncs))
for i, fun := range middleAgg.AggFuncs {
col := &expression.Column{
UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(),
RetType: fun.RetTp,
}
if fun.HasDistinct {
distinctPos = i
} else {
fun.Mode = aggregation.Partial2Mode
originalCol := fun.Args[0].(*expression.Column)
schemaMap[originalCol.UniqueID] = col
}
middleSchema.Append(col)
}
middleAgg.schema = middleSchema

finalHashAgg := finalAgg.(*PhysicalHashAgg)
finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs))
for i, fun := range finalHashAgg.AggFuncs {
newArgs := make([]expression.Expression, 0, 1)
if distinctPos == i {
// change count(distinct) to sum()
fun.Name = ast.AggFuncSum
fun.HasDistinct = false
newArgs = append(newArgs, middleSchema.Columns[i])
} else {
for _, arg := range fun.Args {
newCol, err := arg.RemapColumn(schemaMap)
if err != nil {
return invalidTask
}
newArgs = append(newArgs, newCol)
}
}
fun.Args = newArgs
finalAggDescs = append(finalAggDescs, fun)
}
finalHashAgg.AggFuncs = finalAggDescs
}

// partial agg would be null if one scalar agg cannot run in two-phase mode
if partialAgg != nil {
attachPlan2Task(partialAgg, mpp)
}

if middleAgg != nil && canUse3StageAgg {
items := partialAgg.(*PhysicalHashAgg).GroupByItems
partitionCols := make([]*property.MPPPartitionColumn, 0, len(items))
for _, expr := range items {
col, ok := expr.(*expression.Column)
if !ok {
continue
}
partitionCols = append(partitionCols, &property.MPPPartitionColumn{
Col: col,
CollateID: property.GetCollateIDByNameForPartition(col.GetType().GetCollate()),
})
}

prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols}
newMpp := mpp.enforceExchanger(prop)
attachPlan2Task(middleAgg, newMpp)
mpp = newMpp
}

newMpp := mpp.enforceExchanger(prop)
attachPlan2Task(finalAgg, newMpp)
if proj == nil {
Expand Down
19 changes: 19 additions & 0 deletions planner/core/testdata/enforce_mpp_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,24 @@
"EXPLAIN select a, count(b), avg(distinct c), count(distinct c) from t group by a; -- multi distinct funcs, bail out",
"EXPLAIN select count(b), count(distinct c) from t; -- single distinct func but no group key, bail out"
]
},
{
"name": "TestMPPSingleDistinct3Stage",
"cases": [
"set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;",
"EXPLAIN select count(distinct b) from t;",
"EXPLAIN select count(distinct c) from t;",
"EXPLAIN select count(distinct e) from t;",
"EXPLAIN select count(distinct a,b,c,e) from t;",
"EXPLAIN select count(distinct c), count(a), count(*) from t;",
"EXPLAIN select sum(b), count(a), count(*), count(distinct c) from t;",
"EXPLAIN select sum(b+a), count(*), count(distinct c), count(a) from t having count(distinct c) > 2;",
"EXPLAIN select sum(b+a), count(*), count(a) from t having count(distinct c) > 2;",
"EXPLAIN select sum(b+a), max(b), count(distinct c), count(*) from t having count(a) > 2;",
"EXPLAIN select sum(b), count(distinct a, b, e), count(a+b) from t;",
"EXPLAIN select count(distinct b), json_objectagg(d,c) from t;",
"EXPLAIN select count(distinct c+a), count(a) from t;",
"EXPLAIN select sum(b), count(distinct c+a, b, e), count(a+b) from t;"
]
}
]
Loading

0 comments on commit efc0720

Please sign in to comment.