Skip to content

Commit

Permalink
session: fix select for update statement can't get stmt-count-limit e…
Browse files Browse the repository at this point in the history
…rror (#48412) (#49512)

merge manually
  • Loading branch information
crazycs520 authored Dec 17, 2023
1 parent 8bdf0a2 commit f172580
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 7 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ vet:
$(GO) vet -all $(PACKAGES) 2>&1 | $(FAIL_ON_STDOUT)

staticcheck:
$(GO) get honnef.co/go/tools/cmd/staticcheck
$(GO) get honnef.co/go/tools/cmd/staticcheck@2021.1
$(STATICCHECK) ./...

tidy:
Expand Down Expand Up @@ -274,7 +274,7 @@ checkdep:

tools/bin/megacheck: tools/check/go.mod
cd tools/check; \
$(GO) build -o ../bin/megacheck honnef.co/go/tools/cmd/megacheck
$(GO) build -o ../bin/megacheck honnef.co/go/tools/cmd/megacheck@2021.1

tools/bin/revive: tools/check/go.mod
cd tools/check; \
Expand Down
76 changes: 76 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ import (
"github.com/pingcap/failpoint"
"github.com/pingcap/log"
tmysql "github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/versioninfo"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -1496,3 +1498,77 @@ func (cli *testServerClient) waitUntilServerOnline() {
log.Fatal("failed to connect HTTP status in every 10 ms", zap.Int("retryTime", retryTime))
}
}

func (cli *testServerClient) RunTestStmtCountLimit(t *C) {
originalStmtCountLimit := config.GetGlobalConfig().Performance.StmtCountLimit
config.UpdateGlobal(func(conf *config.Config) {
conf.Performance.StmtCountLimit = 3
})
defer func() {
config.UpdateGlobal(func(conf *config.Config) {
conf.Performance.StmtCountLimit = originalStmtCountLimit
})
}()

cli.runTests(t, nil, func(dbt *DBTest) {
dbt.mustExec("drop table if exists t")
dbt.mustExec("create table t (id int key);")
dbt.mustExec("set @@tidb_disable_txn_auto_retry=0;")
dbt.mustExec("set autocommit=0;")
dbt.mustExec("begin optimistic;")
dbt.mustExec("insert into t values (1);")
dbt.mustExec("insert into t values (2);")
_, err := dbt.db.Query("select * from t for update;")
require.Error(t, err)
require.Equal(t, "Error 1105: statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error())
dbt.mustExec("insert into t values (3);")
dbt.mustExec("commit;")
rows := dbt.mustQuery("select * from t;")
var id int
count := 0
for rows.Next() {
rows.Scan(&id)
count++
}
require.NoError(t, rows.Close())
require.Equal(t, 3, id)
require.Equal(t, 1, count)

dbt.mustExec("delete from t;")
dbt.mustExec("commit;")
dbt.mustExec("set @@tidb_disable_txn_auto_retry=0;")
dbt.mustExec("set autocommit=0;")
dbt.mustExec("begin optimistic;")
dbt.mustExec("insert into t values (1);")
dbt.mustExec("insert into t values (2);")
_, err = dbt.db.Exec("insert into t values (3);")
require.Error(t, err)
require.Equal(t, "Error 1105: statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error())
dbt.mustExec("commit;")
rows = dbt.mustQuery("select count(*) from t;")
for rows.Next() {
rows.Scan(&count)
}
require.NoError(t, rows.Close())
require.Equal(t, 0, count)

dbt.mustExec("delete from t;")
dbt.mustExec("commit;")
dbt.mustExec("set @@tidb_batch_commit=1;")
dbt.mustExec("set @@tidb_disable_txn_auto_retry=0;")
dbt.mustExec("set autocommit=0;")
dbt.mustExec("begin optimistic;")
dbt.mustExec("insert into t values (1);")
dbt.mustExec("insert into t values (2);")
dbt.mustExec("insert into t values (3);")
dbt.mustExec("insert into t values (4);")
dbt.mustExec("insert into t values (5);")
dbt.mustExec("commit;")
rows = dbt.mustQuery("select count(*) from t;")
for rows.Next() {
rows.Scan(&count)
}
require.NoError(t, rows.Close())
require.Equal(t, 5, count)
})
}
4 changes: 4 additions & 0 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,10 @@ func (ts *tidbTestSuite) TestSumAvg(c *C) {
ts.runTestSumAvg(c)
}

func (ts *tidbTestSuite) TestStmtCountLimit(c *C) {
ts.RunTestStmtCountLimit(c)
}

func (ts *tidbTestSuite) TestNullFlag(c *C) {
// issue #9689
qctx, err := ts.tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil)
Expand Down
10 changes: 10 additions & 0 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2522,6 +2522,16 @@ func (s *testSessionSerialSuite) TestBatchCommit(c *C) {
tk.MustExec("insert into t values (7)")
tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7"))

tk.MustExec("delete from t")
tk.MustExec("commit")
tk.MustExec("begin")
tk.MustExec("explain analyze insert into t values (5)")
tk1.MustQuery("select * from t").Check(testkit.Rows())
tk.MustExec("explain analyze insert into t values (6)")
tk1.MustQuery("select * from t").Check(testkit.Rows())
tk.MustExec("explain analyze insert into t values (7)")
tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7"))

// The session is still in transaction.
tk.MustExec("insert into t values (8)")
tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7"))
Expand Down
30 changes: 25 additions & 5 deletions session/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func finishStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.St
if err != nil {
return err
}
return checkStmtLimit(ctx, se)
return checkStmtLimit(ctx, se, true)
}

func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error {
Expand Down Expand Up @@ -239,18 +239,29 @@ func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql s
return nil
}

func checkStmtLimit(ctx context.Context, se *session) error {
func checkStmtLimit(ctx context.Context, se *session, isFinish bool) error {
// If the user insert, insert, insert ... but never commit, TiDB would OOM.
// So we limit the statement count in a transaction here.
var err error
sessVars := se.GetSessionVars()
history := GetHistory(se)
if history.Count() > int(config.GetGlobalConfig().Performance.StmtCountLimit) {
stmtCount := history.Count()
if !isFinish {
// history stmt count + current stmt, since current stmt is not finish, it has not add to history.
stmtCount++
}
if stmtCount > int(config.GetGlobalConfig().Performance.StmtCountLimit) {
if !sessVars.BatchCommit {
se.RollbackTxn(ctx)
return errors.Errorf("statement count %d exceeds the transaction limitation, autocommit = %t",
history.Count(), sessVars.IsAutocommit())
return errors.Errorf("statement count %d exceeds the transaction limitation, transaction has been rollback, autocommit = %t",
stmtCount, sessVars.IsAutocommit())
}
if !isFinish {
// if the stmt is not finish execute, then just return, since some work need to be done such as StmtCommit.
return nil
}
// If the stmt is finish execute, and exceed the StmtCountLimit, and BatchCommit is true,
// then commit the current transaction and create a new transaction.
err = se.NewTxn(ctx)
// The transaction does not committed yet, we need to keep it in transaction.
// The last history could not be "commit"/"rollback" statement.
Expand Down Expand Up @@ -305,6 +316,14 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement)
if err != nil {
return nil, err
}
if sessVars.TxnCtx.CouldRetry && !s.IsReadOnly(sessVars) {
// Only when the txn is could retry and the statement is not read only, need to do stmt-count-limit check,
// otherwise, the stmt won't be add into stmt history, and also don't need check.
// About `stmt-count-limit`, see more in https://docs.pingcap.com/tidb/stable/tidb-configuration-file#stmt-count-limit
if err := checkStmtLimit(ctx, se, false); err != nil {
return nil, err
}
}
rs, err = s.Exec(ctx)
sessVars.TxnCtx.StatementCount++
if !s.IsReadOnly(sessVars) {
Expand Down Expand Up @@ -349,6 +368,7 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement)
}

// GetHistory get all stmtHistory in current txn. Exported only for test.
// If stmtHistory is nil, will create a new one for current txn.
func GetHistory(ctx sessionctx.Context) *StmtHistory {
hist, ok := ctx.GetSessionVars().TxnCtx.History.(*StmtHistory)
if ok {
Expand Down

0 comments on commit f172580

Please sign in to comment.