diff --git a/sessiontxn/internal/txn.go b/sessiontxn/internal/txn.go index 00db4561b979b..304ad91450c61 100644 --- a/sessiontxn/internal/txn.go +++ b/sessiontxn/internal/txn.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/table/temptable" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) @@ -62,9 +61,11 @@ func CommitBeforeEnterNewTxn(ctx context.Context, sctx sessionctx.Context) error } // GetSnapshotWithTS returns a snapshot with ts. -func GetSnapshotWithTS(s sessionctx.Context, ts uint64) kv.Snapshot { +func GetSnapshotWithTS(s sessionctx.Context, ts uint64, interceptor kv.SnapshotInterceptor) kv.Snapshot { snap := s.GetStore().GetSnapshot(kv.Version{Ver: ts}) - snap.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(s)) + if interceptor != nil { + snap.SetOption(kv.SnapInterceptor, interceptor) + } if s.GetSessionVars().InRestrictedSQL { snap.SetOption(kv.RequestSourceInternal, true) } diff --git a/sessiontxn/isolation/base.go b/sessiontxn/isolation/base.go index 47c57ea0306e8..7bded0b3d7baf 100644 --- a/sessiontxn/isolation/base.go +++ b/sessiontxn/isolation/base.go @@ -251,7 +251,7 @@ func (p *baseTxnContextProvider) ActivateTxn() (kv.Transaction, error) { if readReplicaType.IsFollowerRead() { txn.SetOption(kv.ReplicaRead, readReplicaType) } - txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx)) + txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx, p.infoSchema)) if sessVars.StmtCtx.WeakConsistency { txn.SetOption(kv.IsolationLevel, kv.RC) @@ -393,7 +393,11 @@ func (p *baseTxnContextProvider) getSnapshotByTS(snapshotTS uint64) (kv.Snapshot } sessVars := p.sctx.GetSessionVars() - snapshot := internal.GetSnapshotWithTS(p.sctx, snapshotTS) + snapshot := internal.GetSnapshotWithTS( + p.sctx, + snapshotTS, + temptable.SessionSnapshotInterceptor(p.sctx, p.infoSchema), + ) replicaReadType := sessVars.GetReplicaRead() if replicaReadType.IsFollowerRead() && !sessVars.StmtCtx.RCCheckTS { diff --git a/sessiontxn/staleread/provider.go b/sessiontxn/staleread/provider.go index b7ff889c0e29c..184ac33e97964 100644 --- a/sessiontxn/staleread/provider.go +++ b/sessiontxn/staleread/provider.go @@ -124,7 +124,7 @@ func (p *StalenessTxnContextProvider) activateStaleTxn() error { TxnScope: txnScope, }, } - txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx)) + txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx, is)) p.is = is err = p.sctx.GetSessionVars().SetSystemVar(variable.TiDBSnapshot, "") @@ -209,7 +209,11 @@ func (p *StalenessTxnContextProvider) GetSnapshotWithStmtReadTS() (kv.Snapshot, } sessVars := p.sctx.GetSessionVars() - snapshot := internal.GetSnapshotWithTS(p.sctx, p.ts) + snapshot := internal.GetSnapshotWithTS( + p.sctx, + p.ts, + temptable.SessionSnapshotInterceptor(p.sctx, p.is), + ) replicaReadType := sessVars.GetReplicaRead() if replicaReadType.IsFollowerRead() { diff --git a/sessiontxn/txn_manager_test.go b/sessiontxn/txn_manager_test.go index 609f09ff5e945..a76d7a85704be 100644 --- a/sessiontxn/txn_manager_test.go +++ b/sessiontxn/txn_manager_test.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/sessiontxn" "github.com/pingcap/tidb/sessiontxn/internal" "github.com/pingcap/tidb/sessiontxn/staleread" + "github.com/pingcap/tidb/table/temptable" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/tests/realtikvtest" @@ -308,7 +309,7 @@ func TestGetSnapshot(t *testing.T) { check: func(t *testing.T, sctx sessionctx.Context) { ts, err := mgr.GetStmtReadTS() require.NoError(t, err) - compareSnap := internal.GetSnapshotWithTS(sctx, ts) + compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil) snap, err := mgr.GetSnapshotWithStmtReadTS() require.NoError(t, err) require.True(t, isSnapshotEqual(t, compareSnap, snap)) @@ -318,7 +319,7 @@ func TestGetSnapshot(t *testing.T) { tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3", "10")) ts, err = mgr.GetStmtForUpdateTS() require.NoError(t, err) - compareSnap2 := internal.GetSnapshotWithTS(sctx, ts) + compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil) snap, err = mgr.GetSnapshotWithStmtReadTS() require.NoError(t, err) require.False(t, isSnapshotEqual(t, compareSnap2, snap)) @@ -338,7 +339,7 @@ func TestGetSnapshot(t *testing.T) { check: func(t *testing.T, sctx sessionctx.Context) { ts, err := mgr.GetStmtReadTS() require.NoError(t, err) - compareSnap := internal.GetSnapshotWithTS(sctx, ts) + compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil) snap, err := mgr.GetSnapshotWithStmtReadTS() require.NoError(t, err) require.True(t, isSnapshotEqual(t, compareSnap, snap)) @@ -348,7 +349,7 @@ func TestGetSnapshot(t *testing.T) { tk.MustQuery("select * from t").Check(testkit.Rows("1", "3", "10")) ts, err = mgr.GetStmtForUpdateTS() require.NoError(t, err) - compareSnap2 := internal.GetSnapshotWithTS(sctx, ts) + compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil) snap, err = mgr.GetSnapshotWithStmtReadTS() require.NoError(t, err) require.True(t, isSnapshotEqual(t, compareSnap2, snap)) @@ -367,7 +368,7 @@ func TestGetSnapshot(t *testing.T) { check: func(t *testing.T, sctx sessionctx.Context) { ts, err := mgr.GetStmtReadTS() require.NoError(t, err) - compareSnap := internal.GetSnapshotWithTS(sctx, ts) + compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil) snap, err := mgr.GetSnapshotWithStmtReadTS() require.NoError(t, err) require.True(t, isSnapshotEqual(t, compareSnap, snap)) @@ -377,7 +378,7 @@ func TestGetSnapshot(t *testing.T) { tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3")) ts, err = mgr.GetStmtForUpdateTS() require.NoError(t, err) - compareSnap2 := internal.GetSnapshotWithTS(sctx, ts) + compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil) snap, err = mgr.GetSnapshotWithStmtReadTS() require.NoError(t, err) require.True(t, isSnapshotEqual(t, compareSnap2, snap)) @@ -398,7 +399,7 @@ func TestGetSnapshot(t *testing.T) { check: func(t *testing.T, sctx sessionctx.Context) { ts, err := mgr.GetStmtReadTS() require.NoError(t, err) - compareSnap := internal.GetSnapshotWithTS(sctx, ts) + compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil) snap, err := mgr.GetSnapshotWithStmtReadTS() require.NoError(t, err) require.True(t, isSnapshotEqual(t, compareSnap, snap)) @@ -408,7 +409,7 @@ func TestGetSnapshot(t *testing.T) { tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3")) ts, err = mgr.GetStmtForUpdateTS() require.NoError(t, err) - compareSnap2 := internal.GetSnapshotWithTS(sctx, ts) + compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil) snap, err = mgr.GetSnapshotWithStmtReadTS() require.NoError(t, err) require.True(t, isSnapshotEqual(t, compareSnap2, snap)) @@ -494,7 +495,7 @@ func TestSnapshotInterceptor(t *testing.T) { } // Also check GetSnapshotWithTS - snap := internal.GetSnapshotWithTS(tk.Session(), 0) + snap := internal.GetSnapshotWithTS(tk.Session(), 0, temptable.SessionSnapshotInterceptor(tk.Session(), sessiontxn.GetTxnManager(tk.Session()).GetTxnInfoSchema())) val, err := snap.Get(context.Background(), k) require.NoError(t, err) require.Equal(t, []byte("v1"), val) diff --git a/table/temptable/interceptor.go b/table/temptable/interceptor.go index aa61eae9db0b7..e260c7628edde 100644 --- a/table/temptable/interceptor.go +++ b/table/temptable/interceptor.go @@ -40,9 +40,9 @@ type TemporaryTableSnapshotInterceptor struct { } // SessionSnapshotInterceptor creates a new snapshot interceptor for temporary table data fetch -func SessionSnapshotInterceptor(sctx sessionctx.Context) kv.SnapshotInterceptor { +func SessionSnapshotInterceptor(sctx sessionctx.Context, is infoschema.InfoSchema) kv.SnapshotInterceptor { return NewTemporaryTableSnapshotInterceptor( - sctx.GetInfoSchema().(infoschema.InfoSchema), + is, getSessionData(sctx), ) }