Skip to content

Commit

Permalink
planner/core: convert decimal type for mpp join before shuffling. (#2…
Browse files Browse the repository at this point in the history
…3191)

* planner: convert decimal type for mpp join before shuffling.

* fix bug and add code

* add some comments

* fix typo

* fix test

* add test and fix

* address comments

* add more tests

* address comments

* address comments

* address comments

* address comments

* add tests

* refine test

Co-authored-by: Ti Chi Robot <71242396+ti-chi-bot@users.noreply.github.com>
  • Loading branch information
hanfei1991 and ti-chi-bot authored Mar 25, 2021
1 parent 1c83b14 commit 40b9218
Show file tree
Hide file tree
Showing 11 changed files with 478 additions and 31 deletions.
2 changes: 1 addition & 1 deletion executor/mpp_gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment, tasks []*kv.M
e.mppReqs = append(e.mppReqs, req)
}
for _, r := range pf.ExchangeReceivers {
err = e.appendMPPDispatchReq(r.ChildPf, r.Tasks, false)
err = e.appendMPPDispatchReq(r.GetExchangeSender().Fragment, r.Tasks, false)
if err != nil {
return errors.Trace(err)
}
Expand Down
14 changes: 14 additions & 0 deletions executor/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,26 @@ func (s *tiflashTestSuite) TestMppExecution(c *C) {
tk.MustQuery("select count(*) from t1 , t where t1.a = t.a").Check(testkit.Rows("3"))
tk.MustQuery("select count(*) from t1 , t, t2 where t1.a = t.a and t2.a = t.a").Check(testkit.Rows("3"))

// test agg by expression
tk.MustExec("insert into t1 values(4,0)")
tk.MustQuery("select count(*) k, t2.b from t1 left join t2 on t1.a = t2.a group by t2.b order by k").Check(testkit.Rows("1 <nil>", "3 0"))
tk.MustQuery("select count(*) k, t2.b+1 from t1 left join t2 on t1.a = t2.a group by t2.b+1 order by k").Check(testkit.Rows("1 <nil>", "3 1"))
tk.MustQuery("select count(*) k, t2.b * t2.a from t2 group by t2.b * t2.a").Check(testkit.Rows("3 0"))
tk.MustQuery("select count(*) k, t2.a/2 m from t2 group by t2.a / 2 order by m").Check(testkit.Rows("1 0.5000", "1 1.0000", "1 1.5000"))
tk.MustQuery("select count(*) k, t2.a div 2 from t2 group by t2.a div 2 order by k").Check(testkit.Rows("1 0", "2 1"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (c1 decimal(8, 5) not null, c2 decimal(9, 5), c3 decimal(9, 4) , c4 decimal(8, 4) not null)")
tk.MustExec("alter table t set tiflash replica 1")
tb = testGetTableByName(c, tk.Se, "test", "t")
err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)
tk.MustExec("insert into t values(1.00000,1.00000,1.0000,1.0000)")
tk.MustExec("insert into t values(1.00010,1.00010,1.0001,1.0001)")
tk.MustExec("insert into t values(1.00001,1.00001,1.0000,1.0002)")
tk.MustQuery("select t1.c1 from t t1 join t t2 on t1.c1 = t2.c1 order by t1.c1").Check(testkit.Rows("1.00000", "1.00001", "1.00010"))
tk.MustQuery("select t1.c1 from t t1 join t t2 on t1.c1 = t2.c3 order by t1.c1").Check(testkit.Rows("1.00000", "1.00000", "1.00010"))
tk.MustQuery("select t1.c4 from t t1 join t t2 on t1.c4 = t2.c3 order by t1.c4").Check(testkit.Rows("1.0000", "1.0000", "1.0001"))
}

func (s *tiflashTestSuite) TestPartitionTable(c *C) {
Expand Down
1 change: 1 addition & 0 deletions planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,7 @@ func (p *LogicalJoin) tryToGetMppHashJoin(prop *property.PhysicalProperty, useBC
Concurrency: uint(p.ctx.GetSessionVars().CopTiFlashConcurrencyFactor),
EqualConditions: p.EqualConditions,
storeTp: kv.TiFlash,
mppShuffleJoin: !useBCJ,
}.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), p.blockOffset, childrenProps...)
return []PhysicalPlan{join}
}
Expand Down
1 change: 0 additions & 1 deletion planner/core/find_best_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1493,7 +1493,6 @@ func (ds *DataSource) convertToTableScan(prop *property.PhysicalProperty, candid
p: ts,
cst: cost,
partTp: property.AnyType,
ts: ts,
}
ts.PartitionInfo = PartitionInfo{
PruningConds: ds.allConds,
Expand Down
38 changes: 34 additions & 4 deletions planner/core/fragment.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,47 @@ func (e *mppTaskGenerator) generateMPPTasks(s *PhysicalExchangeSender) ([]*kv.MP
StartTs: e.startTS,
ID: -1,
}
rootTasks, err := e.generateMPPTasksForFragment(s.Fragment)
rootTasks, err := e.generateMPPTasksForFragment(s)
if err != nil {
return nil, errors.Trace(err)
}
s.TargetTasks = []*kv.MPPTask{tidbTask}
return rootTasks, nil
}

