Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner/core: change agg cost factor (#25210) #25241

Merged
merged 3 commits into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 0 additions & 74 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,80 +654,6 @@ func (s *testIntegrationSerialSuite) TestMPPShuffledJoin(c *C) {
}
}

func (s *testIntegrationSerialSuite) TestBroadcastJoin(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("set session tidb_allow_mpp = OFF")
tk.MustExec("drop table if exists d1_t")
tk.MustExec("create table d1_t(d1_k int, value int)")
tk.MustExec("insert into d1_t values(1,2),(2,3)")
tk.MustExec("analyze table d1_t")
tk.MustExec("drop table if exists d2_t")
tk.MustExec("create table d2_t(d2_k decimal(10,2), value int)")
tk.MustExec("insert into d2_t values(10.11,2),(10.12,3)")
tk.MustExec("analyze table d2_t")
tk.MustExec("drop table if exists d3_t")
tk.MustExec("create table d3_t(d3_k date, value int)")
tk.MustExec("insert into d3_t values(date'2010-01-01',2),(date'2010-01-02',3)")
tk.MustExec("analyze table d3_t")
tk.MustExec("drop table if exists fact_t")
tk.MustExec("create table fact_t(d1_k int, d2_k decimal(10,2), d3_k date, col1 int, col2 int, col3 int)")
tk.MustExec("insert into fact_t values(1,10.11,date'2010-01-01',1,2,3),(1,10.11,date'2010-01-02',1,2,3),(1,10.12,date'2010-01-01',1,2,3),(1,10.12,date'2010-01-02',1,2,3)")
tk.MustExec("insert into fact_t values(2,10.11,date'2010-01-01',1,2,3),(2,10.11,date'2010-01-02',1,2,3),(2,10.12,date'2010-01-01',1,2,3),(2,10.12,date'2010-01-02',1,2,3)")
tk.MustExec("analyze table fact_t")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Se)
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
c.Assert(exists, IsTrue)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "fact_t" || tblInfo.Name.L == "d1_t" || tblInfo.Name.L == "d2_t" || tblInfo.Name.L == "d3_t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
}
}
}

tk.MustExec("set @@session.tidb_isolation_read_engines = 'tiflash'")
tk.MustExec("set @@session.tidb_allow_batch_cop = 1")
tk.MustExec("set @@session.tidb_opt_broadcast_join = 1")
// make cbo force choose broadcast join since sql hint does not work for semi/anti-semi join
tk.MustExec("set @@session.tidb_opt_cpu_factor=10000000;")
var input []string
var output []struct {
SQL string
Plan []string
}
s.testData.GetTestCases(c, &input, &output)
for i, tt := range input {
s.testData.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
}

// out table of out join should not be global
_, err := tk.Exec("explain format = 'brief' select /*+ broadcast_join(fact_t, d1_t), broadcast_join_local(d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k")
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "[planner:1815]Internal : Can't find a proper physical plan for this query")
// nullEQ not supported
_, err = tk.Exec("explain format = 'brief' select /*+ broadcast_join(fact_t, d1_t) */ count(*) from fact_t join d1_t on fact_t.d1_k <=> d1_t.d1_k")
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "[planner:1815]Internal : Can't find a proper physical plan for this query")
// not supported if join condition has unsupported expr
_, err = tk.Exec("explain format = 'brief' select /*+ broadcast_join(fact_t, d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k and sqrt(fact_t.col1) > 2")
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "[planner:1815]Internal : Can't find a proper physical plan for this query")
// cartsian join not supported
_, err = tk.Exec("explain format = 'brief' select /*+ broadcast_join(fact_t, d1_t) */ count(*) from fact_t join d1_t")
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "[planner:1815]Internal : Can't find a proper physical plan for this query")
}

