diff --git a/session/nontransactional.go b/session/nontransactional.go index 61816153ec334..aceb46bc1d686 100644 --- a/session/nontransactional.go +++ b/session/nontransactional.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/types" driver "github.com/pingcap/tidb/types/parser_driver" + "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sqlexec" @@ -304,29 +305,38 @@ func buildShardJobs(ctx context.Context, stmt *ast.NonTransactionalDeleteStmt, s break } - newStart := chk.GetRow(0).GetDatum(0, &rs.Fields()[0].Column.FieldType) + if len(jobs) > 0 && chk.NumRows()+currentSize < batchSize { + // not enough data for a batch + currentSize += chk.NumRows() + newEnd := chk.GetRow(chk.NumRows()-1).GetDatum(0, &rs.Fields()[0].Column.FieldType) + currentEnd = *newEnd.Clone() + continue + } - // end last batch if: (1) current start != last end (2) current size >= batch size - if currentSize >= batchSize { - cmp, err := newStart.Compare(se.GetSessionVars().StmtCtx, ¤tEnd, collate.GetCollator(shardColumnCollate)) - if err != nil { - return nil, err + iter := chunk.NewIterator4Chunk(chk) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + if currentSize == 0 { + newStart := row.GetDatum(0, &rs.Fields()[0].Column.FieldType) + currentStart = *newStart.Clone() } - if cmp != 0 { - jobs = append(jobs, job{jobID: jobCount, start: currentStart, end: currentEnd, jobSize: currentSize}) - jobCount++ - currentSize = 0 + newEnd := row.GetDatum(0, &rs.Fields()[0].Column.FieldType) + if currentSize >= batchSize { + cmp, err := newEnd.Compare(se.GetSessionVars().StmtCtx, ¤tEnd, collate.GetCollator(shardColumnCollate)) + if err != nil { + return nil, err + } + if cmp != 0 { + jobs = append(jobs, job{jobID: jobCount, start: *currentStart.Clone(), end: *currentEnd.Clone(), jobSize: currentSize}) + jobCount++ + currentSize = 0 + currentStart = newEnd + } } + currentEnd = newEnd + currentSize++ } - - // a new batch - if currentSize == 0 { - currentStart = *newStart.Clone() - } - - currentSize += chk.NumRows() - currentEndPointer := chk.GetRow(chk.NumRows()-1).GetDatum(0, &rs.Fields()[0].Column.FieldType) - currentEnd = *currentEndPointer.Clone() + currentEnd = *currentEnd.Clone() + currentStart = *currentStart.Clone() } return jobs, nil diff --git a/session/nontransactional_test.go b/session/nontransactional_test.go index 404da858e756c..7f210cf777d30 100644 --- a/session/nontransactional_test.go +++ b/session/nontransactional_test.go @@ -24,7 +24,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestNonTransactionalDelete(t *testing.T) { +func TestNonTransactionalDeleteSharding(t *testing.T) { store, clean := createStorage(t) defer clean() tk := testkit.NewTestKit(t, store) @@ -49,30 +49,42 @@ func TestNonTransactionalDelete(t *testing.T) { "create table t(a varchar(30), b int, unique key(a, b))", "create table t(a varchar(30), b int, unique key(a))", } + tableSizes := []int{0, 1, 10, 35, 40, 100} + batchSizes := []int{1, 10, 25, 35, 50, 80, 120} for _, table := range tables { tk.MustExec("drop table if exists t") tk.MustExec(table) - for i := 0; i < 100; i++ { - tk.MustExec(fmt.Sprintf("insert into t values ('%d', %d)", i, i*2)) - } - tk.MustExec("split on a limit 3 delete from t") - tk.MustQuery("select count(*) from t").Check(testkit.Rows("0")) - - for i := 0; i < 100; i++ { - tk.MustExec(fmt.Sprintf("insert into t values ('%d', %d)", i, i*2)) - } - if strings.Contains(table, "a int") { - rows := tk.MustQuery("split on a limit 3 dry run delete from t").Rows() - for _, row := range rows { - require.True(t, strings.HasPrefix(row[0].(string), "DELETE FROM `test`.`t` WHERE `a` BETWEEN")) + for _, tableSize := range tableSizes { + for _, batchSize := range batchSizes { + for i := 0; i < tableSize; i++ { + tk.MustExec(fmt.Sprintf("insert into t values ('%d', %d)", i, i*2)) + } + tk.MustQuery(fmt.Sprintf("split on a limit %d delete from t", batchSize)).Check(testkit.Rows(fmt.Sprintf("%d all succeeded", (tableSize+batchSize-1)/batchSize))) + tk.MustQuery("select count(*) from t").Check(testkit.Rows("0")) } } - tk.MustQuery("split on a limit 3 dry run query delete from t").Check(testkit.Rows( - "SELECT `a` FROM `test`.`t` WHERE TRUE ORDER BY IF(ISNULL(`a`),0,1),`a`")) - tk.MustQuery("select count(*) from t").Check(testkit.Rows("100")) } } +func TestNonTransactionalDeleteDryRun(t *testing.T) { + store, clean := createStorage(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@tidb_max_chunk_size=35") + tk.MustExec("use test") + tk.MustExec("create table t(a int, b int, primary key(a, b) clustered)") + for i := 0; i < 100; i++ { + tk.MustExec(fmt.Sprintf("insert into t values ('%d', %d)", i, i*2)) + } + rows := tk.MustQuery("split on a limit 3 dry run delete from t").Rows() + for _, row := range rows { + require.True(t, strings.HasPrefix(row[0].(string), "DELETE FROM `test`.`t` WHERE `a` BETWEEN")) + } + tk.MustQuery("split on a limit 3 dry run query delete from t").Check(testkit.Rows( + "SELECT `a` FROM `test`.`t` WHERE TRUE ORDER BY IF(ISNULL(`a`),0,1),`a`")) + tk.MustQuery("select count(*) from t").Check(testkit.Rows("100")) +} + func TestNonTransactionalDeleteErrorMessage(t *testing.T) { store, clean := createStorage(t) defer clean()