Skip to content

Commit

Permalink
planner: physical expand support being converted to tipb.Expand2 (#44973
Browse files Browse the repository at this point in the history
)

close #45179
  • Loading branch information
AilinKid authored Jul 5, 2023
1 parent db38184 commit d8b80a4
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 4 deletions.
2 changes: 2 additions & 0 deletions executor/internal/mpp/local_mpp_coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ func (c *localMppCoordinator) fixTaskForCTEStorageAndReader(exec *tipb.Executor,
children = append(children, exec.Sort.Child)
case tipb.ExecType_TypeExpand:
children = append(children, exec.Expand.Child)
case tipb.ExecType_TypeExpand2:
children = append(children, exec.Expand2.Child)
default:
return errors.Errorf("unknown new tipb protocol %d", exec.Tp)
}
Expand Down
2 changes: 2 additions & 0 deletions executor/internal/util/partition_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ func UpdateExecutorTableID(ctx context.Context, exec *tipb.Executor, recursive b
child = exec.Sort.Child
case tipb.ExecType_TypeExpand:
child = exec.Expand.Child
case tipb.ExecType_TypeExpand2:
child = exec.Expand2.Child
default:
return errors.Trace(fmt.Errorf("unknown new tipb protocol %d", exec.Tp))
}
Expand Down
8 changes: 8 additions & 0 deletions expression/builtin_grouping.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
package expression

import (
"context"

"github.com/gogo/protobuf/proto"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tipb/go-tipb"
)