func (e *mppTaskGenerator) generateMPPTasksForFragment(f *Fragment) (tasks []*kv.MPPTask, err error) {
func (f *Fragment) init(p PhysicalPlan) error {
switch x := p.(type) {
case *PhysicalTableScan:
if f.TableScan != nil {
return errors.New("one task contains at most one table scan")
}
f.TableScan = x
case *PhysicalExchangeReceiver:
f.ExchangeReceivers = append(f.ExchangeReceivers, x)
default:
for _, ch := range p.Children() {
if err := f.init(ch); err != nil {
return errors.Trace(err)
}
}
}
return nil
}

func newFragment(s *PhysicalExchangeSender) (*Fragment, error) {
f := &Fragment{ExchangeSender: s}
s.Fragment = f
err := f.init(s)
return f, errors.Trace(err)
}

func (e *mppTaskGenerator) generateMPPTasksForFragment(s *PhysicalExchangeSender) (tasks []*kv.MPPTask, err error) {
f, err := newFragment(s)
if err != nil {
return nil, errors.Trace(err)
}
for _, r := range f.ExchangeReceivers {
r.Tasks, err = e.generateMPPTasksForFragment(r.ChildPf)
r.Tasks, err = e.generateMPPTasksForFragment(r.GetExchangeSender())
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -86,7 +116,7 @@ func (e *mppTaskGenerator) generateMPPTasksForFragment(f *Fragment) (tasks []*kv
return nil, errors.New("cannot find mpp task")
}
for _, r := range f.ExchangeReceivers {
s := r.ChildPf.ExchangeSender
s := r.GetExchangeSender()
s.TargetTasks = tasks
}
f.ExchangeSender.Tasks = tasks
Expand Down
41 changes: 41 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2807,6 +2807,47 @@ func (s *testIntegrationSerialSuite) TestPushDownAggForMPP(c *C) {
}
}

func (s *testIntegrationSerialSuite) TestMppJoinDecimal(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (c1 decimal(8, 5), c2 decimal(9, 5), c3 decimal(9, 4) NOT NULL, c4 decimal(8, 4) NOT NULL, c5 decimal(40, 20))")
tk.MustExec("analyze table t")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Se)
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
c.Assert(exists, IsTrue)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
}
}
}

tk.MustExec("set @@tidb_allow_mpp=1;")
tk.MustExec("set @@session.tidb_broadcast_join_threshold_size = 1")
tk.MustExec("set @@session.tidb_broadcast_join_threshold_count = 1")

var input []string
var output []struct {
SQL string
Plan []string
}
s.testData.GetTestCases(c, &input, &output)
for i, tt := range input {
s.testData.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
}
}

func (s *testIntegrationSerialSuite) TestMppAggWithJoin(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
43 changes: 43 additions & 0 deletions planner/core/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,49 @@

package core

import (
. "github.com/pingcap/check"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/types"
)

// LogicalOptimize exports the `logicalOptimize` function for test packages and
// doesn't affect the normal package and access control of Golang (tricky ^_^)
var LogicalOptimize = logicalOptimize

var _ = Suite(&testPlannerFunctionSuite{})

type testPlannerFunctionSuite struct {
}

func testDecimalConvert(lDec, lLen, rDec, rLen int, lConvert, rConvert bool, cDec, cLen int, c *C) {
lType := types.NewFieldType(mysql.TypeNewDecimal)
lType.Decimal = lDec
lType.Flen = lLen

rType := types.NewFieldType(mysql.TypeNewDecimal)
rType.Decimal = rDec
rType.Flen = rLen

cType, lCon, rCon := negotiateCommonType(lType, rType)
c.Assert(cType.Tp, Equals, mysql.TypeNewDecimal)
c.Assert(cType.Decimal, Equals, cDec)
c.Assert(cType.Flen, Equals, cLen)
c.Assert(lConvert, Equals, lCon)
c.Assert(rConvert, Equals, rCon)
}

func (t *testPlannerFunctionSuite) TestMPPDecimalConvert(c *C) {
testDecimalConvert(5, 9, 5, 8, false, false, 5, 9, c)
testDecimalConvert(5, 8, 5, 9, false, false, 5, 9, c)
testDecimalConvert(0, 8, 0, 11, true, false, 0, 11, c)
testDecimalConvert(0, 16, 0, 11, false, false, 0, 16, c)
testDecimalConvert(5, 9, 4, 9, true, true, 5, 10, c)
testDecimalConvert(5, 8, 4, 9, true, true, 5, 10, c)
testDecimalConvert(5, 9, 4, 8, false, true, 5, 9, c)
testDecimalConvert(10, 16, 0, 11, true, true, 10, 21, c)
testDecimalConvert(5, 19, 0, 20, false, true, 5, 25, c)
testDecimalConvert(20, 20, 0, 60, true, true, 20, 65, c)
testDecimalConvert(20, 40, 0, 60, false, true, 20, 65, c)
testDecimalConvert(0, 40, 0, 60, false, false, 0, 60, c)
}
9 changes: 7 additions & 2 deletions planner/core/physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ type PhysicalHashJoin struct {
// on which store the join executes.
storeTp kv.StoreType
globalChildIndex int
mppShuffleJoin bool
}

// Clone implements PhysicalPlan interface.
Expand Down Expand Up @@ -851,8 +852,12 @@ type PhysicalMergeJoin struct {
type PhysicalExchangeReceiver struct {
basePhysicalPlan

Tasks []*kv.MPPTask
ChildPf *Fragment
Tasks []*kv.MPPTask
}

// GetExchangeSender return the connected sender of this receiver. We assume that its child must be a receiver.
func (p *PhysicalExchangeReceiver) GetExchangeSender() *PhysicalExchangeSender {
return p.children[0].(*PhysicalExchangeSender)
}

// PhysicalExchangeSender dispatches data to upstream tasks. That means push mode processing,
Expand Down
Loading

0 comments on commit 40b9218

Please sign in to comment.