diff --git a/executor/mpp_gather.go b/executor/mpp_gather.go index 7cfeb613c40f6..e517a0130ec4f 100644 --- a/executor/mpp_gather.go +++ b/executor/mpp_gather.go @@ -50,7 +50,7 @@ type MPPGather struct { respIter distsql.SelectResult } -func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment, tasks []*kv.MPPTask, isRoot bool) error { +func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment) error { dagReq, _, err := constructDAGReq(e.ctx, []plannercore.PhysicalPlan{pf.ExchangeSender}, kv.TiFlash) if err != nil { return errors.Trace(err) @@ -58,12 +58,12 @@ func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment, tasks []*kv.M for i := range pf.ExchangeSender.Schema().Columns { dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(i)) } - if !isRoot { + if !pf.IsRoot { dagReq.EncodeType = tipb.EncodeType_TypeCHBlock } else { dagReq.EncodeType = tipb.EncodeType_TypeChunk } - for _, mppTask := range tasks { + for _, mppTask := range pf.ExchangeSender.Tasks { err := updateExecutorTableID(context.Background(), dagReq.RootExecutor, mppTask.TableID, true) if err != nil { return errors.Trace(err) @@ -77,7 +77,7 @@ func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment, tasks []*kv.M Data: pbData, Meta: mppTask.Meta, ID: mppTask.ID, - IsRoot: isRoot, + IsRoot: pf.IsRoot, Timeout: 10, SchemaVar: e.is.SchemaMetaVersion(), StartTs: e.startTS, @@ -85,12 +85,6 @@ func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment, tasks []*kv.M } e.mppReqs = append(e.mppReqs, req) } - for _, r := range pf.ExchangeReceivers { - err = e.appendMPPDispatchReq(r.GetExchangeSender().Fragment, r.Tasks, false) - if err != nil { - return errors.Trace(err) - } - } return nil } @@ -108,13 +102,15 @@ func (e *MPPGather) Open(ctx context.Context) (err error) { // TODO: Move the construct tasks logic to planner, so we can see the explain results. sender := e.originalPlan.(*plannercore.PhysicalExchangeSender) planIDs := collectPlanIDS(e.originalPlan, nil) - rootTasks, err := plannercore.GenerateRootMPPTasks(e.ctx, e.startTS, sender, e.is) + frags, err := plannercore.GenerateRootMPPTasks(e.ctx, e.startTS, sender, e.is) if err != nil { return errors.Trace(err) } - err = e.appendMPPDispatchReq(sender.Fragment, rootTasks, true) - if err != nil { - return errors.Trace(err) + for _, frag := range frags { + err = e.appendMPPDispatchReq(frag) + if err != nil { + return errors.Trace(err) + } } failpoint.Inject("checkTotalMPPTasks", func(val failpoint.Value) { if val.(int) != len(e.mppReqs) { diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index 690106a9a38e5..e442a219ac1d2 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -425,6 +425,48 @@ func (s *tiflashTestSuite) TestMppGoroutinesExitFromErrors(c *C) { c.Assert(failpoint.Disable(hang), IsNil) } +func (s *tiflashTestSuite) TestMppUnionAll(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists x1") + tk.MustExec("create table x1(a int , b int);") + tk.MustExec("alter table x1 set tiflash replica 1") + tk.MustExec("drop table if exists x2") + tk.MustExec("create table x2(a int , b int);") + tk.MustExec("alter table x2 set tiflash replica 1") + tb := testGetTableByName(c, tk.Se, "test", "x1") + err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tb = testGetTableByName(c, tk.Se, "test", "x2") + err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + + tk.MustExec("insert into x1 values (1, 1), (2, 2), (3, 3), (4, 4)") + tk.MustExec("insert into x2 values (5, 1), (2, 2), (3, 3), (4, 4)") + + // test join + union (join + select) + tk.MustQuery("select x1.a, x.a from x1 left join (select x2.b a, x1.b from x1 join x2 on x1.a = x2.b union all select * from x1 ) x on x1.a = x.a order by x1.a").Check(testkit.Rows("1 1", "1 1", "2 2", "2 2", "3 3", "3 3", "4 4", "4 4")) + tk.MustQuery("select x1.a, x.a from x1 left join (select count(*) a, sum(b) b from x1 group by a union all select * from x2 ) x on x1.a = x.a order by x1.a").Check(testkit.Rows("1 1", "1 1", "1 1", "1 1", "2 2", "3 3", "4 4")) + + tk.MustExec("drop table if exists x3") + tk.MustExec("create table x3(a int , b int);") + tk.MustExec("alter table x3 set tiflash replica 1") + tb = testGetTableByName(c, tk.Se, "test", "x3") + err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + + tk.MustExec("insert into x3 values (2, 2), (2, 3), (2, 4)") + // test nested union all + tk.MustQuery("select count(*) from (select a, b from x1 union all select a, b from x3 union all (select x1.a, x3.b from (select * from x3 union all select * from x2) x3 left join x1 on x3.a = x1.b))").Check(testkit.Rows("14")) + // test union all join union all + tk.MustQuery("select count(*) from (select * from x1 union all select * from x2 union all select * from x3) x join (select * from x1 union all select * from x2 union all select * from x3) y on x.a = y.b").Check(testkit.Rows("29")) + tk.MustExec("set @@session.tidb_broadcast_join_threshold_count=100000") + failpoint.Enable("github.com/pingcap/tidb/executor/checkTotalMPPTasks", `return(6)`) + tk.MustQuery("select count(*) from (select * from x1 union all select * from x2 union all select * from x3) x join (select * from x1 union all select * from x2 union all select * from x3) y on x.a = y.b").Check(testkit.Rows("29")) + failpoint.Disable("github.com/pingcap/tidb/executor/checkTotalMPPTasks") + +} + func (s *tiflashTestSuite) TestMppApply(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/expression/column.go b/expression/column.go index 006b9a3867cda..ebc0feaf93a06 100644 --- a/expression/column.go +++ b/expression/column.go @@ -38,10 +38,9 @@ type CorrelatedColumn struct { // Clone implements Expression interface. func (col *CorrelatedColumn) Clone() Expression { - var d types.Datum return &CorrelatedColumn{ Column: col.Column, - Data: &d, + Data: col.Data, } } diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index dcb4e991ebe33..b2ba91ff99dcd 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -2126,7 +2126,7 @@ func (p *baseLogicalPlan) canPushToCop(storeTp kv.StoreType) bool { } } ret = ret && validDs - case *LogicalAggregation, *LogicalProjection, *LogicalSelection, *LogicalJoin: + case *LogicalAggregation, *LogicalProjection, *LogicalSelection, *LogicalJoin, *LogicalUnionAll: if storeTp == kv.TiFlash { ret = ret && c.canPushToCop(storeTp) } else { @@ -2494,15 +2494,41 @@ func (p *LogicalLock) exhaustPhysicalPlans(prop *property.PhysicalProperty) ([]P func (p *LogicalUnionAll) exhaustPhysicalPlans(prop *property.PhysicalProperty) ([]PhysicalPlan, bool, error) { // TODO: UnionAll can not pass any order, but we can change it to sort merge to keep order. - if !prop.IsEmpty() || prop.IsFlashProp() { + if !prop.IsEmpty() || (prop.IsFlashProp() && prop.TaskTp != property.MppTaskType) { + return nil, true, nil + } + // TODO: UnionAll can pass partition info, but for briefness, we prevent it from pushing down. + if prop.TaskTp == property.MppTaskType && prop.PartitionTp != property.AnyType { return nil, true, nil } + canUseMpp := p.ctx.GetSessionVars().IsMPPAllowed() && p.canPushToCop(kv.TiFlash) chReqProps := make([]*property.PhysicalProperty, 0, len(p.children)) for range p.children { - chReqProps = append(chReqProps, &property.PhysicalProperty{ExpectedCnt: prop.ExpectedCnt}) + if canUseMpp && prop.TaskTp == property.MppTaskType { + chReqProps = append(chReqProps, &property.PhysicalProperty{ + ExpectedCnt: prop.ExpectedCnt, + TaskTp: property.MppTaskType, + }) + } else { + chReqProps = append(chReqProps, &property.PhysicalProperty{ExpectedCnt: prop.ExpectedCnt}) + } } - ua := PhysicalUnionAll{}.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), p.blockOffset, chReqProps...) + ua := PhysicalUnionAll{ + mpp: canUseMpp && prop.TaskTp == property.MppTaskType, + }.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), p.blockOffset, chReqProps...) ua.SetSchema(p.Schema()) + if canUseMpp && prop.TaskTp == property.RootTaskType { + chReqProps = make([]*property.PhysicalProperty, 0, len(p.children)) + for range p.children { + chReqProps = append(chReqProps, &property.PhysicalProperty{ + ExpectedCnt: prop.ExpectedCnt, + TaskTp: property.MppTaskType, + }) + } + mppUA := PhysicalUnionAll{mpp: true}.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), p.blockOffset, chReqProps...) + mppUA.SetSchema(p.Schema()) + return []PhysicalPlan{ua, mppUA}, true, nil + } return []PhysicalPlan{ua}, true, nil } diff --git a/planner/core/fragment.go b/planner/core/fragment.go index d0ab808c742e7..7315da176e6b8 100644 --- a/planner/core/fragment.go +++ b/planner/core/fragment.go @@ -38,32 +38,49 @@ type Fragment struct { // following fields are filled after scheduling. ExchangeSender *PhysicalExchangeSender // data exporter + + IsRoot bool +} + +type tasksAndFrags struct { + tasks []*kv.MPPTask + frags []*Fragment } type mppTaskGenerator struct { ctx sessionctx.Context startTS uint64 is infoschema.InfoSchema + frags []*Fragment + cache map[int]tasksAndFrags } // GenerateRootMPPTasks generate all mpp tasks and return root ones. -func GenerateRootMPPTasks(ctx sessionctx.Context, startTs uint64, sender *PhysicalExchangeSender, is infoschema.InfoSchema) ([]*kv.MPPTask, error) { - g := &mppTaskGenerator{ctx: ctx, startTS: startTs, is: is} +func GenerateRootMPPTasks(ctx sessionctx.Context, startTs uint64, sender *PhysicalExchangeSender, is infoschema.InfoSchema) ([]*Fragment, error) { + g := &mppTaskGenerator{ + ctx: ctx, + startTS: startTs, + is: is, + cache: make(map[int]tasksAndFrags), + } return g.generateMPPTasks(sender) } -func (e *mppTaskGenerator) generateMPPTasks(s *PhysicalExchangeSender) ([]*kv.MPPTask, error) { +func (e *mppTaskGenerator) generateMPPTasks(s *PhysicalExchangeSender) ([]*Fragment, error) { logutil.BgLogger().Info("Mpp will generate tasks", zap.String("plan", ToString(s))) tidbTask := &kv.MPPTask{ StartTs: e.startTS, ID: -1, } - rootTasks, err := e.generateMPPTasksForFragment(s) + _, frags, err := e.generateMPPTasksForExchangeSender(s) if err != nil { return nil, errors.Trace(err) } - s.TargetTasks = []*kv.MPPTask{tidbTask} - return rootTasks, nil + for _, frag := range frags { + frag.ExchangeSender.TargetTasks = []*kv.MPPTask{tidbTask} + frag.IsRoot = true + } + return e.frags, nil } type mppAddr struct { @@ -105,6 +122,8 @@ func (f *Fragment) init(p PhysicalPlan) error { f.TableScan = x case *PhysicalExchangeReceiver: f.ExchangeReceivers = append(f.ExchangeReceivers, x) + case *PhysicalUnionAll: + return errors.New("unexpected union all detected") default: for _, ch := range p.Children() { if err := f.init(ch); err != nil { @@ -115,20 +134,107 @@ func (f *Fragment) init(p PhysicalPlan) error { return nil } -func newFragment(s *PhysicalExchangeSender) (*Fragment, error) { - f := &Fragment{ExchangeSender: s} - s.Fragment = f - err := f.init(s) - return f, errors.Trace(err) +// We would remove all the union-all operators by 'untwist'ing and copying the plans above union-all. +// This will make every route from root (ExchangeSender) to leaf nodes (ExchangeReceiver and TableScan) +// a new ioslated tree (and also a fragment) without union all. These trees (fragments then tasks) will +// finally be gathered to TiDB or be exchanged to upper tasks again. +// For instance, given a plan "select c1 from t union all select c1 from s" +// after untwist, there will be two plans in `forest` slice: +// - ExchangeSender -> Projection (c1) -> TableScan(t) +// - ExchangeSender -> Projection (c2) -> TableScan(s) +func untwistPlanAndRemoveUnionAll(stack []PhysicalPlan, forest *[]*PhysicalExchangeSender) error { + cur := stack[len(stack)-1] + switch x := cur.(type) { + case *PhysicalTableScan, *PhysicalExchangeReceiver: // This should be the leave node. + p, err := stack[0].Clone() + if err != nil { + return errors.Trace(err) + } + *forest = append(*forest, p.(*PhysicalExchangeSender)) + for i := 1; i < len(stack); i++ { + if _, ok := stack[i].(*PhysicalUnionAll); ok { + continue + } + ch, err := stack[i].Clone() + if err != nil { + return errors.Trace(err) + } + if join, ok := p.(*PhysicalHashJoin); ok { + join.SetChild(1-join.InnerChildIdx, ch) + } else { + p.SetChildren(ch) + } + p = ch + } + case *PhysicalHashJoin: + stack = append(stack, x.children[1-x.InnerChildIdx]) + err := untwistPlanAndRemoveUnionAll(stack, forest) + stack = stack[:len(stack)-1] + return errors.Trace(err) + case *PhysicalUnionAll: + for _, ch := range x.children { + stack = append(stack, ch) + err := untwistPlanAndRemoveUnionAll(stack, forest) + stack = stack[:len(stack)-1] + if err != nil { + return errors.Trace(err) + } + } + default: + if len(cur.Children()) != 1 { + return errors.Trace(errors.New("unexpected plan " + cur.ExplainID().String())) + } + ch := cur.Children()[0] + stack = append(stack, ch) + err := untwistPlanAndRemoveUnionAll(stack, forest) + stack = stack[:len(stack)-1] + return errors.Trace(err) + } + return nil } -func (e *mppTaskGenerator) generateMPPTasksForFragment(s *PhysicalExchangeSender) (tasks []*kv.MPPTask, err error) { - f, err := newFragment(s) +func buildFragments(s *PhysicalExchangeSender) ([]*Fragment, error) { + forest := make([]*PhysicalExchangeSender, 0, 1) + err := untwistPlanAndRemoveUnionAll([]PhysicalPlan{s}, &forest) if err != nil { return nil, errors.Trace(err) } + fragments := make([]*Fragment, 0, len(forest)) + for _, s := range forest { + f := &Fragment{ExchangeSender: s} + err = f.init(s) + if err != nil { + return nil, errors.Trace(err) + } + fragments = append(fragments, f) + } + return fragments, nil +} + +func (e *mppTaskGenerator) generateMPPTasksForExchangeSender(s *PhysicalExchangeSender) ([]*kv.MPPTask, []*Fragment, error) { + if cached, ok := e.cache[s.ID()]; ok { + return cached.tasks, cached.frags, nil + } + frags, err := buildFragments(s) + if err != nil { + return nil, nil, errors.Trace(err) + } + results := make([]*kv.MPPTask, 0, len(frags)) + for _, f := range frags { + tasks, err := e.generateMPPTasksForFragment(f) + if err != nil { + return nil, nil, errors.Trace(err) + } + results = append(results, tasks...) + } + e.frags = append(e.frags, frags...) + e.cache[s.ID()] = tasksAndFrags{results, frags} + return results, frags, nil +} + +func (e *mppTaskGenerator) generateMPPTasksForFragment(f *Fragment) (tasks []*kv.MPPTask, err error) { for _, r := range f.ExchangeReceivers { - r.Tasks, err = e.generateMPPTasksForFragment(r.GetExchangeSender()) + r.Tasks, r.frags, err = e.generateMPPTasksForExchangeSender(r.GetExchangeSender()) if err != nil { return nil, errors.Trace(err) } @@ -149,8 +255,9 @@ func (e *mppTaskGenerator) generateMPPTasksForFragment(s *PhysicalExchangeSender return nil, errors.New("cannot find mpp task") } for _, r := range f.ExchangeReceivers { - s := r.GetExchangeSender() - s.TargetTasks = tasks + for _, frag := range r.frags { + frag.ExchangeSender.TargetTasks = append(frag.ExchangeSender.TargetTasks, tasks...) + } } f.ExchangeSender.Tasks = tasks return tasks, nil diff --git a/planner/core/initialize.go b/planner/core/initialize.go index f41340147abfe..cbf099baf4113 100644 --- a/planner/core/initialize.go +++ b/planner/core/initialize.go @@ -419,7 +419,7 @@ func (p PhysicalTableReader) Init(ctx sessionctx.Context, offset int) *PhysicalT if p.tablePlan != nil { p.TablePlans = flattenPushDownPlan(p.tablePlan) p.schema = p.tablePlan.Schema() - if p.StoreType == kv.TiFlash && !p.GetTableScan().KeepOrder { + if p.StoreType == kv.TiFlash && p.GetTableScan() != nil && !p.GetTableScan().KeepOrder { // When allow batch cop is 1, only agg / topN uses batch cop. // When allow batch cop is 2, every query uses batch cop. switch ctx.GetSessionVars().AllowBatchCop { diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index e6b2c6b6ef20b..ccb93029edd25 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -1378,6 +1378,10 @@ func (b *PlanBuilder) buildProjection4Union(ctx context.Context, u *LogicalUnion b.optFlag |= flagEliminateProjection proj := LogicalProjection{Exprs: exprs, AvoidColumnEvaluator: true}.Init(b.ctx, b.getSelectOffset()) proj.SetSchema(u.schema.Clone()) + // reset the schema type to make the "not null" flag right. + for i, expr := range exprs { + proj.schema.Columns[i].RetType = expr.GetType() + } proj.SetChildren(child) u.children[childID] = proj } diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 5602399a63b74..74c70e1fce3c9 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -111,7 +111,10 @@ func (p *PhysicalTableReader) GetTableScan() *PhysicalTableScan { } else if chCnt == 1 { curPlan = curPlan.Children()[0] } else { - join := curPlan.(*PhysicalHashJoin) + join, ok := curPlan.(*PhysicalHashJoin) + if !ok { + return nil + } curPlan = join.children[1-join.globalChildIndex] } } @@ -886,6 +889,18 @@ type PhysicalExchangeReceiver struct { basePhysicalPlan Tasks []*kv.MPPTask + frags []*Fragment +} + +// Clone implment PhysicalPlan interface. +func (p *PhysicalExchangeReceiver) Clone() (PhysicalPlan, error) { + np := new(PhysicalExchangeReceiver) + base, err := p.basePhysicalPlan.cloneWithSelf(np) + if err != nil { + return nil, errors.Trace(err) + } + np.basePhysicalPlan = *base + return np, nil } // GetExchangeSender return the connected sender of this receiver. We assume that its child must be a receiver. @@ -900,10 +915,21 @@ type PhysicalExchangeSender struct { TargetTasks []*kv.MPPTask ExchangeType tipb.ExchangeType HashCols []*expression.Column - // Tasks is the mpp task for current PhysicalExchangeSender + // Tasks is the mpp task for current PhysicalExchangeSender. Tasks []*kv.MPPTask +} - Fragment *Fragment +// Clone implment PhysicalPlan interface. +func (p *PhysicalExchangeSender) Clone() (PhysicalPlan, error) { + np := new(PhysicalExchangeSender) + base, err := p.basePhysicalPlan.cloneWithSelf(np) + if err != nil { + return nil, errors.Trace(err) + } + np.basePhysicalPlan = *base + np.ExchangeType = p.ExchangeType + np.HashCols = p.HashCols + return np, nil } // Clone implements PhysicalPlan interface. @@ -952,6 +978,19 @@ func (p *PhysicalLimit) Clone() (PhysicalPlan, error) { // PhysicalUnionAll is the physical operator of UnionAll. type PhysicalUnionAll struct { physicalSchemaProducer + + mpp bool +} + +// Clone implements PhysicalPlan interface. +func (p *PhysicalUnionAll) Clone() (PhysicalPlan, error) { + cloned := new(PhysicalUnionAll) + base, err := p.physicalSchemaProducer.cloneWithSelf(cloned) + if err != nil { + return nil, err + } + cloned.physicalSchemaProducer = *base + return cloned, nil } // AggMppRunMode defines the running mode of aggregation in MPP diff --git a/planner/core/plan.go b/planner/core/plan.go index dd7e41b77f7fa..2a62afa960b63 100644 --- a/planner/core/plan.go +++ b/planner/core/plan.go @@ -408,6 +408,9 @@ func (p *basePhysicalPlan) cloneWithSelf(newSelf PhysicalPlan) (*basePhysicalPla base.children = append(base.children, cloned) } for _, prop := range p.childrenReqProps { + if prop == nil { + continue + } base.childrenReqProps = append(base.childrenReqProps, prop.CloneEssentialFields()) } return base, nil diff --git a/planner/core/rule_inject_extra_projection.go b/planner/core/rule_inject_extra_projection.go index 2896a1dade0ff..968917a4fef2e 100644 --- a/planner/core/rule_inject_extra_projection.go +++ b/planner/core/rule_inject_extra_projection.go @@ -14,6 +14,7 @@ package core import ( + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/kv" @@ -62,10 +63,42 @@ func (pe *projInjector) inject(plan PhysicalPlan) PhysicalPlan { plan = InjectProjBelowSort(p, p.ByItems) case *NominalSort: plan = TurnNominalSortIntoProj(p, p.OnlyColumn, p.ByItems) + case *PhysicalUnionAll: + plan = injectProjBelowUnion(p) } return plan } +func injectProjBelowUnion(un *PhysicalUnionAll) *PhysicalUnionAll { + if !un.mpp { + return un + } + for i, ch := range un.children { + exprs := make([]expression.Expression, len(ch.Schema().Columns)) + needChange := false + for i, dstCol := range un.schema.Columns { + dstType := dstCol.RetType + srcCol := ch.Schema().Columns[i] + srcType := srcCol.RetType + if !srcType.Equal(dstType) || !(mysql.HasNotNullFlag(dstType.Flag) == mysql.HasNotNullFlag(srcType.Flag)) { + exprs[i] = expression.BuildCastFunction4Union(un.ctx, srcCol, dstType) + needChange = true + } else { + exprs[i] = srcCol + } + } + if needChange { + proj := PhysicalProjection{ + Exprs: exprs, + }.Init(un.ctx, ch.statsInfo(), 0) + proj.SetSchema(un.schema.Clone()) + proj.SetChildren(ch) + un.children[i] = proj + } + } + return un +} + // wrapCastForAggFunc wraps the args of an aggregate function with a cast function. // If the mode is FinalMode or Partial2Mode, we do not need to wrap cast upon the args, // since the types of the args are already the expected. diff --git a/planner/core/task.go b/planner/core/task.go index fa6855503dd0e..f4c1d1ca72c65 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1300,7 +1300,31 @@ func (p *PhysicalProjection) attach2Task(tasks ...task) task { return t } +func (p *PhysicalUnionAll) attach2MppTasks(tasks ...task) task { + t := &mppTask{p: p} + childPlans := make([]PhysicalPlan, 0, len(tasks)) + var childMaxCost float64 + for _, tk := range tasks { + if mpp, ok := tk.(*mppTask); ok && !tk.invalid() { + childCost := mpp.cost() + if childCost > childMaxCost { + childMaxCost = childCost + } + childPlans = append(childPlans, mpp.plan()) + } else { + return invalidTask + } + } + p.SetChildren(childPlans...) + t.cst = childMaxCost + p.cost = t.cost() + return t +} + func (p *PhysicalUnionAll) attach2Task(tasks ...task) task { + if _, ok := tasks[0].(*mppTask); ok { + return p.attach2MppTasks(tasks...) + } t := &rootTask{p: p} childPlans := make([]PhysicalPlan, 0, len(tasks)) var childMaxCost float64 diff --git a/store/copr/mpp.go b/store/copr/mpp.go index db0aff7e22696..4be45c3288e23 100644 --- a/store/copr/mpp.go +++ b/store/copr/mpp.go @@ -239,6 +239,7 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *Backoffer, req realResp := rpcResp.Resp.(*mpp.DispatchTaskResponse) if realResp.Error != nil { + logutil.BgLogger().Error("mpp dispatch response meet error", zap.String("error", realResp.Error.Msg)) m.sendError(errors.New(realResp.Error.Msg)) return }