Expand Down Expand Up @@ -73,6 +76,7 @@ func (b *BuiltinGroupingImplSig) SetMetadata(mode tipb.GroupingMode, groupingMar
b.isMetaInited = true
err := b.checkMetadata()
if err != nil {
logutil.Logger(context.Background()).Error("grouping meta check err: " + err.Error())
b.isMetaInited = false
return err
}
Expand All @@ -95,6 +99,7 @@ func (b *BuiltinGroupingImplSig) getGroupingMode() tipb.GroupingMode {
func (b *BuiltinGroupingImplSig) metadata() proto.Message {
err := b.checkMetadata()
if err != nil {
logutil.Logger(context.Background()).Error("grouping meta check err: " + err.Error())
return &tipb.GroupingFunctionMetadata{}
}
args := &tipb.GroupingFunctionMetadata{}
Expand All @@ -117,6 +122,9 @@ func (b *BuiltinGroupingImplSig) Clone() builtinFunc {
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.mode = b.mode
newSig.groupingMarks = b.groupingMarks
// mpp task generation will clone whole plan tree, including every expression related.
// if grouping function missed cloning this field, the ToPB check will errors.
newSig.isMetaInited = b.isMetaInited
return newSig
}

Expand Down
14 changes: 14 additions & 0 deletions expression/expr_to_pb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ func TestExprPushDownToFlash(t *testing.T) {
float32Column := genColumn(mysql.TypeFloat, 9)
enumColumn := genColumn(mysql.TypeEnum, 10)
durationColumn := genColumn(mysql.TypeDuration, 11)
// uint64 col
uintColumn := genColumn(mysql.TypeLonglong, 12)
uintColumn.RetType.AddFlag(mysql.UnsignedFlag)

function, err := NewFunction(mock.NewContext(), ast.JSONLength, types.NewFieldType(mysql.TypeLonglong), jsonColumn)
require.NoError(t, err)
Expand Down Expand Up @@ -1273,6 +1276,17 @@ func TestExprPushDownToFlash(t *testing.T) {
require.NoError(t, err)
exprs = append(exprs, function)

// Grouping
function, err = NewFunction(mock.NewContext(), ast.Grouping, types.NewFieldType(mysql.TypeLonglong), uintColumn)
require.NoError(t, err)
exprs = append(exprs, function)
if scalarFunc, ok := function.(*ScalarFunction); ok {
if scalarFunc.FuncName.L == ast.Grouping {
scalarFunc.Function.(*BuiltinGroupingImplSig).
SetMetadata(tipb.GroupingMode_ModeBitAnd, []map[uint64]struct{}{})
}
}

pushed, remained = PushDownExprs(sc, exprs, client, kv.TiFlash)
require.Len(t, pushed, len(exprs))
require.Len(t, remained, 0)
Expand Down
28 changes: 28 additions & 0 deletions planner/core/logical_plans_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2206,8 +2206,11 @@ func TestRollupExpand(t *testing.T) {
// after logical optimization, the current select block's expand will generate its level-projections.
require.Equal(t, builder.currentBlockExpand.LevelExprs != nil, true)
require.Equal(t, len(builder.currentBlockExpand.LevelExprs), 3)
// for grouping set {}: gid = '00' = 0
require.Equal(t, expression.ExplainExpressionList(expand.LevelExprs[0], expand.schema), "test.t.a, <nil>->Column#13, <nil>->Column#14, 0->gid")
// for grouping set {a}: gid = '01' = 1
require.Equal(t, expression.ExplainExpressionList(expand.LevelExprs[1], expand.schema), "test.t.a, Column#13, <nil>->Column#14, 1->gid")
// for grouping set {a,b}: gid = '11' = 3
require.Equal(t, expression.ExplainExpressionList(expand.LevelExprs[2], expand.schema), "test.t.a, Column#13, Column#14, 3->gid")

require.Equal(t, expand.Schema().Len(), 4)
Expand All @@ -2222,4 +2225,29 @@ func TestRollupExpand(t *testing.T) {
// the gid col
require.Equal(t, expand.Schema().Columns[3].RetType.GetFlag()&mysql.NotNullFlag, uint(1))
require.Equal(t, expand.names[3].String(), "gid")

// Test grouping marks generation.
// Expand.schema.columns[0] is normal source column.
// Expand.schema.columns[1] is normal grouping set column a.
// Expand.schema.columns[2] is normal grouping set column b.
// Expand.schema.columns[2] is normal grouping gen column gid.
// mock grouping(a)
gm := expand.GenerateGroupingMarks([]*expression.Column{expand.schema.Columns[1]})
require.NotNil(t, gm)
require.Equal(t, len(gm), 1)

// mock grouping(b)
gm = expand.GenerateGroupingMarks([]*expression.Column{expand.schema.Columns[2]})
require.NotNil(t, gm)
require.Equal(t, len(gm), 1)

// mock grouping(a,b)
gm = expand.GenerateGroupingMarks([]*expression.Column{expand.schema.Columns[1], expand.schema.Columns[2]})
require.NotNil(t, gm)
require.Equal(t, len(gm), 2)

// mock grouping(b,a)
gm = expand.GenerateGroupingMarks([]*expression.Column{expand.schema.Columns[2], expand.schema.Columns[1]})
require.NotNil(t, gm)
require.Equal(t, len(gm), 2)
}
31 changes: 27 additions & 4 deletions planner/core/physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package core
import (
"fmt"
"strconv"
"strings"
"unsafe"

"github.com/pingcap/errors"
Expand Down Expand Up @@ -1058,11 +1059,11 @@ type PhysicalProjection struct {
func (p *PhysicalProjection) Clone() (PhysicalPlan, error) {
cloned := new(PhysicalProjection)
*cloned = *p
base, err := p.basePhysicalPlan.cloneWithSelf(cloned)
base, err := p.physicalSchemaProducer.cloneWithSelf(cloned)
if err != nil {
return nil, err
}
cloned.basePhysicalPlan = *base
cloned.physicalSchemaProducer = *base
cloned.Exprs = util.CloneExprs(p.Exprs)
return cloned, err
}
Expand Down Expand Up @@ -1616,12 +1617,15 @@ func (p PhysicalExpand) Init(ctx sessionctx.Context, stats *property.StatsInfo,

// Clone implements PhysicalPlan interface.
func (p *PhysicalExpand) Clone() (PhysicalPlan, error) {
if len(p.LevelExprs) > 0 {
return p.cloneV2()
}
np := new(PhysicalExpand)
base, err := p.basePhysicalPlan.cloneWithSelf(np)
base, err := p.physicalSchemaProducer.cloneWithSelf(np)
if err != nil {
return nil, errors.Trace(err)
}
np.basePhysicalPlan = *base
np.physicalSchemaProducer = *base
// clone ID cols.
np.GroupingIDCol = p.GroupingIDCol.Clone().(*expression.Column)

Expand All @@ -1634,6 +1638,25 @@ func (p *PhysicalExpand) Clone() (PhysicalPlan, error) {
return np, nil
}

func (p *PhysicalExpand) cloneV2() (PhysicalPlan, error) {
np := new(PhysicalExpand)
base, err := p.physicalSchemaProducer.cloneWithSelf(np)
if err != nil {
return nil, errors.Trace(err)
}
np.physicalSchemaProducer = *base
// clone level projection expressions.
for _, oneLevelProjExprs := range p.LevelExprs {
np.LevelExprs = append(np.LevelExprs, util.CloneExprs(oneLevelProjExprs))
}

// clone generated column names.
for _, name := range p.ExtraGroupingColNames {
np.ExtraGroupingColNames = append(np.ExtraGroupingColNames, strings.Clone(name))
}
return np, nil
}

// MemoryUsage return the memory usage of PhysicalExpand
func (p *PhysicalExpand) MemoryUsage() (sum int64) {
if p == nil {
Expand Down
30 changes: 30 additions & 0 deletions planner/core/plan_to_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ func (p *basePhysicalPlan) ToPB(_ sessionctx.Context, _ kv.StoreType) (*tipb.Exe

// ToPB implements PhysicalPlan ToPB interface.
func (p *PhysicalExpand) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) {
if len(p.LevelExprs) > 0 {
return p.toPBV2(ctx, storeType)
}
sc := ctx.GetSessionVars().StmtCtx
client := ctx.GetClient()
groupingSetsPB, err := p.GroupingSets.ToPB(sc, client)
Expand All @@ -56,6 +59,33 @@ func (p *PhysicalExpand) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*
return &tipb.Executor{Tp: tipb.ExecType_TypeExpand, Expand: expand, ExecutorId: &executorID}, nil
}

func (p *PhysicalExpand) toPBV2(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) {
sc := ctx.GetSessionVars().StmtCtx
client := ctx.GetClient()
projExprsPB := make([]*tipb.ExprSlice, 0, len(p.LevelExprs))
for _, exprs := range p.LevelExprs {
expressionsPB, err := expression.ExpressionsToPBList(sc, exprs, client)
if err != nil {
return nil, err
}
projExprsPB = append(projExprsPB, &tipb.ExprSlice{Exprs: expressionsPB})
}
expand2 := &tipb.Expand2{
ProjExprs: projExprsPB,
GeneratedOutputNames: p.ExtraGroupingColNames,
}
executorID := ""
if storeType == kv.TiFlash {
var err error
expand2.Child, err = p.children[0].ToPB(ctx, storeType)
if err != nil {
return nil, errors.Trace(err)
}
executorID = p.ExplainID().String()
}
return &tipb.Executor{Tp: tipb.ExecType_TypeExpand2, Expand2: expand2, ExecutorId: &executorID}, nil
}

// ToPB implements PhysicalPlan ToPB interface.
func (p *PhysicalHashAgg) ToPB(ctx sessionctx.Context, storeType kv.StoreType) (*tipb.Executor, error) {
sc := ctx.GetSessionVars().StmtCtx
Expand Down

0 comments on commit d8b80a4

Please sign in to comment.