From 8b7057a6a7b61b5fb335acc8a59b6401383cfa3d Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Mon, 4 Dec 2023 17:25:24 +0800 Subject: [PATCH] session: fix select for update statement can't get stmt-count-limit error (#48412) (#48468) close pingcap/tidb#48411 --- server/server_test.go | 74 +++++++++++++++++++++++++++++ server/tidb_test.go | 5 ++ session/session.go | 8 ++++ session/sessiontest/session_test.go | 10 ++++ session/tidb.go | 22 +++++++-- 5 files changed, 114 insertions(+), 5 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index 2e6477d00ff56..a54fcd18ed2e1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/kv" tmysql "github.com/pingcap/tidb/parser/mysql" @@ -2569,3 +2570,76 @@ func TestIssue46197(t *testing.T) { path := testdata.ConvertRowsToStrings(tk.MustQuery("select @@tidb_last_plan_replayer_token").Rows()) require.NoError(t, os.Remove(filepath.Join(replayer.GetPlanReplayerDirName(), path[0]))) } + +func (cli *testServerClient) RunTestStmtCountLimit(t *testing.T) { + originalStmtCountLimit := config.GetGlobalConfig().Performance.StmtCountLimit + config.UpdateGlobal(func(conf *config.Config) { + conf.Performance.StmtCountLimit = 3 + }) + defer func() { + config.UpdateGlobal(func(conf *config.Config) { + conf.Performance.StmtCountLimit = originalStmtCountLimit + }) + }() + + cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (id int key);") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + _, err := dbt.GetDB().Query("select * from t for update;") + require.Error(t, err) + require.Equal(t, "Error 1105 (HY000): statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error()) + dbt.MustExec("insert into t values (3);") + dbt.MustExec("commit;") + rows := dbt.MustQuery("select * from t;") + var id int + count := 0 + for rows.Next() { + rows.Scan(&id) + count++ + } + require.NoError(t, rows.Close()) + require.Equal(t, 3, id) + require.Equal(t, 1, count) + + dbt.MustExec("delete from t;") + dbt.MustExec("commit;") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + _, err = dbt.GetDB().Exec("insert into t values (3);") + require.Error(t, err) + require.Equal(t, "Error 1105 (HY000): statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error()) + dbt.MustExec("commit;") + rows = dbt.MustQuery("select count(*) from t;") + for rows.Next() { + rows.Scan(&count) + } + require.NoError(t, rows.Close()) + require.Equal(t, 0, count) + + dbt.MustExec("delete from t;") + dbt.MustExec("commit;") + dbt.MustExec("set @@tidb_batch_commit=1;") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + dbt.MustExec("insert into t values (3);") + dbt.MustExec("insert into t values (4);") + dbt.MustExec("insert into t values (5);") + dbt.MustExec("commit;") + rows = dbt.MustQuery("select count(*) from t;") + for rows.Next() { + rows.Scan(&count) + } + require.NoError(t, rows.Close()) + require.Equal(t, 5, count) + }) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index b51aa67ccd25c..569684be80920 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1122,6 +1122,11 @@ func TestSumAvg(t *testing.T) { ts.runTestSumAvg(t) } +func TestStmtCountLimit(t *testing.T) { + ts := createTidbTestSuite(t) + ts.RunTestStmtCountLimit(t) +} + func TestNullFlag(t *testing.T) { ts := createTidbTestSuite(t) diff --git a/session/session.go b/session/session.go index 5ccae09f21d35..003467fec3c3a 100644 --- a/session/session.go +++ b/session/session.go @@ -2399,6 +2399,14 @@ func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec. if err != nil { return nil, err } + if sessVars.TxnCtx.CouldRetry && !s.IsReadOnly(sessVars) { + // Only when the txn is could retry and the statement is not read only, need to do stmt-count-limit check, + // otherwise, the stmt won't be add into stmt history, and also don't need check. + // About `stmt-count-limit`, see more in https://docs.pingcap.com/tidb/stable/tidb-configuration-file#stmt-count-limit + if err := checkStmtLimit(ctx, se, false); err != nil { + return nil, err + } + } rs, err = s.Exec(ctx) se.updateTelemetryMetric(s.(*executor.ExecStmt)) diff --git a/session/sessiontest/session_test.go b/session/sessiontest/session_test.go index 6124d4e4e269e..79f05e0abc2dc 100644 --- a/session/sessiontest/session_test.go +++ b/session/sessiontest/session_test.go @@ -968,6 +968,16 @@ func TestBatchCommit(t *testing.T) { tk.MustExec("insert into t values (7)") tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) + tk.MustExec("delete from t") + tk.MustExec("commit") + tk.MustExec("begin") + tk.MustExec("explain analyze insert into t values (5)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("explain analyze insert into t values (6)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("explain analyze insert into t values (7)") + tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) + // The session is still in transaction. tk.MustExec("insert into t values (8)") tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) diff --git a/session/tidb.go b/session/tidb.go index 310d76e007e5a..dad6d14daf41a 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -271,7 +271,7 @@ func finishStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.St if err != nil { return err } - return checkStmtLimit(ctx, se) + return checkStmtLimit(ctx, se, true) } func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error { @@ -305,18 +305,29 @@ func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql s return nil } -func checkStmtLimit(ctx context.Context, se *session) error { +func checkStmtLimit(ctx context.Context, se *session, isFinish bool) error { // If the user insert, insert, insert ... but never commit, TiDB would OOM. // So we limit the statement count in a transaction here. var err error sessVars := se.GetSessionVars() history := GetHistory(se) - if history.Count() > int(config.GetGlobalConfig().Performance.StmtCountLimit) { + stmtCount := history.Count() + if !isFinish { + // history stmt count + current stmt, since current stmt is not finish, it has not add to history. + stmtCount++ + } + if stmtCount > int(config.GetGlobalConfig().Performance.StmtCountLimit) { if !sessVars.BatchCommit { se.RollbackTxn(ctx) - return errors.Errorf("statement count %d exceeds the transaction limitation, autocommit = %t", - history.Count(), sessVars.IsAutocommit()) + return errors.Errorf("statement count %d exceeds the transaction limitation, transaction has been rollback, autocommit = %t", + stmtCount, sessVars.IsAutocommit()) + } + if !isFinish { + // if the stmt is not finish execute, then just return, since some work need to be done such as StmtCommit. + return nil } + // If the stmt is finish execute, and exceed the StmtCountLimit, and BatchCommit is true, + // then commit the current transaction and create a new transaction. err = sessiontxn.NewTxn(ctx, se) // The transaction does not committed yet, we need to keep it in transaction. // The last history could not be "commit"/"rollback" statement. @@ -328,6 +339,7 @@ func checkStmtLimit(ctx context.Context, se *session) error { } // GetHistory get all stmtHistory in current txn. Exported only for test. +// If stmtHistory is nil, will create a new one for current txn. func GetHistory(ctx sessionctx.Context) *StmtHistory { hist, ok := ctx.GetSessionVars().TxnCtx.History.(*StmtHistory) if ok {