diff --git a/executor/main_test.go b/executor/main_test.go new file mode 100644 index 0000000000000..dc50dc65feb79 --- /dev/null +++ b/executor/main_test.go @@ -0,0 +1,26 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "os" + "testing" + + "github.com/pingcap/tidb/util/testbridge" +) + +func TestMain(m *testing.M) { + testbridge.WorkaroundGoCheckFlags() + os.Exit(m.Run()) +} diff --git a/executor/write_concurrent_test.go b/executor/write_concurrent_test.go index 4eb5ded77539d..a0be46778fe75 100644 --- a/executor/write_concurrent_test.go +++ b/executor/write_concurrent_test.go @@ -15,46 +15,61 @@ package executor_test import ( "context" + "math/rand" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/tidb/config" - "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/testkit" ) -func (s *testSuite) TestBatchInsertWithOnDuplicate(c *C) { - tk := testkit.NewCTestKit(c, s.store) +func TestBatchInsertWithOnDuplicate(t *testing.T) { + t.Parallel() + + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewAsyncTestKit(t, store) // prepare schema. - ctx := tk.OpenSessionWithDB(context.Background(), "test") + ctx := tk.OpenSession(context.Background(), "test") defer tk.CloseSession(ctx) tk.MustExec(ctx, "drop table if exists duplicate_test") tk.MustExec(ctx, "create table duplicate_test(id int auto_increment, k1 int, primary key(id), unique key uk(k1))") - tk.MustExec(ctx, "insert into duplicate_test(k1) values(?),(?),(?),(?),(?)", tk.PermInt(5)...) + tk.MustExec(ctx, "insert into duplicate_test(k1) values(?),(?),(?),(?),(?)", permInt(5)...) defer config.RestoreFunc()() config.UpdateGlobal(func(conf *config.Config) { conf.EnableBatchDML = true }) - tk.ConcurrentRun(c, 3, 2, // concurrent: 3, loops: 2, + tk.ConcurrentRun( + 3, + 2, // prepare data for each loop. - func(ctx context.Context, tk *testkit.CTestKit, concurrent int, currentLoop int) [][][]interface{} { + func(ctx context.Context, tk *testkit.AsyncTestKit, concurrent int, currentLoop int) [][][]interface{} { var ii [][][]interface{} for i := 0; i < concurrent; i++ { - ii = append(ii, [][]interface{}{tk.PermInt(7)}) + ii = append(ii, [][]interface{}{permInt(7)}) } return ii }, // concurrent execute logic. - func(ctx context.Context, tk *testkit.CTestKit, input [][]interface{}) { + func(ctx context.Context, tk *testkit.AsyncTestKit, input [][]interface{}) { tk.MustExec(ctx, "set @@session.tidb_batch_insert=1") tk.MustExec(ctx, "set @@session.tidb_dml_batch_size=1") - _, err := tk.Exec(ctx, "insert ignore into duplicate_test(k1) values (?),(?),(?),(?),(?),(?),(?)", input[0]...) - tk.IgnoreError(err) + _, _ = tk.Exec(ctx, "insert ignore into duplicate_test(k1) values (?),(?),(?),(?),(?),(?),(?)", input[0]...) }, // check after all done. - func(ctx context.Context, tk *testkit.CTestKit) { + func(ctx context.Context, tk *testkit.AsyncTestKit) { tk.MustExec(ctx, "admin check table duplicate_test") - tk.MustQuery(ctx, "select d1.id, d1.k1 from duplicate_test d1 ignore index(uk), duplicate_test d2 use index (uk) where d1.id = d2.id and d1.k1 <> d2.k1"). - Check(testkit.Rows()) + tk.MustQuery(ctx, "select d1.id, d1.k1 from duplicate_test d1 ignore index(uk), duplicate_test d2 use index (uk) where d1.id = d2.id and d1.k1 <> d2.k1").Check(testkit.Rows()) }) } + +func permInt(n int) []interface{} { + randPermSlice := rand.Perm(n) + v := make([]interface{}, 0, len(randPermSlice)) + for _, i := range randPermSlice { + v = append(v, i) + } + return v +} diff --git a/testkit/asynctestkit.go b/testkit/asynctestkit.go new file mode 100644 index 0000000000000..ffa1d86429643 --- /dev/null +++ b/testkit/asynctestkit.go @@ -0,0 +1,226 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !codes + +package testkit + +import ( + "context" + "fmt" + "testing" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/sqlexec" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" +) + +var asyncTestKitIDGenerator atomic.Uint64 + +// AsyncTestKit is a utility to run sql concurrently. +type AsyncTestKit struct { + require *require.Assertions + assert *assert.Assertions + store kv.Storage +} + +// NewAsyncTestKit returns a new *AsyncTestKit. +func NewAsyncTestKit(t *testing.T, store kv.Storage) *AsyncTestKit { + return &AsyncTestKit{ + require: require.New(t), + assert: assert.New(t), + store: store, + } +} + +// OpenSession opens new session ctx if no exists one and use db. +func (tk *AsyncTestKit) OpenSession(ctx context.Context, db string) context.Context { + if tryRetrieveSession(ctx) == nil { + se, err := session.CreateSession4Test(tk.store) + tk.require.NoError(err) + se.SetConnectionID(asyncTestKitIDGenerator.Inc()) + ctx = context.WithValue(ctx, sessionKey, se) + } + tk.MustExec(ctx, fmt.Sprintf("use %s", db)) + return ctx +} + +// CloseSession closes exists session from ctx. +func (tk *AsyncTestKit) CloseSession(ctx context.Context) { + se := tryRetrieveSession(ctx) + tk.require.NotNil(se) + se.Close() +} + +// ConcurrentRun run test in current. +// - concurrent: controls the concurrent worker count. +// - loops: controls run test how much times. +// - prepareFunc: provide test data and will be called for every loop. +// - checkFunc: used to do some check after all workers done. +// works like create table better be put in front of this method calling. +// see more example at TestBatchInsertWithOnDuplicate +func (tk *AsyncTestKit) ConcurrentRun( + concurrent int, + loops int, + prepareFunc func(ctx context.Context, tk *AsyncTestKit, concurrent int, currentLoop int) [][][]interface{}, + writeFunc func(ctx context.Context, tk *AsyncTestKit, input [][]interface{}), + checkFunc func(ctx context.Context, tk *AsyncTestKit), +) { + channel := make([]chan [][]interface{}, concurrent) + contextList := make([]context.Context, concurrent) + doneList := make([]context.CancelFunc, concurrent) + + for i := 0; i < concurrent; i++ { + w := i + channel[w] = make(chan [][]interface{}, 1) + contextList[w], doneList[w] = context.WithCancel(context.Background()) + contextList[w] = tk.OpenSession(contextList[w], "test") + go func() { + defer func() { + r := recover() + tk.require.Nil(r, string(util.GetStack())) + doneList[w]() + }() + + for input := range channel[w] { + writeFunc(contextList[w], tk, input) + } + }() + } + + defer func() { + for i := 0; i < concurrent; i++ { + tk.CloseSession(contextList[i]) + } + }() + + ctx := tk.OpenSession(context.Background(), "test") + defer tk.CloseSession(ctx) + tk.MustExec(ctx, "use test") + + for j := 0; j < loops; j++ { + data := prepareFunc(ctx, tk, concurrent, j) + for i := 0; i < concurrent; i++ { + channel[i] <- data[i] + } + } + + for i := 0; i < concurrent; i++ { + close(channel[i]) + } + + for i := 0; i < concurrent; i++ { + <-contextList[i].Done() + } + checkFunc(ctx, tk) +} + +// Exec executes a sql statement. +func (tk *AsyncTestKit) Exec(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) { + se := tryRetrieveSession(ctx) + tk.require.NotNil(se) + + if len(args) == 0 { + rss, err := se.Execute(ctx, sql) + if err == nil && len(rss) > 0 { + return rss[0], nil + } + return nil, err + } + + stmtID, _, _, err := se.PrepareStmt(sql) + if err != nil { + return nil, err + } + + params := make([]types.Datum, len(args)) + for i := 0; i < len(params); i++ { + params[i] = types.NewDatum(args[i]) + } + + rs, err := se.ExecutePreparedStmt(ctx, stmtID, params) + if err != nil { + return nil, err + } + + err = se.DropPreparedStmt(stmtID) + if err != nil { + return nil, err + } + + return rs, nil +} + +// MustExec executes a sql statement and asserts nil error. +func (tk *AsyncTestKit) MustExec(ctx context.Context, sql string, args ...interface{}) { + res, err := tk.Exec(ctx, sql, args...) + tk.require.NoErrorf(err, "sql:%s, %v, error stack %v", sql, args, errors.ErrorStack(err)) + if res != nil { + tk.require.NoError(res.Close()) + } +} + +// MustQuery query the statements and returns result rows. +// If expected result is set it asserts the query result equals expected result. +func (tk *AsyncTestKit) MustQuery(ctx context.Context, sql string, args ...interface{}) *Result { + comment := fmt.Sprintf("sql:%s, args:%v", sql, args) + rs, err := tk.Exec(ctx, sql, args...) + tk.require.NoError(err, comment) + tk.require.NotNil(rs, comment) + return tk.resultSetToResult(ctx, rs, comment) +} + +// resultSetToResult converts ast.RecordSet to testkit.Result. +// It is used to check results of execute statement in binary mode. +func (tk *AsyncTestKit) resultSetToResult(ctx context.Context, rs sqlexec.RecordSet, comment string) *Result { + rows, err := session.GetRows4Test(context.Background(), tryRetrieveSession(ctx), rs) + tk.require.NoError(err, comment) + + err = rs.Close() + tk.require.NoError(err, comment) + + result := make([][]string, len(rows)) + for i := range rows { + row := rows[i] + resultRow := make([]string, row.Len()) + for j := 0; j < row.Len(); j++ { + if row.IsNull(j) { + resultRow[j] = "" + } else { + d := row.GetDatum(j, &rs.Fields()[j].Column.FieldType) + resultRow[j], err = d.ToString() + tk.require.NoError(err, comment) + } + } + result[i] = resultRow + } + return &Result{rows: result, comment: comment, assert: tk.assert, require: tk.require} +} + +type sessionCtxKeyType struct{} + +var sessionKey = sessionCtxKeyType{} + +func tryRetrieveSession(ctx context.Context) session.Session { + s := ctx.Value(sessionKey) + if s == nil { + return nil + } + return s.(session.Session) +} diff --git a/testkit/mockstore.go b/testkit/mockstore.go new file mode 100644 index 0000000000000..7bad3b1caeb3b --- /dev/null +++ b/testkit/mockstore.go @@ -0,0 +1,46 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !codes + +package testkit + +import ( + "testing" + + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/store/mockstore" + "github.com/stretchr/testify/require" +) + +// CreateMockStore return a new mock kv.Storage. +func CreateMockStore(t *testing.T) (store kv.Storage, clean func()) { + store, err := mockstore.NewMockStore() + require.NoError(t, err) + + session.SetSchemaLease(0) + session.DisableStats4Test() + d, err := session.BootstrapSession(store) + require.NoError(t, err) + + d.SetStatsUpdating(true) + + clean = func() { + d.Close() + err := store.Close() + require.NoError(t, err) + } + + return +} diff --git a/testkit/testkit.go b/testkit/testkit.go index 3623b519a3afc..93414bddcc625 100644 --- a/testkit/testkit.go +++ b/testkit/testkit.go @@ -30,7 +30,7 @@ import ( "go.uber.org/atomic" ) -var idGenerator atomic.Uint64 +var testKitIDGenerator atomic.Uint64 // TestKit is a utility to run sql test. type TestKit struct { @@ -140,6 +140,6 @@ func (tk *TestKit) Exec(sql string, args ...interface{}) (sqlexec.RecordSet, err func newSession(t *testing.T, store kv.Storage) session.Session { se, err := session.CreateSession4Test(store) require.Nil(t, err) - se.SetConnectionID(idGenerator.Inc()) + se.SetConnectionID(testKitIDGenerator.Inc()) return se } diff --git a/util/admin/admin_integration_test.go b/util/admin/admin_integration_test.go index f61085e8d05c7..a19a31c9d6410 100644 --- a/util/admin/admin_integration_test.go +++ b/util/admin/admin_integration_test.go @@ -17,19 +17,14 @@ import ( "strconv" "testing" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/testkit" - "github.com/stretchr/testify/require" - "github.com/tikv/client-go/v2/testutils" ) func TestAdminCheckTable(t *testing.T) { t.Parallel() - store, clean := newIntegrationMockStore(t) + store, clean := testkit.CreateMockStore(t) defer clean() // test NULL value. @@ -82,7 +77,7 @@ func TestAdminCheckTable(t *testing.T) { func TestAdminCheckTableClusterIndex(t *testing.T) { t.Parallel() - store, clean := newIntegrationMockStore(t) + store, clean := testkit.CreateMockStore(t) defer clean() tk := testkit.NewTestKit(t, store) @@ -116,27 +111,3 @@ func TestAdminCheckTableClusterIndex(t *testing.T) { tk.MustExec("insert into t values (1000, '1000', 1000, '1000', '1000');") tk.MustExec("admin check table t;") } - -func newIntegrationMockStore(t *testing.T) (store kv.Storage, clean func()) { - store, err := mockstore.NewMockStore( - mockstore.WithClusterInspector(func(c testutils.Cluster) { - mockstore.BootstrapWithSingleStore(c) - }), - ) - require.NoError(t, err) - - session.SetSchemaLease(0) - session.DisableStats4Test() - d, err := session.BootstrapSession(store) - require.NoError(t, err) - - d.SetStatsUpdating(true) - - clean = func() { - d.Close() - err := store.Close() - require.NoError(t, err) - } - - return -} diff --git a/util/testkit/ctestkit.go b/util/testkit/ctestkit.go deleted file mode 100644 index d3f9368861e2d..0000000000000 --- a/util/testkit/ctestkit.go +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build !codes - -package testkit - -import ( - "context" - "math/rand" - "sync/atomic" - - "github.com/pingcap/check" - "github.com/pingcap/errors" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/session" - "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util" - "github.com/pingcap/tidb/util/sqlexec" -) - -type sessionCtxKeyType struct{} - -var sessionKey = sessionCtxKeyType{} - -func getSession(ctx context.Context) session.Session { - s := ctx.Value(sessionKey) - if s == nil { - return nil - } - return s.(session.Session) -} - -func setSession(ctx context.Context, se session.Session) context.Context { - return context.WithValue(ctx, sessionKey, se) -} - -// CTestKit is a utility to run sql test with concurrent execution support. -type CTestKit struct { - c *check.C - store kv.Storage -} - -// NewCTestKit returns a new *CTestKit. -func NewCTestKit(c *check.C, store kv.Storage) *CTestKit { - return &CTestKit{ - c: c, - store: store, - } -} - -// OpenSession opens new session ctx if no exists one. -func (tk *CTestKit) OpenSession(ctx context.Context) context.Context { - if getSession(ctx) == nil { - se, err := session.CreateSession4Test(tk.store) - tk.c.Assert(err, check.IsNil) - id := atomic.AddUint64(&connectionID, 1) - se.SetConnectionID(id) - ctx = setSession(ctx, se) - } - return ctx -} - -// OpenSessionWithDB opens new session ctx if no exists one and use db. -func (tk *CTestKit) OpenSessionWithDB(ctx context.Context, db string) context.Context { - ctx = tk.OpenSession(ctx) - tk.MustExec(ctx, "use "+db) - return ctx -} - -// CloseSession closes exists session from ctx. -func (tk *CTestKit) CloseSession(ctx context.Context) { - se := getSession(ctx) - tk.c.Assert(se, check.NotNil) - se.Close() -} - -// Exec executes a sql statement. -func (tk *CTestKit) Exec(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) { - var err error - tk.c.Assert(getSession(ctx), check.NotNil) - if len(args) == 0 { - var rss []sqlexec.RecordSet - rss, err = getSession(ctx).Execute(ctx, sql) - if err == nil && len(rss) > 0 { - return rss[0], nil - } - return nil, err - } - stmtID, _, _, err := getSession(ctx).PrepareStmt(sql) - if err != nil { - return nil, err - } - params := make([]types.Datum, len(args)) - for i := 0; i < len(params); i++ { - params[i] = types.NewDatum(args[i]) - } - rs, err := getSession(ctx).ExecutePreparedStmt(ctx, stmtID, params) - if err != nil { - return nil, err - } - err = getSession(ctx).DropPreparedStmt(stmtID) - if err != nil { - return nil, err - } - return rs, nil -} - -// CheckExecResult checks the affected rows and the insert id after executing MustExec. -func (tk *CTestKit) CheckExecResult(ctx context.Context, affectedRows, insertID int64) { - tk.c.Assert(getSession(ctx), check.NotNil) - tk.c.Assert(affectedRows, check.Equals, int64(getSession(ctx).AffectedRows())) - tk.c.Assert(insertID, check.Equals, int64(getSession(ctx).LastInsertID())) -} - -// MustExec executes a sql statement and asserts nil error. -func (tk *CTestKit) MustExec(ctx context.Context, sql string, args ...interface{}) { - res, err := tk.Exec(ctx, sql, args...) - tk.c.Assert(err, check.IsNil, check.Commentf("sql:%s, %v, error stack %v", sql, args, errors.ErrorStack(err))) - if res != nil { - tk.c.Assert(res.Close(), check.IsNil) - } -} - -// MustQuery query the statements and returns result rows. -// If expected result is set it asserts the query result equals expected result. -func (tk *CTestKit) MustQuery(ctx context.Context, sql string, args ...interface{}) *Result { - comment := check.Commentf("sql:%s, args:%v", sql, args) - rs, err := tk.Exec(ctx, sql, args...) - tk.c.Assert(errors.ErrorStack(err), check.Equals, "", comment) - tk.c.Assert(rs, check.NotNil, comment) - return tk.resultSetToResult(ctx, rs, comment) -} - -// resultSetToResult converts ast.RecordSet to testkit.Result. -// It is used to check results of execute statement in binary mode. -func (tk *CTestKit) resultSetToResult(ctx context.Context, rs sqlexec.RecordSet, comment check.CommentInterface) *Result { - rows, err := session.GetRows4Test(context.Background(), getSession(ctx), rs) - tk.c.Assert(errors.ErrorStack(err), check.Equals, "", comment) - err = rs.Close() - tk.c.Assert(errors.ErrorStack(err), check.Equals, "", comment) - sRows := make([][]string, len(rows)) - for i := range rows { - row := rows[i] - iRow := make([]string, row.Len()) - for j := 0; j < row.Len(); j++ { - if row.IsNull(j) { - iRow[j] = "" - } else { - d := row.GetDatum(j, &rs.Fields()[j].Column.FieldType) - iRow[j], err = d.ToString() - tk.c.Assert(err, check.IsNil) - } - } - sRows[i] = iRow - } - return &Result{rows: sRows, c: tk.c, comment: comment} -} - -// ConcurrentRun run test in current. -// - concurrent: controls the concurrent worker count. -// - loops: controls run test how much times. -// - prepareFunc: provide test data and will be called for every loop. -// - checkFunc: used to do some check after all workers done. -// works like create table better be put in front of this method calling. -// see more example at TestBatchInsertWithOnDuplicate -func (tk *CTestKit) ConcurrentRun(c *check.C, concurrent int, loops int, - prepareFunc func(ctx context.Context, tk *CTestKit, concurrent int, currentLoop int) [][][]interface{}, - writeFunc func(ctx context.Context, tk *CTestKit, input [][]interface{}), - checkFunc func(ctx context.Context, tk *CTestKit)) { - var ( - channel = make([]chan [][]interface{}, concurrent) - ctxs = make([]context.Context, concurrent) - dones = make([]context.CancelFunc, concurrent) - ) - for i := 0; i < concurrent; i++ { - w := i - channel[w] = make(chan [][]interface{}, 1) - ctxs[w], dones[w] = context.WithCancel(context.Background()) - ctxs[w] = tk.OpenSessionWithDB(ctxs[w], "test") - go func() { - defer func() { - r := recover() - if r != nil { - c.Fatal(r, string(util.GetStack())) - } - dones[w]() - }() - for input := range channel[w] { - writeFunc(ctxs[w], tk, input) - } - }() - } - defer func() { - for i := 0; i < concurrent; i++ { - tk.CloseSession(ctxs[i]) - } - }() - - ctx := tk.OpenSessionWithDB(context.Background(), "test") - defer tk.CloseSession(ctx) - tk.MustExec(ctx, "use test") - - for j := 0; j < loops; j++ { - datas := prepareFunc(ctx, tk, concurrent, j) - for i := 0; i < concurrent; i++ { - channel[i] <- datas[i] - } - } - - for i := 0; i < concurrent; i++ { - close(channel[i]) - } - - for i := 0; i < concurrent; i++ { - <-ctxs[i].Done() - } - checkFunc(ctx, tk) -} - -// PermInt returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n). -func (tk *CTestKit) PermInt(n int) []interface{} { - randPermSlice := rand.Perm(n) - v := make([]interface{}, 0, len(randPermSlice)) - for _, i := range randPermSlice { - v = append(v, i) - } - return v -} - -// IgnoreError ignores error and make errcheck tool happy. -// Deprecated: it's normal to ignore some error in concurrent test, but please don't use this method in other place. -func (tk *CTestKit) IgnoreError(_ error) {}