func (s *testIntegrationSerialSuite) TestJoinNotSupportedByTiFlash(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
12 changes: 10 additions & 2 deletions planner/core/physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ func (p *basePhysicalAgg) numDistinctFunc() (num int) {
return
}

func (p *basePhysicalAgg) getAggFuncCostFactor() (factor float64) {
func (p *basePhysicalAgg) getAggFuncCostFactor(isMPP bool) (factor float64) {
factor = 0.0
for _, agg := range p.AggFuncs {
if fac, ok := aggFuncFactor[agg.Name]; ok {
Expand All @@ -1018,7 +1018,15 @@ func (p *basePhysicalAgg) getAggFuncCostFactor() (factor float64) {
}
}
if factor == 0 {
factor = 1.0
if isMPP {
// The default factor 1.0 will lead to 1-phase agg in pseudo stats settings.
// But in mpp cases, 2-phase is more usual. So we change this factor.
// TODO: This is still a little tricky and might cause regression. We should
// calibrate these factors and polish our cost model in the future.
factor = aggFuncFactor[ast.AggFuncFirstRow]
} else {
factor = 1.0
}
}
return
}
Expand Down
18 changes: 9 additions & 9 deletions planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1739,7 +1739,7 @@ func (p *PhysicalStreamAgg) attach2Task(tasks ...task) task {

// GetCost computes cost of stream aggregation considering CPU/memory.
func (p *PhysicalStreamAgg) GetCost(inputRows float64, isRoot bool) float64 {
aggFuncFactor := p.getAggFuncCostFactor()
aggFuncFactor := p.getAggFuncCostFactor(false)
var cpuCost float64
sessVars := p.ctx.GetSessionVars()
if isRoot {
Expand Down Expand Up @@ -1786,7 +1786,7 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
if proj != nil {
attachPlan2Task(proj, mpp)
}
mpp.addCost(p.GetCost(inputRows, false))
mpp.addCost(p.GetCost(inputRows, false, true))
return mpp
case Mpp2Phase:
proj := p.convertAvgForMPP()
Expand Down Expand Up @@ -1817,18 +1817,18 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
attachPlan2Task(proj, newMpp)
}
// TODO: how to set 2-phase cost?
newMpp.addCost(p.GetCost(inputRows, false))
newMpp.addCost(p.GetCost(inputRows, false, true))
return newMpp
case MppTiDB:
partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, false)
if partialAgg != nil {
attachPlan2Task(partialAgg, mpp)
}
mpp.addCost(p.GetCost(inputRows, false))
mpp.addCost(p.GetCost(inputRows, false, true))
t = mpp.convertToRootTask(p.ctx)
inputRows = t.count()
attachPlan2Task(finalAgg, t)
t.addCost(p.GetCost(inputRows, true))
t.addCost(p.GetCost(inputRows, true, false))
return t
default:
return invalidTask
Expand Down Expand Up @@ -1858,7 +1858,7 @@ func (p *PhysicalHashAgg) attach2Task(tasks ...task) task {
partialAgg.SetChildren(cop.indexPlan)
cop.indexPlan = partialAgg
}
cop.addCost(p.GetCost(inputRows, false))
cop.addCost(p.GetCost(inputRows, false, false))
}
// In `newPartialAggregate`, we are using stats of final aggregation as stats
// of `partialAgg`, so the network cost of transferring result rows of `partialAgg`
Expand Down Expand Up @@ -1891,15 +1891,15 @@ func (p *PhysicalHashAgg) attach2Task(tasks ...task) task {
// hash aggregation, it would cause under-estimation as the reason mentioned in comment above.
// To make it simple, we also treat 2-phase parallel hash aggregation in TiDB layer as
// 1-phase when computing cost.
t.addCost(p.GetCost(inputRows, true))
t.addCost(p.GetCost(inputRows, true, false))
return t
}

// GetCost computes the cost of hash aggregation considering CPU/memory.
func (p *PhysicalHashAgg) GetCost(inputRows float64, isRoot bool) float64 {
func (p *PhysicalHashAgg) GetCost(inputRows float64, isRoot bool, isMPP bool) float64 {
cardinality := p.statsInfo().RowCount
numDistinctFunc := p.numDistinctFunc()
aggFuncFactor := p.getAggFuncCostFactor()
aggFuncFactor := p.getAggFuncCostFactor(isMPP)
var cpuCost float64
sessVars := p.ctx.GetSessionVars()
if isRoot {
Expand Down
20 changes: 0 additions & 20 deletions planner/core/testdata/integration_serial_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,6 @@
"explain format = 'brief' select count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)"
]
},
{
"name": "TestBroadcastJoin",
"cases": [
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t, d1_t where fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t,d2_t,d3_t) */ count(*) from fact_t, d1_t, d2_t, d3_t where fact_t.d1_k = d1_t.d1_k and fact_t.d2_k = d2_t.d2_k and fact_t.d3_k = d3_t.d3_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t), broadcast_join_local(d1_t) */ count(*) from fact_t, d1_t where fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t,d2_t,d3_t), broadcast_join_local(d2_t) */ count(*) from fact_t, d1_t, d2_t, d3_t where fact_t.d1_k = d1_t.d1_k and fact_t.d2_k = d2_t.d2_k and fact_t.d3_k = d3_t.d3_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col1 > 10",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col2 > 10 and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k and d1_t.value > 10",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t right join d1_t on fact_t.d1_k = d1_t.d1_k and d1_t.value > 10 and fact_t.col1 > d1_t.value",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where exists (select 1 from d1_t where d1_k = fact_t.d1_k)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k)",
"explain format = 'brief' select /*+ broadcast_join(fact_t,d1_t) */ count(*) from fact_t where not exists (select 1 from d1_t where d1_k = fact_t.d1_k and value > fact_t.col1)"
]
},
{
"name": "TestJoinNotSupportedByTiFlash",
"cases": [
Expand Down
Loading