Skip to content

Commit

Permalink
*: prevent cursor read from being cancelled by GC (#39950) (#39989)
Browse files Browse the repository at this point in the history
close #39447
  • Loading branch information
ti-chi-bot authored Jan 16, 2023
1 parent ee7b18f commit cfdd74f
Show file tree
Hide file tree
Showing 20 changed files with 394 additions and 20 deletions.
4 changes: 4 additions & 0 deletions bindinfo/session_handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ func (msm *mockSessionManager) GetInternalSessionStartTSList() []uint64 {
return nil
}

func (msm *mockSessionManager) GetMinStartTS(lowerBound uint64) uint64 {
return 0
}

func TestIssue19836(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
Expand Down
21 changes: 15 additions & 6 deletions ddl/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1415,8 +1415,10 @@ func TestLogAndShowSlowLog(t *testing.T) {
}

func TestReportingMinStartTimestamp(t *testing.T) {
_, dom, clean := testkit.CreateMockStoreAndDomainWithSchemaLease(t, dbTestLease)
store, dom, clean := testkit.CreateMockStoreAndDomainWithSchemaLease(t, dbTestLease)
defer clean()
tk := testkit.NewTestKit(t, store)
se := tk.Session()

infoSyncer := dom.InfoSyncer()
sm := &testkit.MockSessionManager{
Expand All @@ -1432,12 +1434,19 @@ func TestReportingMinStartTimestamp(t *testing.T) {
validTS := oracle.GoTimeToLowerLimitStartTS(now.Add(time.Minute), tikv.MaxTxnTimeUse)
lowerLimit := oracle.GoTimeToLowerLimitStartTS(now, tikv.MaxTxnTimeUse)
sm.PS = []*util.ProcessInfo{
{CurTxnStartTS: 0},
{CurTxnStartTS: math.MaxUint64},
{CurTxnStartTS: lowerLimit},
{CurTxnStartTS: validTS},
{CurTxnStartTS: 0, ProtectedTSList: &se.GetSessionVars().ProtectedTSList},
{CurTxnStartTS: math.MaxUint64, ProtectedTSList: &se.GetSessionVars().ProtectedTSList},
{CurTxnStartTS: lowerLimit, ProtectedTSList: &se.GetSessionVars().ProtectedTSList},
{CurTxnStartTS: validTS, ProtectedTSList: &se.GetSessionVars().ProtectedTSList},
}
infoSyncer.SetSessionManager(sm)
infoSyncer.ReportMinStartTS(dom.Store())
require.Equal(t, validTS, infoSyncer.GetMinStartTS())

unhold := se.GetSessionVars().ProtectedTSList.HoldTS(validTS - 1)
infoSyncer.ReportMinStartTS(dom.Store())
require.Equal(t, validTS-1, infoSyncer.GetMinStartTS())

unhold()
infoSyncer.ReportMinStartTS(dom.Store())
require.Equal(t, validTS, infoSyncer.GetMinStartTS())
}
16 changes: 2 additions & 14 deletions domain/infosync/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,6 @@ func (is *InfoSyncer) ReportMinStartTS(store kv.Storage) {
if sm == nil {
return
}
pl := sm.ShowProcessList()
innerSessionStartTSList := sm.GetInternalSessionStartTSList()

// Calculate the lower limit of the start timestamp to avoid extremely old transaction delaying GC.
currentVer, err := store.CurrentVersion(kv.GlobalTxnScope)
Expand All @@ -644,18 +642,8 @@ func (is *InfoSyncer) ReportMinStartTS(store kv.Storage) {
minStartTS := oracle.GoTimeToTS(now)
logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("initial minStartTS", minStartTS),
zap.Uint64("StartTSLowerLimit", startTSLowerLimit))
for _, info := range pl {
if info.CurTxnStartTS > startTSLowerLimit && info.CurTxnStartTS < minStartTS {
minStartTS = info.CurTxnStartTS
}
}

for _, innerTS := range innerSessionStartTSList {
logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("Internal Session Transaction StartTS", innerTS))
kv.PrintLongTimeInternalTxn(now, innerTS, false)
if innerTS > startTSLowerLimit && innerTS < minStartTS {
minStartTS = innerTS
}
if ts := sm.GetMinStartTS(startTSLowerLimit); ts > startTSLowerLimit && ts < minStartTS {
minStartTS = ts
}

is.minStartTS = kv.GetMinInnerTxnStartTS(now, startTSLowerLimit, minStartTS)
Expand Down
4 changes: 4 additions & 0 deletions executor/executor_pkg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ func (msm *mockSessionManager) GetInternalSessionStartTSList() []uint64 {
return nil
}

func (msm *mockSessionManager) GetMinStartTS(lowerBound uint64) uint64 {
return 0
}

func TestShowProcessList(t *testing.T) {
// Compose schema.
names := []string{"Id", "User", "Host", "db", "Command", "Time", "State", "Info"}
Expand Down
4 changes: 4 additions & 0 deletions executor/infoschema_cluster_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ func (sm *mockSessionManager) SetServerID(serverID uint64) {
sm.serverID = serverID
}

func (sm *mockSessionManager) GetMinStartTS(lowerBound uint64) uint64 {
return 0
}

type mockStore struct {
helper.Storage
host string
Expand Down
4 changes: 4 additions & 0 deletions executor/prepared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ func (sm *mockSessionManager2) GetInternalSessionStartTSList() []uint64 {
return nil
}

func (sm *mockSessionManager2) GetMinStartTS(lowerBound uint64) uint64 {
return 0
}

func TestPreparedStmtWithHint(t *testing.T) {
// see https://github.com/pingcap/tidb/issues/18535
store, dom, clean := testkit.CreateMockStoreAndDomain(t)
Expand Down
4 changes: 4 additions & 0 deletions executor/seqtest/prepared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,10 @@ func (msm *mockSessionManager1) GetInternalSessionStartTSList() []uint64 {
return nil
}

func (msm *mockSessionManager1) GetMinStartTS(lowerBound uint64) uint64 {
return 0
}

func TestPreparedIssue17419(t *testing.T) {
store, dom, clean := testkit.CreateMockStoreAndDomain(t)
defer clean()
Expand Down
4 changes: 4 additions & 0 deletions infoschema/tables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ func (sm *mockSessionManager) GetInternalSessionStartTSList() []uint64 {
return nil
}

func (sm *mockSessionManager) GetMinStartTS(lowerBound uint64) uint64 {
return 0
}

func TestSomeTables(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
Expand Down
6 changes: 6 additions & 0 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stm
if useCursor {
cc.initResultEncoder(ctx)
defer cc.rsEncoder.clean()
// fix https://github.com/pingcap/tidb/issues/39447. we need to hold the start-ts here because the process info
// will be set to sleep after fetch returned.
if pi := cc.ctx.ShowProcess(); pi != nil && pi.ProtectedTSList != nil && pi.CurTxnStartTS > 0 {
unhold := pi.HoldTS(pi.CurTxnStartTS)
rs = &rsWithHooks{ResultSet: rs, onClosed: unhold}
}
stmt.StoreResultSet(rs)
err = cc.writeColumnInfo(rs.Columns(), mysql.ServerStatusCursorExists)
if err != nil {
Expand Down
90 changes: 90 additions & 0 deletions server/conn_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
package server

import (
"context"
"encoding/binary"
"testing"

"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -248,3 +251,90 @@ func TestParseStmtFetchCmd(t *testing.T) {
require.Equal(t, tc.err, err)
}
}

func TestCursorReadHoldTS(t *testing.T) {
store, dom, clean := testkit.CreateMockStoreAndDomain(t)
defer clean()
srv := CreateMockServer(t, store)
srv.SetDomain(dom)
defer srv.Close()

appendUint32 := binary.LittleEndian.AppendUint32
ctx := context.Background()
c := CreateMockConn(t, store, srv)
tk := testkit.NewTestKitWithSession(t, store, c.Context().Session)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int primary key)")
tk.MustExec("insert into t values (1), (2), (3), (4), (5), (6), (7), (8)")
tk.MustQuery("select count(*) from t").Check(testkit.Rows("8"))

stmt, _, _, err := c.Context().Prepare("select * from t")
require.NoError(t, err)
require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0))

// should hold ts after executing stmt with cursor
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())),
0x1, 0x1, 0x0, 0x0, 0x0,
)))
ts := tk.Session().ShowProcess().GetMinStartTS(0)
require.Positive(t, ts)
// should unhold ts when result set exhausted
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.Equal(t, ts, tk.Session().ShowProcess().GetMinStartTS(0))
require.Equal(t, ts, srv.GetMinStartTS(0))
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.Equal(t, ts, tk.Session().ShowProcess().GetMinStartTS(0))
require.Equal(t, ts, srv.GetMinStartTS(0))
require.NoError(t, c.Dispatch(ctx, appendUint32(appendUint32([]byte{mysql.ComStmtFetch}, uint32(stmt.ID())), 5)))
require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0))

// should hold ts after executing stmt with cursor
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())),
0x1, 0x1, 0x0, 0x0, 0x0,
)))
require.Positive(t, tk.Session().ShowProcess().GetMinStartTS(0))
// should unhold ts when stmt reset
require.NoError(t, c.Dispatch(ctx, appendUint32([]byte{mysql.ComStmtReset}, uint32(stmt.ID()))))
require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0))

