Skip to content

Commit

Permalink
*: optimize temptable.SessionSnapshotInterceptor (#36999) (#37031)
Browse files Browse the repository at this point in the history
close #37000
  • Loading branch information
ti-srebot authored Aug 10, 2022
1 parent b08ec41 commit 93ffafc
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 18 deletions.
7 changes: 4 additions & 3 deletions sessiontxn/internal/txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table/temptable"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -62,9 +61,11 @@ func CommitBeforeEnterNewTxn(ctx context.Context, sctx sessionctx.Context) error
}

// GetSnapshotWithTS returns a snapshot with ts.
func GetSnapshotWithTS(s sessionctx.Context, ts uint64) kv.Snapshot {
func GetSnapshotWithTS(s sessionctx.Context, ts uint64, interceptor kv.SnapshotInterceptor) kv.Snapshot {
snap := s.GetStore().GetSnapshot(kv.Version{Ver: ts})
snap.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(s))
if interceptor != nil {
snap.SetOption(kv.SnapInterceptor, interceptor)
}
if s.GetSessionVars().InRestrictedSQL {
snap.SetOption(kv.RequestSourceInternal, true)
}
Expand Down
8 changes: 6 additions & 2 deletions sessiontxn/isolation/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func (p *baseTxnContextProvider) ActivateTxn() (kv.Transaction, error) {
if readReplicaType.IsFollowerRead() {
txn.SetOption(kv.ReplicaRead, readReplicaType)
}
txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx))
txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx, p.infoSchema))

if sessVars.StmtCtx.WeakConsistency {
txn.SetOption(kv.IsolationLevel, kv.RC)
Expand Down Expand Up @@ -378,7 +378,11 @@ func (p *baseTxnContextProvider) getSnapshotByTS(snapshotTS uint64) (kv.Snapshot
}

sessVars := p.sctx.GetSessionVars()
snapshot := internal.GetSnapshotWithTS(p.sctx, snapshotTS)
snapshot := internal.GetSnapshotWithTS(
p.sctx,
snapshotTS,
temptable.SessionSnapshotInterceptor(p.sctx, p.infoSchema),
)

replicaReadType := sessVars.GetReplicaRead()
if replicaReadType.IsFollowerRead() && !sessVars.StmtCtx.RCCheckTS {
Expand Down
8 changes: 6 additions & 2 deletions sessiontxn/staleread/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (p *StalenessTxnContextProvider) activateStaleTxn() error {
TxnScope: txnScope,
},
}
txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx))
txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx, is))

p.is = is
err = p.sctx.GetSessionVars().SetSystemVar(variable.TiDBSnapshot, "")
Expand Down Expand Up @@ -209,7 +209,11 @@ func (p *StalenessTxnContextProvider) GetSnapshotWithStmtReadTS() (kv.Snapshot,
}

sessVars := p.sctx.GetSessionVars()
snapshot := internal.GetSnapshotWithTS(p.sctx, p.ts)
snapshot := internal.GetSnapshotWithTS(
p.sctx,
p.ts,
temptable.SessionSnapshotInterceptor(p.sctx, p.is),
)

replicaReadType := sessVars.GetReplicaRead()
if replicaReadType.IsFollowerRead() {
Expand Down
19 changes: 10 additions & 9 deletions sessiontxn/txn_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/pingcap/tidb/sessiontxn"
"github.com/pingcap/tidb/sessiontxn/internal"
"github.com/pingcap/tidb/sessiontxn/staleread"
"github.com/pingcap/tidb/table/temptable"
"github.com/pingcap/tidb/tablecodec"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/tests/realtikvtest"
Expand Down Expand Up @@ -310,7 +311,7 @@ func TestGetSnapshot(t *testing.T) {
check: func(t *testing.T, sctx sessionctx.Context) {
ts, err := mgr.GetStmtReadTS()
require.NoError(t, err)
compareSnap := internal.GetSnapshotWithTS(sctx, ts)
compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err := mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap, snap))
Expand All @@ -320,7 +321,7 @@ func TestGetSnapshot(t *testing.T) {
tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3", "10"))
ts, err = mgr.GetStmtForUpdateTS()
require.NoError(t, err)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err = mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.False(t, isSnapshotEqual(t, compareSnap2, snap))
Expand All @@ -340,7 +341,7 @@ func TestGetSnapshot(t *testing.T) {
check: func(t *testing.T, sctx sessionctx.Context) {
ts, err := mgr.GetStmtReadTS()
require.NoError(t, err)
compareSnap := internal.GetSnapshotWithTS(sctx, ts)
compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err := mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap, snap))
Expand All @@ -350,7 +351,7 @@ func TestGetSnapshot(t *testing.T) {
tk.MustQuery("select * from t").Check(testkit.Rows("1", "3", "10"))
ts, err = mgr.GetStmtForUpdateTS()
require.NoError(t, err)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err = mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap2, snap))
Expand All @@ -369,7 +370,7 @@ func TestGetSnapshot(t *testing.T) {
check: func(t *testing.T, sctx sessionctx.Context) {
ts, err := mgr.GetStmtReadTS()
require.NoError(t, err)
compareSnap := internal.GetSnapshotWithTS(sctx, ts)
compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err := mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap, snap))
Expand All @@ -379,7 +380,7 @@ func TestGetSnapshot(t *testing.T) {
tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3"))
ts, err = mgr.GetStmtForUpdateTS()
require.NoError(t, err)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err = mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap2, snap))
Expand All @@ -400,7 +401,7 @@ func TestGetSnapshot(t *testing.T) {
check: func(t *testing.T, sctx sessionctx.Context) {
ts, err := mgr.GetStmtReadTS()
require.NoError(t, err)
compareSnap := internal.GetSnapshotWithTS(sctx, ts)
compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err := mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap, snap))
Expand All @@ -410,7 +411,7 @@ func TestGetSnapshot(t *testing.T) {
tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3"))
ts, err = mgr.GetStmtForUpdateTS()
require.NoError(t, err)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err = mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap2, snap))
Expand Down Expand Up @@ -497,7 +498,7 @@ func TestSnapshotInterceptor(t *testing.T) {
}

// Also check GetSnapshotWithTS
snap := internal.GetSnapshotWithTS(tk.Session(), 0)
snap := internal.GetSnapshotWithTS(tk.Session(), 0, temptable.SessionSnapshotInterceptor(tk.Session(), sessiontxn.GetTxnManager(tk.Session()).GetTxnInfoSchema()))
val, err := snap.Get(context.Background(), k)
require.NoError(t, err)
require.Equal(t, []byte("v1"), val)
Expand Down
4 changes: 2 additions & 2 deletions table/temptable/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ type TemporaryTableSnapshotInterceptor struct {
}

// SessionSnapshotInterceptor creates a new snapshot interceptor for temporary table data fetch
func SessionSnapshotInterceptor(sctx sessionctx.Context) kv.SnapshotInterceptor {
func SessionSnapshotInterceptor(sctx sessionctx.Context, is infoschema.InfoSchema) kv.SnapshotInterceptor {
return NewTemporaryTableSnapshotInterceptor(
sctx.GetInfoSchema().(infoschema.InfoSchema),
is,
getSessionData(sctx),
)
}
Expand Down

0 comments on commit 93ffafc

Please sign in to comment.