Skip to content

Commit

Permalink
planner/core: support union all for mpp. (#24287)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanfei1991 authored Jun 2, 2021
1 parent 7c3e036 commit 52e89cb
Show file tree
Hide file tree
Showing 12 changed files with 314 additions and 40 deletions.
24 changes: 10 additions & 14 deletions executor/mpp_gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ type MPPGather struct {
respIter distsql.SelectResult
}

func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment, tasks []*kv.MPPTask, isRoot bool) error {
func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment) error {
dagReq, _, err := constructDAGReq(e.ctx, []plannercore.PhysicalPlan{pf.ExchangeSender}, kv.TiFlash)
if err != nil {
return errors.Trace(err)
}
for i := range pf.ExchangeSender.Schema().Columns {
dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(i))
}
if !isRoot {
if !pf.IsRoot {
dagReq.EncodeType = tipb.EncodeType_TypeCHBlock
} else {
dagReq.EncodeType = tipb.EncodeType_TypeChunk
}
for _, mppTask := range tasks {
for _, mppTask := range pf.ExchangeSender.Tasks {
err := updateExecutorTableID(context.Background(), dagReq.RootExecutor, mppTask.TableID, true)
if err != nil {
return errors.Trace(err)
Expand All @@ -77,20 +77,14 @@ func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment, tasks []*kv.M
Data: pbData,
Meta: mppTask.Meta,
ID: mppTask.ID,
IsRoot: isRoot,
IsRoot: pf.IsRoot,
Timeout: 10,
SchemaVar: e.is.SchemaMetaVersion(),
StartTs: e.startTS,
State: kv.MppTaskReady,
}
e.mppReqs = append(e.mppReqs, req)
}
for _, r := range pf.ExchangeReceivers {
err = e.appendMPPDispatchReq(r.GetExchangeSender().Fragment, r.Tasks, false)
if err != nil {
return errors.Trace(err)
}
}
return nil
}

Expand All @@ -108,13 +102,15 @@ func (e *MPPGather) Open(ctx context.Context) (err error) {
// TODO: Move the construct tasks logic to planner, so we can see the explain results.
sender := e.originalPlan.(*plannercore.PhysicalExchangeSender)
planIDs := collectPlanIDS(e.originalPlan, nil)
rootTasks, err := plannercore.GenerateRootMPPTasks(e.ctx, e.startTS, sender, e.is)
frags, err := plannercore.GenerateRootMPPTasks(e.ctx, e.startTS, sender, e.is)
if err != nil {
return errors.Trace(err)
}
err = e.appendMPPDispatchReq(sender.Fragment, rootTasks, true)
if err != nil {
return errors.Trace(err)
for _, frag := range frags {
err = e.appendMPPDispatchReq(frag)
if err != nil {
return errors.Trace(err)
}
}
failpoint.Inject("checkTotalMPPTasks", func(val failpoint.Value) {
if val.(int) != len(e.mppReqs) {
Expand Down
42 changes: 42 additions & 0 deletions executor/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,48 @@ func (s *tiflashTestSuite) TestMppGoroutinesExitFromErrors(c *C) {
c.Assert(failpoint.Disable(hang), IsNil)
}

func (s *tiflashTestSuite) TestMppUnionAll(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists x1")
tk.MustExec("create table x1(a int , b int);")
tk.MustExec("alter table x1 set tiflash replica 1")
tk.MustExec("drop table if exists x2")
tk.MustExec("create table x2(a int , b int);")
tk.MustExec("alter table x2 set tiflash replica 1")
tb := testGetTableByName(c, tk.Se, "test", "x1")
err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)
tb = testGetTableByName(c, tk.Se, "test", "x2")
err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)

tk.MustExec("insert into x1 values (1, 1), (2, 2), (3, 3), (4, 4)")
tk.MustExec("insert into x2 values (5, 1), (2, 2), (3, 3), (4, 4)")

// test join + union (join + select)
tk.MustQuery("select x1.a, x.a from x1 left join (select x2.b a, x1.b from x1 join x2 on x1.a = x2.b union all select * from x1 ) x on x1.a = x.a order by x1.a").Check(testkit.Rows("1 1", "1 1", "2 2", "2 2", "3 3", "3 3", "4 4", "4 4"))
tk.MustQuery("select x1.a, x.a from x1 left join (select count(*) a, sum(b) b from x1 group by a union all select * from x2 ) x on x1.a = x.a order by x1.a").Check(testkit.Rows("1 1", "1 1", "1 1", "1 1", "2 2", "3 3", "4 4"))

tk.MustExec("drop table if exists x3")
tk.MustExec("create table x3(a int , b int);")
tk.MustExec("alter table x3 set tiflash replica 1")
tb = testGetTableByName(c, tk.Se, "test", "x3")
err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)

tk.MustExec("insert into x3 values (2, 2), (2, 3), (2, 4)")
// test nested union all
tk.MustQuery("select count(*) from (select a, b from x1 union all select a, b from x3 union all (select x1.a, x3.b from (select * from x3 union all select * from x2) x3 left join x1 on x3.a = x1.b))").Check(testkit.Rows("14"))
// test union all join union all
tk.MustQuery("select count(*) from (select * from x1 union all select * from x2 union all select * from x3) x join (select * from x1 union all select * from x2 union all select * from x3) y on x.a = y.b").Check(testkit.Rows("29"))
tk.MustExec("set @@session.tidb_broadcast_join_threshold_count=100000")
failpoint.Enable("github.com/pingcap/tidb/executor/checkTotalMPPTasks", `return(6)`)
tk.MustQuery("select count(*) from (select * from x1 union all select * from x2 union all select * from x3) x join (select * from x1 union all select * from x2 union all select * from x3) y on x.a = y.b").Check(testkit.Rows("29"))
failpoint.Disable("github.com/pingcap/tidb/executor/checkTotalMPPTasks")

}

func (s *tiflashTestSuite) TestMppApply(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
3 changes: 1 addition & 2 deletions expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ type CorrelatedColumn struct {

// Clone implements Expression interface.
func (col *CorrelatedColumn) Clone() Expression {
var d types.Datum
return &CorrelatedColumn{
Column: col.Column,
Data: &d,
Data: col.Data,
}
}

Expand Down
34 changes: 30 additions & 4 deletions planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -2126,7 +2126,7 @@ func (p *baseLogicalPlan) canPushToCop(storeTp kv.StoreType) bool {
}
}
ret = ret && validDs
case *LogicalAggregation, *LogicalProjection, *LogicalSelection, *LogicalJoin:
case *LogicalAggregation, *LogicalProjection, *LogicalSelection, *LogicalJoin, *LogicalUnionAll:
if storeTp == kv.TiFlash {
ret = ret && c.canPushToCop(storeTp)
} else {
Expand Down Expand Up @@ -2494,15 +2494,41 @@ func (p *LogicalLock) exhaustPhysicalPlans(prop *property.PhysicalProperty) ([]P

func (p *LogicalUnionAll) exhaustPhysicalPlans(prop *property.PhysicalProperty) ([]PhysicalPlan, bool, error) {
// TODO: UnionAll can not pass any order, but we can change it to sort merge to keep order.
if !prop.IsEmpty() || prop.IsFlashProp() {
if !prop.IsEmpty() || (prop.IsFlashProp() && prop.TaskTp != property.MppTaskType) {
return nil, true, nil
}
// TODO: UnionAll can pass partition info, but for briefness, we prevent it from pushing down.
if prop.TaskTp == property.MppTaskType && prop.PartitionTp != property.AnyType {
return nil, true, nil
}
canUseMpp := p.ctx.GetSessionVars().IsMPPAllowed() && p.canPushToCop(kv.TiFlash)
chReqProps := make([]*property.PhysicalProperty, 0, len(p.children))
for range p.children {
chReqProps = append(chReqProps, &property.PhysicalProperty{ExpectedCnt: prop.ExpectedCnt})
if canUseMpp && prop.TaskTp == property.MppTaskType {
chReqProps = append(chReqProps, &property.PhysicalProperty{
ExpectedCnt: prop.ExpectedCnt,
TaskTp: property.MppTaskType,
})
} else {
chReqProps = append(chReqProps, &property.PhysicalProperty{ExpectedCnt: prop.ExpectedCnt})
}
}
ua := PhysicalUnionAll{}.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), p.blockOffset, chReqProps...)
ua := PhysicalUnionAll{
mpp: canUseMpp && prop.TaskTp == property.MppTaskType,
}.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), p.blockOffset, chReqProps...)
ua.SetSchema(p.Schema())
if canUseMpp && prop.TaskTp == property.RootTaskType {
chReqProps = make([]*property.PhysicalProperty, 0, len(p.children))
for range p.children {
chReqProps = append(chReqProps, &property.PhysicalProperty{
ExpectedCnt: prop.ExpectedCnt,
TaskTp: property.MppTaskType,
})
}
mppUA := PhysicalUnionAll{mpp: true}.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), p.blockOffset, chReqProps...)
mppUA.SetSchema(p.Schema())
return []PhysicalPlan{ua, mppUA}, true, nil
}
return []PhysicalPlan{ua}, true, nil
}

Expand Down
139 changes: 123 additions & 16 deletions planner/core/fragment.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,49 @@ type Fragment struct {

// following fields are filled after scheduling.
ExchangeSender *PhysicalExchangeSender // data exporter

IsRoot bool
}

type tasksAndFrags struct {
tasks []*kv.MPPTask
frags []*Fragment
}

type mppTaskGenerator struct {
ctx sessionctx.Context
startTS uint64
is infoschema.InfoSchema
frags []*Fragment
cache map[int]tasksAndFrags
}

// GenerateRootMPPTasks generate all mpp tasks and return root ones.
func GenerateRootMPPTasks(ctx sessionctx.Context, startTs uint64, sender *PhysicalExchangeSender, is infoschema.InfoSchema) ([]*kv.MPPTask, error) {
g := &mppTaskGenerator{ctx: ctx, startTS: startTs, is: is}
func GenerateRootMPPTasks(ctx sessionctx.Context, startTs uint64, sender *PhysicalExchangeSender, is infoschema.InfoSchema) ([]*Fragment, error) {
g := &mppTaskGenerator{
ctx: ctx,
startTS: startTs,
is: is,
cache: make(map[int]tasksAndFrags),
}
return g.generateMPPTasks(sender)
}

func (e *mppTaskGenerator) generateMPPTasks(s *PhysicalExchangeSender) ([]*kv.MPPTask, error) {
func (e *mppTaskGenerator) generateMPPTasks(s *PhysicalExchangeSender) ([]*Fragment, error) {
logutil.BgLogger().Info("Mpp will generate tasks", zap.String("plan", ToString(s)))
tidbTask := &kv.MPPTask{
StartTs: e.startTS,
ID: -1,
}
rootTasks, err := e.generateMPPTasksForFragment(s)
_, frags, err := e.generateMPPTasksForExchangeSender(s)
if err != nil {
return nil, errors.Trace(err)
}
s.TargetTasks = []*kv.MPPTask{tidbTask}
return rootTasks, nil
for _, frag := range frags {
frag.ExchangeSender.TargetTasks = []*kv.MPPTask{tidbTask}
frag.IsRoot = true
}
return e.frags, nil
}

type mppAddr struct {
Expand Down Expand Up @@ -105,6 +122,8 @@ func (f *Fragment) init(p PhysicalPlan) error {
f.TableScan = x
case *PhysicalExchangeReceiver:
f.ExchangeReceivers = append(f.ExchangeReceivers, x)
case *PhysicalUnionAll:
return errors.New("unexpected union all detected")
default:
for _, ch := range p.Children() {
if err := f.init(ch); err != nil {
Expand All @@ -115,20 +134,107 @@ func (f *Fragment) init(p PhysicalPlan) error {
return nil
}

func newFragment(s *PhysicalExchangeSender) (*Fragment, error) {
f := &Fragment{ExchangeSender: s}
s.Fragment = f
err := f.init(s)
return f, errors.Trace(err)
// We would remove all the union-all operators by 'untwist'ing and copying the plans above union-all.
// This will make every route from root (ExchangeSender) to leaf nodes (ExchangeReceiver and TableScan)
// a new ioslated tree (and also a fragment) without union all. These trees (fragments then tasks) will
// finally be gathered to TiDB or be exchanged to upper tasks again.
// For instance, given a plan "select c1 from t union all select c1 from s"
// after untwist, there will be two plans in `forest` slice:
// - ExchangeSender -> Projection (c1) -> TableScan(t)
// - ExchangeSender -> Projection (c2) -> TableScan(s)
func untwistPlanAndRemoveUnionAll(stack []PhysicalPlan, forest *[]*PhysicalExchangeSender) error {
cur := stack[len(stack)-1]
switch x := cur.(type) {
case *PhysicalTableScan, *PhysicalExchangeReceiver: // This should be the leave node.
p, err := stack[0].Clone()
if err != nil {
return errors.Trace(err)
}
*forest = append(*forest, p.(*PhysicalExchangeSender))
for i := 1; i < len(stack); i++ {
if _, ok := stack[i].(*PhysicalUnionAll); ok {
continue
}
ch, err := stack[i].Clone()
if err != nil {
return errors.Trace(err)
}
if join, ok := p.(*PhysicalHashJoin); ok {
join.SetChild(1-join.InnerChildIdx, ch)
} else {
p.SetChildren(ch)
}
p = ch
}
case *PhysicalHashJoin:
stack = append(stack, x.children[1-x.InnerChildIdx])
err := untwistPlanAndRemoveUnionAll(stack, forest)
stack = stack[:len(stack)-1]
return errors.Trace(err)
case *PhysicalUnionAll:
for _, ch := range x.children {
stack = append(stack, ch)
err := untwistPlanAndRemoveUnionAll(stack, forest)
stack = stack[:len(stack)-1]
if err != nil {
return errors.Trace(err)
}
}
default:
if len(cur.Children()) != 1 {
return errors.Trace(errors.New("unexpected plan " + cur.ExplainID().String()))
}
ch := cur.Children()[0]
stack = append(stack, ch)
err := untwistPlanAndRemoveUnionAll(stack, forest)
stack = stack[:len(stack)-1]
return errors.Trace(err)
}
return nil
}

func (e *mppTaskGenerator) generateMPPTasksForFragment(s *PhysicalExchangeSender) (tasks []*kv.MPPTask, err error) {
f, err := newFragment(s)
func buildFragments(s *PhysicalExchangeSender) ([]*Fragment, error) {
forest := make([]*PhysicalExchangeSender, 0, 1)
err := untwistPlanAndRemoveUnionAll([]PhysicalPlan{s}, &forest)
if err != nil {
return nil, errors.Trace(err)
}
fragments := make([]*Fragment, 0, len(forest))
for _, s := range forest {
f := &Fragment{ExchangeSender: s}
err = f.init(s)
if err != nil {
return nil, errors.Trace(err)
}
fragments = append(fragments, f)
}
return fragments, nil
}

func (e *mppTaskGenerator) generateMPPTasksForExchangeSender(s *PhysicalExchangeSender) ([]*kv.MPPTask, []*Fragment, error) {
if cached, ok := e.cache[s.ID()]; ok {
return cached.tasks, cached.frags, nil
}
frags, err := buildFragments(s)
if err != nil {
return nil, nil, errors.Trace(err)
}
results := make([]*kv.MPPTask, 0, len(frags))
for _, f := range frags {
tasks, err := e.generateMPPTasksForFragment(f)
if err != nil {
return nil, nil, errors.Trace(err)
}
results = append(results, tasks...)
}
e.frags = append(e.frags, frags...)
e.cache[s.ID()] = tasksAndFrags{results, frags}
return results, frags, nil
}

func (e *mppTaskGenerator) generateMPPTasksForFragment(f *Fragment) (tasks []*kv.MPPTask, err error) {
for _, r := range f.ExchangeReceivers {
r.Tasks, err = e.generateMPPTasksForFragment(r.GetExchangeSender())
r.Tasks, r.frags, err = e.generateMPPTasksForExchangeSender(r.GetExchangeSender())
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -149,8 +255,9 @@ func (e *mppTaskGenerator) generateMPPTasksForFragment(s *PhysicalExchangeSender
return nil, errors.New("cannot find mpp task")
}
for _, r := range f.ExchangeReceivers {
s := r.GetExchangeSender()
s.TargetTasks = tasks
for _, frag := range r.frags {
frag.ExchangeSender.TargetTasks = append(frag.ExchangeSender.TargetTasks, tasks...)
}
}
f.ExchangeSender.Tasks = tasks
return tasks, nil
Expand Down
2 changes: 1 addition & 1 deletion planner/core/initialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ func (p PhysicalTableReader) Init(ctx sessionctx.Context, offset int) *PhysicalT
if p.tablePlan != nil {
p.TablePlans = flattenPushDownPlan(p.tablePlan)
p.schema = p.tablePlan.Schema()
if p.StoreType == kv.TiFlash && !p.GetTableScan().KeepOrder {
if p.StoreType == kv.TiFlash && p.GetTableScan() != nil && !p.GetTableScan().KeepOrder {
// When allow batch cop is 1, only agg / topN uses batch cop.
// When allow batch cop is 2, every query uses batch cop.
switch ctx.GetSessionVars().AllowBatchCop {
Expand Down
4 changes: 4 additions & 0 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,10 @@ func (b *PlanBuilder) buildProjection4Union(ctx context.Context, u *LogicalUnion
b.optFlag |= flagEliminateProjection
proj := LogicalProjection{Exprs: exprs, AvoidColumnEvaluator: true}.Init(b.ctx, b.getSelectOffset())
proj.SetSchema(u.schema.Clone())
// reset the schema type to make the "not null" flag right.
for i, expr := range exprs {
proj.schema.Columns[i].RetType = expr.GetType()
}
proj.SetChildren(child)
u.children[childID] = proj
}
Expand Down
Loading

0 comments on commit 52e89cb

Please sign in to comment.