// should hold ts after executing stmt with cursor
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt.ID())),
0x1, 0x1, 0x0, 0x0, 0x0,
)))
require.Positive(t, tk.Session().ShowProcess().GetMinStartTS(0))
// should unhold ts when stmt closed
require.NoError(t, c.Dispatch(ctx, appendUint32([]byte{mysql.ComStmtClose}, uint32(stmt.ID()))))
require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0))

// create another 2 stmts and execute them
stmt1, _, _, err := c.Context().Prepare("select * from t")
require.NoError(t, err)
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt1.ID())),
0x1, 0x1, 0x0, 0x0, 0x0,
)))
ts1 := tk.Session().ShowProcess().GetMinStartTS(0)
require.Positive(t, ts1)
stmt2, _, _, err := c.Context().Prepare("select * from t")
require.NoError(t, err)
require.NoError(t, c.Dispatch(ctx, append(
appendUint32([]byte{mysql.ComStmtExecute}, uint32(stmt2.ID())),
0x1, 0x1, 0x0, 0x0, 0x0,
)))
ts2 := tk.Session().ShowProcess().GetMinStartTS(ts1)
require.Positive(t, ts2)

require.Less(t, ts1, ts2)
require.Equal(t, ts1, srv.GetMinStartTS(0))
require.Equal(t, ts2, srv.GetMinStartTS(ts1))
require.Zero(t, srv.GetMinStartTS(ts2))

