Skip to content

Commit e07cf40

Browse files
authored
planner: fix wrong result when pushing Agg down through Union in MPP plans (#46310) (#46517)
close #45850
1 parent 715048e commit e07cf40

File tree

6 files changed

+116
-37
lines changed

6 files changed

+116
-37
lines changed

executor/tiflashtest/BUILD.bazel

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ go_test(
99
],
1010
flaky = True,
1111
race = "on",
12-
shard_count = 37,
12+
shard_count = 38,
1313
deps = [
1414
"//config",
1515
"//domain",

executor/tiflashtest/tiflash_test.go

+29
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,35 @@ func TestAggPushDownCountStar(t *testing.T) {
12411241
tk.MustQuery("select count(*) from c, o where c.c_id=o.c_id").Check(testkit.Rows("5"))
12421242
}
12431243

1244+
func TestAggPushDownUnionAndMPP(t *testing.T) {
1245+
store := testkit.CreateMockStore(t, withMockTiFlash(2))
1246+
tk := testkit.NewTestKit(t, store)
1247+
1248+
tk.MustExec("use test")
1249+
tk.MustExec("create table t (a int, b int)")
1250+
tk.MustExec("alter table t set tiflash replica 1")
1251+
tk.MustExec("insert into t values (1, 1);")
1252+
tk.MustExec("insert into t values (1, 1);")
1253+
tk.MustExec("insert into t values (1, 1);")
1254+
tk.MustExec("insert into t values (1, 1);")
1255+
tk.MustExec("insert into t values (1, 1);")
1256+
tk.MustExec("set @@tidb_allow_mpp=1;")
1257+
tk.MustExec("set @@tidb_enforce_mpp=1;")
1258+
tk.MustExec("set @@tidb_opt_agg_push_down=1")
1259+
1260+
tk.MustExec("create table c(c_id int)")
1261+
tk.MustExec("create table o(o_id int, c_id int)")
1262+
tk.MustExec("insert into c values(1),(1),(1),(1)")
1263+
tk.MustExec("insert into o values(1,1),(1,1),(1,2)")
1264+
tk.MustExec("alter table c set tiflash replica 1")
1265+
tk.MustExec("alter table o set tiflash replica 1")
1266+
1267+
tk.MustQuery("select a, count(1) from (select a, b from t union all select a, " +
1268+
"b from t) s group by a order by a").Check(testkit.Rows("1 10"))
1269+
1270+
tk.MustQuery("select o.o_id, count(*) from c, o where c.c_id=o.o_id group by o.o_id").Check(testkit.Rows("1 12"))
1271+
}
1272+
12441273
func TestGroupStreamAggOnTiFlash(t *testing.T) {
12451274
store := testkit.CreateMockStore(t, withMockTiFlash(2))
12461275
tk := testkit.NewTestKit(t, store)

planner/core/casetest/enforce_mpp_test.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -345,13 +345,20 @@ func TestMPP2PhaseAggPushDown(t *testing.T) {
345345
tk.MustExec("create table c(c_id bigint)")
346346
tk.MustExec("create table o(o_id bigint, c_id bigint not null)")
347347

348+
tk.MustExec("create table t (a int, b int)")
349+
tk.MustExec("insert into t values (1, 1);")
350+
tk.MustExec("insert into t values (1, 1);")
351+
tk.MustExec("insert into t values (1, 1);")
352+
tk.MustExec("insert into t values (1, 1);")
353+
tk.MustExec("insert into t values (1, 1);")
354+
348355
// Create virtual tiflash replica info.
349356
dom := domain.GetDomain(tk.Session())
350357
is := dom.InfoSchema()
351358
db, exists := is.SchemaByName(model.NewCIStr("test"))
352359
require.True(t, exists)
353360
for _, tblInfo := range db.Tables {
354-
if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" {
361+
if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" || tblInfo.Name.L == "t" {
355362
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
356363
Count: 1,
357364
Available: true,

planner/core/casetest/testdata/enforce_mpp_suite_in.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@
9292
"set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;set @@tidb_opt_agg_push_down=1;",
9393
"EXPLAIN select count(*) from c, o where c.c_id=o.c_id; -- 1. test agg push down, scalar aggregate",
9494
"EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column",
95-
"EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column"
95+
"EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column",
96+
"EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10"
9697
]
9798
},
9899
{

planner/core/casetest/testdata/enforce_mpp_suite_out.json

+65-34
Original file line numberDiff line numberDiff line change
@@ -658,48 +658,79 @@
658658
{
659659
"SQL": "EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column",
660660
"Plan": [
661-
"TableReader_78 8000.00 root MppVersion: 1, data:ExchangeSender_77",
662-
"└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough",
661+
"TableReader_84 8000.00 root MppVersion: 1, data:ExchangeSender_83",
662+
"└─ExchangeSender_83 8000.00 mpp[tiflash] ExchangeType: PassThrough",
663663
" └─Projection_10 8000.00 mpp[tiflash] test.o.o_id, Column#6",
664-
" └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.o_id",
665-
" └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.o_id",
666-
" └─ExchangeReceiver_71 9990.00 mpp[tiflash] ",
667-
" └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary]",
668-
" └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
669-
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
670-
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
671-
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id",
672-
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id",
673-
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
674-
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]",
675-
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#9",
676-
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
677-
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
678-
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
664+
" └─Projection_79 8000.00 mpp[tiflash] Column#6, test.o.o_id",
665+
" └─HashAgg_80 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:sum(Column#25)->Column#6, funcs:firstrow(Column#26)->test.o.o_id",
666+
" └─ExchangeReceiver_82 8000.00 mpp[tiflash] ",
667+
" └─ExchangeSender_81 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary]",
668+
" └─HashAgg_76 8000.00 mpp[tiflash] group by:Column#29, funcs:sum(Column#27)->Column#25, funcs:firstrow(Column#28)->Column#26",
669+
" └─Projection_85 9990.00 mpp[tiflash] cast(Column#7, decimal(20,0) BINARY)->Column#27, Column#8, test.o.o_id",
670+
" └─HashJoin_78 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
671+
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
672+
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
673+
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id",
674+
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id",
675+
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
676+
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]",
677+
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#9",
678+
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
679+
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
680+
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
679681
],
680682
"Warn": null
681683
},
682684
{
683685
"SQL": "EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column",
684686
"Plan": [
685-
"TableReader_78 8000.00 root MppVersion: 1, data:ExchangeSender_77",
686-
"└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough",
687+
"TableReader_84 8000.00 root MppVersion: 1, data:ExchangeSender_83",
688+
"└─ExchangeSender_83 8000.00 mpp[tiflash] ExchangeType: PassThrough",
687689
" └─Projection_10 8000.00 mpp[tiflash] test.o.c_id, Column#6",
688-
" └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.c_id",
689-
" └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.c_id",
690-
" └─ExchangeReceiver_71 9990.00 mpp[tiflash] ",
691-
" └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
692-
" └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
693-
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
694-
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
695-
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id",
696-
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id",
697-
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
698-
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
699-
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#9",
700-
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
701-
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
702-
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
690+
" └─Projection_79 8000.00 mpp[tiflash] Column#6, test.o.c_id",
691+
" └─HashAgg_80 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#21)->Column#6, funcs:firstrow(Column#22)->test.o.c_id",
692+
" └─ExchangeReceiver_82 8000.00 mpp[tiflash] ",
693+
" └─ExchangeSender_81 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
694+
" └─HashAgg_76 8000.00 mpp[tiflash] group by:Column#25, funcs:sum(Column#23)->Column#21, funcs:firstrow(Column#24)->Column#22",
695+
" └─Projection_85 9990.00 mpp[tiflash] cast(Column#7, decimal(20,0) BINARY)->Column#23, Column#8, test.o.c_id",
696+
" └─HashJoin_78 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
697+
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
698+
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
699+
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id",
700+
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id",
701+
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
702+
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
703+
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#9",
704+
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
705+
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
706+
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
707+
],
708+
"Warn": null
709+
},
710+
{
711+
"SQL": "EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10",
712+
"Plan": [
713+
"Projection 10.00 root Column#7, Column#9",
714+
"└─TopN 10.00 root Column#7, offset:0, count:10",
715+
" └─TableReader 10.00 root MppVersion: 1, data:ExchangeSender",
716+
" └─ExchangeSender 10.00 mpp[tiflash] ExchangeType: PassThrough",
717+
" └─TopN 10.00 mpp[tiflash] Column#7, offset:0, count:10",
718+
" └─Projection 16000.00 mpp[tiflash] Column#9, Column#7",
719+
" └─HashAgg 16000.00 mpp[tiflash] group by:Column#40, funcs:sum(Column#38)->Column#9, funcs:firstrow(Column#39)->Column#7",
720+
" └─Projection 16000.00 mpp[tiflash] cast(Column#10, decimal(20,0) BINARY)->Column#38, Column#11, Column#7",
721+
" └─ExchangeReceiver 16000.00 mpp[tiflash] ",
722+
" └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#7, collate: binary]",
723+
" └─Union 16000.00 mpp[tiflash] ",
724+
" ├─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#30)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7",
725+
" │ └─ExchangeReceiver 8000.00 mpp[tiflash] ",
726+
" │ └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]",
727+
" │ └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#30",
728+
" │ └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo",
729+
" └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#33)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7",
730+
" └─ExchangeReceiver 8000.00 mpp[tiflash] ",
731+
" └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]",
732+
" └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#33",
733+
" └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo"
703734
],
704735
"Warn": null
705736
}

planner/core/exhaust_physical_plans.go

+11
Original file line numberDiff line numberDiff line change
@@ -3151,6 +3151,16 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
31513151
// Is this aggregate a final stage aggregate?
31523152
// Final agg can't be split into multi-stage aggregate
31533153
hasFinalAgg := len(la.AggFuncs) > 0 && la.AggFuncs[0].Mode == aggregation.FinalMode
3154+
// count final agg should become sum for MPP execution path.
3155+
// In the traditional case, TiDB take up the final agg role and push partial agg to TiKV,
3156+
// while TiDB can tell the partialMode and do the sum computation rather than counting but MPP doesn't
3157+
finalAggAdjust := func(aggFuncs []*aggregation.AggFuncDesc) {
3158+
for i, agg := range aggFuncs {
3159+
if agg.Mode == aggregation.FinalMode && agg.Name == ast.AggFuncCount {
3160+
aggFuncs[i], _ = aggregation.NewAggFuncDesc(la.SCtx(), ast.AggFuncSum, agg.Args, false)
3161+
}
3162+
}
3163+
}
31543164

31553165
if len(la.GroupByItems) > 0 {
31563166
partitionCols := la.GetPotentialPartitionKeys()
@@ -3176,6 +3186,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
31763186
agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp)
31773187
agg.SetSchema(la.schema.Clone())
31783188
agg.MppRunMode = Mpp1Phase
3189+
finalAggAdjust(agg.AggFuncs)
31793190
hashAggs = append(hashAggs, agg)
31803191
}
31813192

0 commit comments

Comments
 (0)