Skip to content

Commit

Permalink
executor: migrate TestBatchInsertWithOnDuplicate to testify (#26712)
Browse files Browse the repository at this point in the history
  • Loading branch information
tisonkun authored Aug 4, 2021
1 parent a0d0b48 commit f9652b2
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 291 deletions.
26 changes: 26 additions & 0 deletions executor/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package executor

import (
"os"
"testing"

"github.com/pingcap/tidb/util/testbridge"
)

func TestMain(m *testing.M) {
testbridge.WorkaroundGoCheckFlags()
os.Exit(m.Run())
}
45 changes: 30 additions & 15 deletions executor/write_concurrent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,61 @@ package executor_test

import (
"context"
"math/rand"
"testing"

. "github.com/pingcap/check"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/testkit"
)

func (s *testSuite) TestBatchInsertWithOnDuplicate(c *C) {
tk := testkit.NewCTestKit(c, s.store)
func TestBatchInsertWithOnDuplicate(t *testing.T) {
t.Parallel()

store, clean := testkit.CreateMockStore(t)
defer clean()

tk := testkit.NewAsyncTestKit(t, store)
// prepare schema.
ctx := tk.OpenSessionWithDB(context.Background(), "test")
ctx := tk.OpenSession(context.Background(), "test")
defer tk.CloseSession(ctx)
tk.MustExec(ctx, "drop table if exists duplicate_test")
tk.MustExec(ctx, "create table duplicate_test(id int auto_increment, k1 int, primary key(id), unique key uk(k1))")
tk.MustExec(ctx, "insert into duplicate_test(k1) values(?),(?),(?),(?),(?)", tk.PermInt(5)...)
tk.MustExec(ctx, "insert into duplicate_test(k1) values(?),(?),(?),(?),(?)", permInt(5)...)

defer config.RestoreFunc()()
config.UpdateGlobal(func(conf *config.Config) {
conf.EnableBatchDML = true
})

tk.ConcurrentRun(c, 3, 2, // concurrent: 3, loops: 2,
tk.ConcurrentRun(
3,
2,
// prepare data for each loop.
func(ctx context.Context, tk *testkit.CTestKit, concurrent int, currentLoop int) [][][]interface{} {
func(ctx context.Context, tk *testkit.AsyncTestKit, concurrent int, currentLoop int) [][][]interface{} {
var ii [][][]interface{}
for i := 0; i < concurrent; i++ {
ii = append(ii, [][]interface{}{tk.PermInt(7)})
ii = append(ii, [][]interface{}{permInt(7)})
}
return ii
},
// concurrent execute logic.
func(ctx context.Context, tk *testkit.CTestKit, input [][]interface{}) {
func(ctx context.Context, tk *testkit.AsyncTestKit, input [][]interface{}) {
tk.MustExec(ctx, "set @@session.tidb_batch_insert=1")
tk.MustExec(ctx, "set @@session.tidb_dml_batch_size=1")
_, err := tk.Exec(ctx, "insert ignore into duplicate_test(k1) values (?),(?),(?),(?),(?),(?),(?)", input[0]...)
tk.IgnoreError(err)
_, _ = tk.Exec(ctx, "insert ignore into duplicate_test(k1) values (?),(?),(?),(?),(?),(?),(?)", input[0]...)
},
// check after all done.
func(ctx context.Context, tk *testkit.CTestKit) {
func(ctx context.Context, tk *testkit.AsyncTestKit) {
tk.MustExec(ctx, "admin check table duplicate_test")
tk.MustQuery(ctx, "select d1.id, d1.k1 from duplicate_test d1 ignore index(uk), duplicate_test d2 use index (uk) where d1.id = d2.id and d1.k1 <> d2.k1").
Check(testkit.Rows())
tk.MustQuery(ctx, "select d1.id, d1.k1 from duplicate_test d1 ignore index(uk), duplicate_test d2 use index (uk) where d1.id = d2.id and d1.k1 <> d2.k1").Check(testkit.Rows())
})
}

func permInt(n int) []interface{} {
randPermSlice := rand.Perm(n)
v := make([]interface{}, 0, len(randPermSlice))
for _, i := range randPermSlice {
v = append(v, i)
}
return v
}
226 changes: 226 additions & 0 deletions testkit/asynctestkit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

// +build !codes

package testkit

import (
"context"
"fmt"
"testing"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/sqlexec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
)

var asyncTestKitIDGenerator atomic.Uint64

// AsyncTestKit is a utility to run sql concurrently.
type AsyncTestKit struct {
require *require.Assertions
assert *assert.Assertions
store kv.Storage
}

// NewAsyncTestKit returns a new *AsyncTestKit.
func NewAsyncTestKit(t *testing.T, store kv.Storage) *AsyncTestKit {
return &AsyncTestKit{
require: require.New(t),
assert: assert.New(t),
store: store,
}
}

// OpenSession opens new session ctx if no exists one and use db.
func (tk *AsyncTestKit) OpenSession(ctx context.Context, db string) context.Context {
if tryRetrieveSession(ctx) == nil {
se, err := session.CreateSession4Test(tk.store)
tk.require.NoError(err)
se.SetConnectionID(asyncTestKitIDGenerator.Inc())
ctx = context.WithValue(ctx, sessionKey, se)
}
tk.MustExec(ctx, fmt.Sprintf("use %s", db))
return ctx
}

// CloseSession closes exists session from ctx.
func (tk *AsyncTestKit) CloseSession(ctx context.Context) {
se := tryRetrieveSession(ctx)
tk.require.NotNil(se)
se.Close()
}

// ConcurrentRun run test in current.
// - concurrent: controls the concurrent worker count.
// - loops: controls run test how much times.
// - prepareFunc: provide test data and will be called for every loop.
// - checkFunc: used to do some check after all workers done.
// works like create table better be put in front of this method calling.
// see more example at TestBatchInsertWithOnDuplicate
func (tk *AsyncTestKit) ConcurrentRun(
concurrent int,
loops int,
prepareFunc func(ctx context.Context, tk *AsyncTestKit, concurrent int, currentLoop int) [][][]interface{},
writeFunc func(ctx context.Context, tk *AsyncTestKit, input [][]interface{}),
checkFunc func(ctx context.Context, tk *AsyncTestKit),
) {
channel := make([]chan [][]interface{}, concurrent)
contextList := make([]context.Context, concurrent)
doneList := make([]context.CancelFunc, concurrent)

for i := 0; i < concurrent; i++ {
w := i
channel[w] = make(chan [][]interface{}, 1)
contextList[w], doneList[w] = context.WithCancel(context.Background())
contextList[w] = tk.OpenSession(contextList[w], "test")
go func() {
defer func() {
r := recover()
tk.require.Nil(r, string(util.GetStack()))
doneList[w]()
}()

for input := range channel[w] {
writeFunc(contextList[w], tk, input)
}
}()
}

defer func() {
for i := 0; i < concurrent; i++ {
tk.CloseSession(contextList[i])
}
}()

ctx := tk.OpenSession(context.Background(), "test")
defer tk.CloseSession(ctx)
tk.MustExec(ctx, "use test")

for j := 0; j < loops; j++ {
data := prepareFunc(ctx, tk, concurrent, j)
for i := 0; i < concurrent; i++ {
channel[i] <- data[i]
}
}

for i := 0; i < concurrent; i++ {
close(channel[i])
}

for i := 0; i < concurrent; i++ {
<-contextList[i].Done()
}
checkFunc(ctx, tk)
}

// Exec executes a sql statement.
func (tk *AsyncTestKit) Exec(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) {
se := tryRetrieveSession(ctx)
tk.require.NotNil(se)

if len(args) == 0 {
rss, err := se.Execute(ctx, sql)
if err == nil && len(rss) > 0 {
return rss[0], nil
}
return nil, err
}

stmtID, _, _, err := se.PrepareStmt(sql)
if err != nil {
return nil, err
}

params := make([]types.Datum, len(args))
for i := 0; i < len(params); i++ {
params[i] = types.NewDatum(args[i])
}

rs, err := se.ExecutePreparedStmt(ctx, stmtID, params)
if err != nil {
return nil, err
}

err = se.DropPreparedStmt(stmtID)
if err != nil {
return nil, err
}

return rs, nil
}

// MustExec executes a sql statement and asserts nil error.
func (tk *AsyncTestKit) MustExec(ctx context.Context, sql string, args ...interface{}) {
res, err := tk.Exec(ctx, sql, args...)
tk.require.NoErrorf(err, "sql:%s, %v, error stack %v", sql, args, errors.ErrorStack(err))
if res != nil {
tk.require.NoError(res.Close())
}
}

// MustQuery query the statements and returns result rows.
// If expected result is set it asserts the query result equals expected result.
func (tk *AsyncTestKit) MustQuery(ctx context.Context, sql string, args ...interface{}) *Result {
comment := fmt.Sprintf("sql:%s, args:%v", sql, args)
rs, err := tk.Exec(ctx, sql, args...)
tk.require.NoError(err, comment)
tk.require.NotNil(rs, comment)
return tk.resultSetToResult(ctx, rs, comment)
}

// resultSetToResult converts ast.RecordSet to testkit.Result.
// It is used to check results of execute statement in binary mode.
func (tk *AsyncTestKit) resultSetToResult(ctx context.Context, rs sqlexec.RecordSet, comment string) *Result {
rows, err := session.GetRows4Test(context.Background(), tryRetrieveSession(ctx), rs)
tk.require.NoError(err, comment)

err = rs.Close()
tk.require.NoError(err, comment)

result := make([][]string, len(rows))
for i := range rows {
row := rows[i]
resultRow := make([]string, row.Len())
for j := 0; j < row.Len(); j++ {
if row.IsNull(j) {
resultRow[j] = "<nil>"
} else {
d := row.GetDatum(j, &rs.Fields()[j].Column.FieldType)
resultRow[j], err = d.ToString()
tk.require.NoError(err, comment)
}
}
result[i] = resultRow
}
return &Result{rows: result, comment: comment, assert: tk.assert, require: tk.require}
}

type sessionCtxKeyType struct{}

var sessionKey = sessionCtxKeyType{}

func tryRetrieveSession(ctx context.Context) session.Session {
s := ctx.Value(sessionKey)
if s == nil {
return nil
}
return s.(session.Session)
}
46 changes: 46 additions & 0 deletions testkit/mockstore.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

// +build !codes

package testkit

import (
"testing"

"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/store/mockstore"
"github.com/stretchr/testify/require"
)

// CreateMockStore return a new mock kv.Storage.
func CreateMockStore(t *testing.T) (store kv.Storage, clean func()) {
store, err := mockstore.NewMockStore()
require.NoError(t, err)

session.SetSchemaLease(0)
session.DisableStats4Test()
d, err := session.BootstrapSession(store)
require.NoError(t, err)

d.SetStatsUpdating(true)

clean = func() {
d.Close()
err := store.Close()
require.NoError(t, err)
}

return
}
4 changes: 2 additions & 2 deletions testkit/testkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
"go.uber.org/atomic"
)

var idGenerator atomic.Uint64
var testKitIDGenerator atomic.Uint64

// TestKit is a utility to run sql test.
type TestKit struct {
Expand Down Expand Up @@ -140,6 +140,6 @@ func (tk *TestKit) Exec(sql string, args ...interface{}) (sqlexec.RecordSet, err
func newSession(t *testing.T, store kv.Storage) session.Session {
se, err := session.CreateSession4Test(store)
require.Nil(t, err)
se.SetConnectionID(idGenerator.Inc())
se.SetConnectionID(testKitIDGenerator.Inc())
return se
}
Loading

0 comments on commit f9652b2

Please sign in to comment.