// should unhold all when session closed
c.Close()
require.Zero(t, tk.Session().ShowProcess().GetMinStartTS(0))
require.Zero(t, srv.GetMinStartTS(0))
}
2 changes: 2 additions & 0 deletions server/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ type ResultSet interface {
StoreFetchedRows(rows []chunk.Row)
GetFetchedRows() []chunk.Row
Close() error
// IsClosed checks whether the result set is closed.
IsClosed() bool
}

// fetchNotifier represents notifier will be called in COM_FETCH.
Expand Down
45 changes: 45 additions & 0 deletions server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ func (trs *tidbResultSet) Close() error {
return err
}

// IsClosed implements ResultSet.IsClosed interface.
func (trs *tidbResultSet) IsClosed() bool {
return atomic.LoadInt32(&trs.closed) == 1
}

// OnFetchReturned implements fetchNotifier#OnFetchReturned
func (trs *tidbResultSet) OnFetchReturned() {
if cl, ok := trs.recordSet.(fetchNotifier); ok {
Expand Down Expand Up @@ -375,6 +380,46 @@ func (trs *tidbResultSet) Columns() []*ColumnInfo {
return trs.columns
}

// rsWithHooks wraps a ResultSet with some hooks (currently only onClosed).
type rsWithHooks struct {
ResultSet
onClosed func()
}

// Close implements ResultSet#Close
func (rs *rsWithHooks) Close() error {
closed := rs.IsClosed()
err := rs.ResultSet.Close()
if !closed && rs.onClosed != nil {
rs.onClosed()
}
return err
}

// OnFetchReturned implements fetchNotifier#OnFetchReturned
func (rs *rsWithHooks) OnFetchReturned() {
if impl, ok := rs.ResultSet.(fetchNotifier); ok {
impl.OnFetchReturned()
}
}

// Unwrap returns the underlying result set
func (rs *rsWithHooks) Unwrap() ResultSet {
return rs.ResultSet
}

// unwrapResultSet likes errors.Cause but for ResultSet
func unwrapResultSet(rs ResultSet) ResultSet {
var unRS ResultSet
if u, ok := rs.(interface{ Unwrap() ResultSet }); ok {
unRS = u.Unwrap()
}
if unRS == nil {
return rs
}
return unwrapResultSet(unRS)
}

func convertColumnInfo(fld *ast.ResultField) (ci *ColumnInfo) {
ci = &ColumnInfo{
Name: fld.ColumnAsName.O,
Expand Down
25 changes: 25 additions & 0 deletions server/driver_tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/sqlexec"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -95,3 +96,27 @@ func TestConvertColumnInfo(t *testing.T) {
colInfo = convertColumnInfo(&resultField)
require.Equal(t, uint32(4), colInfo.ColumnLength)
}

func TestRSWithHooks(t *testing.T) {
closeCount := 0
rs := &rsWithHooks{
ResultSet: &tidbResultSet{recordSet: new(sqlexec.SimpleRecordSet)},
onClosed: func() { closeCount++ },
}
require.Equal(t, 0, closeCount)
rs.Close()
require.Equal(t, 1, closeCount)
rs.Close()
require.Equal(t, 1, closeCount)
}

func TestUnwrapRS(t *testing.T) {
var nilRS ResultSet
require.Nil(t, unwrapResultSet(nilRS))
rs0 := new(tidbResultSet)
rs1 := &rsWithHooks{ResultSet: rs0}
rs2 := &rsWithHooks{ResultSet: rs1}
for _, rs := range []ResultSet{rs0, rs1, rs2} {
require.Equal(t, rs0, unwrapResultSet(rs))
}
}
3 changes: 3 additions & 0 deletions server/mock_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ func CreateMockConn(t *testing.T, store kv.Storage, server *Server) MockConn {
},
}
cc.setCtx(tc)
cc.server.rwlock.Lock()
server.clients[cc.connectionID] = cc
cc.server.rwlock.Unlock()
return &mockConn{
clientConn: cc,
t: t,
Expand Down
Loading

0 comments on commit cfdd74f

Please sign in to comment.