From 789564b4a0ac94b34cdba60b3c48aaedc756a664 Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Tue, 31 Oct 2023 11:54:37 +0800 Subject: [PATCH] planner: fix wrong result when pushing Agg down through Union in MPP plans (#46310) (#46515) close pingcap/tidb#45850 --- executor/tiflash_test.go | 42 +++++++++++++++++++ planner/core/enforce_mpp_test.go | 9 +++- planner/core/exhaust_physical_plans.go | 13 ++++++ .../core/testdata/enforce_mpp_suite_in.json | 3 +- .../core/testdata/enforce_mpp_suite_out.json | 40 ++++++++++++++---- 5 files changed, 97 insertions(+), 10 deletions(-) diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index 999ea336acffa..979f3333f68e1 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -1248,6 +1248,48 @@ func TestAggPushDownCountStar(t *testing.T) { tk.MustQuery("select count(*) from c, o where c.c_id=o.c_id").Check(testkit.Rows("5")) } +func TestAggPushDownUnionAndMPP(t *testing.T) { + store, clean := testkit.CreateMockStore(t, withMockTiFlash(2)) + defer clean() + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("create table t (a int, b int)") + tk.MustExec("alter table t set tiflash replica 1") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("set @@tidb_allow_mpp=1;") + tk.MustExec("set @@tidb_enforce_mpp=1;") + tk.MustExec("set @@tidb_opt_agg_push_down=1") + tb := external.GetTableByName(t, tk, "test", "t") + err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + + tk.MustExec("create table c(c_id int)") + tk.MustExec("create table o(o_id int, c_id int)") + tk.MustExec("insert into c values(1),(1),(1),(1)") + tk.MustExec("insert into o values(1,1),(1,1),(1,2)") + tk.MustExec("alter table c set tiflash replica 1") + tk.MustExec("alter table o set tiflash replica 1") + tb = external.GetTableByName(t, tk, "test", "c") + err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustExec("alter table o set tiflash replica 1") + tb = external.GetTableByName(t, tk, "test", "o") + err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + + tk.MustQuery("select a, count(*) from (select a, b from t " + + "union all " + + "select a, b from t" + + ") t group by a order by a limit 10;").Check(testkit.Rows("1 10")) + + tk.MustQuery("select o.o_id, count(*) from c, o where c.c_id=o.o_id group by o.o_id").Check(testkit.Rows("1 12")) +} + func TestGroupStreamAggOnTiFlash(t *testing.T) { store, clean := testkit.CreateMockStore(t, withMockTiFlash(2)) defer clean() diff --git a/planner/core/enforce_mpp_test.go b/planner/core/enforce_mpp_test.go index f839d53bfa7ff..7e0979fb2707a 100644 --- a/planner/core/enforce_mpp_test.go +++ b/planner/core/enforce_mpp_test.go @@ -405,13 +405,20 @@ func TestMPP2PhaseAggPushDown(t *testing.T) { tk.MustExec("create table c(c_id bigint)") tk.MustExec("create table o(o_id bigint, c_id bigint not null)") + tk.MustExec("create table t (a int, b int)") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + tk.MustExec("insert into t values (1, 1);") + // Create virtual tiflash replica info. dom := domain.GetDomain(tk.Session()) is := dom.InfoSchema() db, exists := is.SchemaByName(model.NewCIStr("test")) require.True(t, exists) for _, tblInfo := range db.Tables { - if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" { + if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" || tblInfo.Name.L == "t" { tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ Count: 1, Available: true, diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 5db8b3a030a84..fca60a8f9435c 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -2687,6 +2687,18 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert // Is this aggregate a final stage aggregate? // Final agg can't be split into multi-stage aggregate hasFinalAgg := len(la.AggFuncs) > 0 && la.AggFuncs[0].Mode == aggregation.FinalMode + // count final agg should become sum for MPP execution path. + // In the traditional case, TiDB take up the final agg role and push partial agg to TiKV, + // while TiDB can tell the partialMode and do the sum computation rather than counting but MPP doesn't + finalAggAdjust := func(aggFuncs []*aggregation.AggFuncDesc) { + for i, agg := range aggFuncs { + if agg.Mode == aggregation.FinalMode && agg.Name == ast.AggFuncCount { + oldFt := agg.RetTp + aggFuncs[i], _ = aggregation.NewAggFuncDesc(la.SCtx(), ast.AggFuncSum, agg.Args, false) + aggFuncs[i].RetTp = oldFt + } + } + } if len(la.GroupByItems) > 0 { partitionCols := la.GetPotentialPartitionKeys() @@ -2710,6 +2722,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp) agg.SetSchema(la.schema.Clone()) agg.MppRunMode = Mpp1Phase + finalAggAdjust(agg.AggFuncs) hashAggs = append(hashAggs, agg) } diff --git a/planner/core/testdata/enforce_mpp_suite_in.json b/planner/core/testdata/enforce_mpp_suite_in.json index 3a46d1fdcc930..2e2b08cb74044 100644 --- a/planner/core/testdata/enforce_mpp_suite_in.json +++ b/planner/core/testdata/enforce_mpp_suite_in.json @@ -93,7 +93,8 @@ "set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;set @@tidb_opt_agg_push_down=1;", "EXPLAIN select count(*) from c, o where c.c_id=o.c_id; -- 1. test agg push down, scalar aggregate", "EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column", - "EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column" + "EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column", + "EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10" ] } ] diff --git a/planner/core/testdata/enforce_mpp_suite_out.json b/planner/core/testdata/enforce_mpp_suite_out.json index 4ee73906dd0cf..9d30b25248207 100644 --- a/planner/core/testdata/enforce_mpp_suite_out.json +++ b/planner/core/testdata/enforce_mpp_suite_out.json @@ -676,11 +676,11 @@ { "SQL": "EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column", "Plan": [ - "TableReader_78 8000.00 root data:ExchangeSender_77", - "└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough", + "TableReader_84 8000.00 root data:ExchangeSender_83", + "└─ExchangeSender_83 8000.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_10 8000.00 mpp[tiflash] test.o.o_id, Column#6", - " └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.o_id", - " └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.o_id", + " └─Projection_77 8000.00 mpp[tiflash] Column#6, test.o.o_id", + " └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:sum(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.o_id", " └─ExchangeReceiver_71 9990.00 mpp[tiflash] ", " └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.o_id, collate: binary]", " └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", @@ -699,11 +699,11 @@ { "SQL": "EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column", "Plan": [ - "TableReader_78 8000.00 root data:ExchangeSender_77", - "└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough", + "TableReader_84 8000.00 root data:ExchangeSender_83", + "└─ExchangeSender_83 8000.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_10 8000.00 mpp[tiflash] test.o.c_id, Column#6", - " └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.c_id", - " └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.c_id", + " └─Projection_77 8000.00 mpp[tiflash] Column#6, test.o.c_id", + " └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.c_id", " └─ExchangeReceiver_71 9990.00 mpp[tiflash] ", " └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.c_id, collate: binary]", " └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", @@ -718,6 +718,30 @@ " └─TableFullScan_18 10000.00 mpp[tiflash] table:c keep order:false, stats:pseudo" ], "Warn": null + }, + { + "SQL": "EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10", + "Plan": [ + "Projection 10.00 root Column#7, Column#9", + "└─TopN 10.00 root Column#7, offset:0, count:10", + " └─TableReader 10.00 root data:ExchangeSender", + " └─ExchangeSender 10.00 mpp[tiflash] ExchangeType: PassThrough", + " └─TopN 10.00 mpp[tiflash] Column#7, offset:0, count:10", + " └─Projection 16000.00 mpp[tiflash] Column#9, Column#7", + " └─HashAgg 16000.00 mpp[tiflash] group by:Column#7, funcs:sum(Column#10)->Column#9, funcs:firstrow(Column#11)->Column#7", + " └─ExchangeReceiver 16000.00 mpp[tiflash] ", + " └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: Column#7, collate: binary]", + " └─Union 16000.00 mpp[tiflash] ", + " ├─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7", + " │ └─ExchangeReceiver 10000.00 mpp[tiflash] ", + " │ └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary]", + " │ └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo", + " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7", + " └─ExchangeReceiver 10000.00 mpp[tiflash] ", + " └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.a, collate: binary]", + " └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ], + "Warn": null } ] }