diff --git a/executor/adapter.go b/executor/adapter.go index 3a4ace6ba582b..89980ab3779ae 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -259,7 +259,7 @@ func (a *ExecStmt) PointGet(ctx context.Context, is infoschema.InfoSchema) (*rec } else { // CachedPlan type is already checked in last step pointGetPlan := a.PsStmt.PreparedAst.CachedPlan.(*plannercore.PointGetPlan) - exec.Init(pointGetPlan, startTs) + exec.Init(pointGetPlan) a.PsStmt.Executor = exec } } diff --git a/executor/batch_point_get.go b/executor/batch_point_get.go index b5eb68a8b12de..f502804f016c8 100644 --- a/executor/batch_point_get.go +++ b/executor/batch_point_get.go @@ -22,8 +22,6 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/ddl/placement" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/model" @@ -40,35 +38,31 @@ import ( "github.com/pingcap/tidb/util/logutil/consistency" "github.com/pingcap/tidb/util/mathutil" "github.com/pingcap/tidb/util/rowcodec" - "github.com/tikv/client-go/v2/txnkv/txnsnapshot" ) // BatchPointGetExec executes a bunch of point select queries. type BatchPointGetExec struct { baseExecutor - tblInfo *model.TableInfo - idxInfo *model.IndexInfo - handles []kv.Handle - physIDs []int64 - partExpr *tables.PartitionExpr - partPos int - singlePart bool - partTblID int64 - idxVals [][]types.Datum - readReplicaScope string - isStaleness bool - snapshotTS uint64 - txn kv.Transaction - lock bool - waitTime int64 - inited uint32 - values [][]byte - index int - rowDecoder *rowcodec.ChunkDecoder - keepOrder bool - desc bool - batchGetter kv.BatchGetter + tblInfo *model.TableInfo + idxInfo *model.IndexInfo + handles []kv.Handle + physIDs []int64 + partExpr *tables.PartitionExpr + partPos int + singlePart bool + partTblID int64 + idxVals [][]types.Datum + txn kv.Transaction + lock bool + waitTime int64 + inited uint32 + values [][]byte + index int + rowDecoder *rowcodec.ChunkDecoder + keepOrder bool + desc bool + batchGetter kv.BatchGetter columns []*model.ColumnInfo // virtualColumnIndex records all the indices of virtual columns and sort them in definition @@ -78,9 +72,8 @@ type BatchPointGetExec struct { // virtualColumnRetFieldTypes records the RetFieldTypes of virtual columns. virtualColumnRetFieldTypes []*types.FieldType - snapshot kv.Snapshot - stats *runtimeStatsWithSnapshot - cacheTable kv.MemBuffer + snapshot kv.Snapshot + stats *runtimeStatsWithSnapshot } // buildVirtualColumnInfo saves virtual column indices and sort them in definition order @@ -98,69 +91,24 @@ func (e *BatchPointGetExec) buildVirtualColumnInfo() { func (e *BatchPointGetExec) Open(context.Context) error { sessVars := e.ctx.GetSessionVars() txnCtx := sessVars.TxnCtx - stmtCtx := sessVars.StmtCtx txn, err := e.ctx.Txn(false) if err != nil { return err } e.txn = txn - var snapshot kv.Snapshot - if txn.Valid() && txnCtx.StartTS == txnCtx.GetForUpdateTS() && txnCtx.StartTS == e.snapshotTS { - // We can safely reuse the transaction snapshot if snapshotTS is equal to forUpdateTS. - // The snapshot may contain cache that can reduce RPC call. - snapshot = txn.GetSnapshot() - } else { - snapshot = e.ctx.GetSnapshotWithTS(e.snapshotTS) - } - if e.ctx.GetSessionVars().StmtCtx.RCCheckTS { - snapshot.SetOption(kv.IsolationLevel, kv.RCCheckTS) - } - if e.cacheTable != nil { - snapshot = cacheTableSnapshot{snapshot, e.cacheTable} - } - if e.runtimeStats != nil { - snapshotStats := &txnsnapshot.SnapshotRuntimeStats{} - e.stats = &runtimeStatsWithSnapshot{ - SnapshotRuntimeStats: snapshotStats, - } - snapshot.SetOption(kv.CollectRuntimeStats, snapshotStats) - stmtCtx.RuntimeStatsColl.RegisterStats(e.id, e.stats) - } - replicaReadType := e.ctx.GetSessionVars().GetReplicaRead() - if replicaReadType.IsFollowerRead() && !e.ctx.GetSessionVars().StmtCtx.RCCheckTS { - snapshot.SetOption(kv.ReplicaRead, replicaReadType) - } - snapshot.SetOption(kv.TaskID, stmtCtx.TaskID) - snapshot.SetOption(kv.ReadReplicaScope, e.readReplicaScope) - snapshot.SetOption(kv.IsStalenessReadOnly, e.isStaleness) - failpoint.Inject("assertBatchPointReplicaOption", func(val failpoint.Value) { - assertScope := val.(string) - if replicaReadType.IsClosestRead() && assertScope != e.readReplicaScope { - panic("batch point get replica option fail") - } - }) - - if replicaReadType.IsClosestRead() && e.readReplicaScope != kv.GlobalTxnScope { - snapshot.SetOption(kv.MatchStoreLabels, []*metapb.StoreLabel{ - { - Key: placement.DCLabelKey, - Value: e.readReplicaScope, - }, - }) - } - setOptionForTopSQL(stmtCtx, snapshot) - var batchGetter kv.BatchGetter = snapshot + + setOptionForTopSQL(e.ctx.GetSessionVars().StmtCtx, e.snapshot) + var batchGetter kv.BatchGetter = e.snapshot if txn.Valid() { lock := e.tblInfo.Lock if e.lock { - batchGetter = driver.NewBufferBatchGetter(txn.GetMemBuffer(), &PessimisticLockCacheGetter{txnCtx: txnCtx}, snapshot) + batchGetter = driver.NewBufferBatchGetter(txn.GetMemBuffer(), &PessimisticLockCacheGetter{txnCtx: txnCtx}, e.snapshot) } else if lock != nil && (lock.Tp == model.TableLockRead || lock.Tp == model.TableLockReadOnly) && e.ctx.GetSessionVars().EnablePointGetCache { - batchGetter = newCacheBatchGetter(e.ctx, e.tblInfo.ID, snapshot) + batchGetter = newCacheBatchGetter(e.ctx, e.tblInfo.ID, e.snapshot) } else { - batchGetter = driver.NewBufferBatchGetter(txn.GetMemBuffer(), nil, snapshot) + batchGetter = driver.NewBufferBatchGetter(txn.GetMemBuffer(), nil, e.snapshot) } } - e.snapshot = snapshot e.batchGetter = batchGetter return nil } diff --git a/executor/builder.go b/executor/builder.go index 0a5e3b60ee4e0..9335231d217e3 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -29,8 +29,10 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/diagnosticspb" + "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/ddl/placement" "github.com/pingcap/tidb/distsql" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/executor/aggfuncs" @@ -68,6 +70,7 @@ import ( "github.com/pingcap/tidb/util/timeutil" "github.com/pingcap/tipb/go-tipb" "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/txnkv/txnsnapshot" ) var ( @@ -1528,6 +1531,39 @@ func (b *executorBuilder) getSnapshotTS() (uint64, error) { return txnManager.GetStmtReadTS() } +// getSnapshot get the appropriate snapshot from txnManager and set +// the relevant snapshot options before return. +func (b *executorBuilder) getSnapshot() (kv.Snapshot, error) { + var snapshot kv.Snapshot + var err error + + txnManager := sessiontxn.GetTxnManager(b.ctx) + if b.inInsertStmt || b.inUpdateStmt || b.inDeleteStmt || b.inSelectLockStmt { + snapshot, err = txnManager.GetForUpdateSnapshot() + } else { + snapshot, err = txnManager.GetReadSnapshot() + } + if err != nil { + return nil, err + } + + sessVars := b.ctx.GetSessionVars() + replicaReadType := sessVars.GetReplicaRead() + snapshot.SetOption(kv.ReadReplicaScope, b.readReplicaScope) + snapshot.SetOption(kv.TaskID, sessVars.StmtCtx.TaskID) + + if replicaReadType.IsClosestRead() && b.readReplicaScope != kv.GlobalTxnScope { + snapshot.SetOption(kv.MatchStoreLabels, []*metapb.StoreLabel{ + { + Key: placement.DCLabelKey, + Value: b.readReplicaScope, + }, + }) + } + + return snapshot, nil +} + func (b *executorBuilder) buildMemTable(v *plannercore.PhysicalMemTable) Executor { switch v.DBName.L { case util.MetricSchemaName.L: @@ -4549,7 +4585,8 @@ func NewRowDecoder(ctx sessionctx.Context, schema *expression.Schema, tbl *model } func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan) Executor { - if err := b.validCanReadTemporaryOrCacheTable(plan.TblInfo); err != nil { + var err error + if err = b.validCanReadTemporaryOrCacheTable(plan.TblInfo); err != nil { b.err = err return nil } @@ -4561,34 +4598,53 @@ func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan }() } - snapshotTS, err := b.getSnapshotTS() + decoder := NewRowDecoder(b.ctx, plan.Schema(), plan.TblInfo) + e := &BatchPointGetExec{ + baseExecutor: newBaseExecutor(b.ctx, plan.Schema(), plan.ID()), + tblInfo: plan.TblInfo, + idxInfo: plan.IndexInfo, + rowDecoder: decoder, + keepOrder: plan.KeepOrder, + desc: plan.Desc, + lock: plan.Lock, + waitTime: plan.LockWaitTime, + partExpr: plan.PartitionExpr, + partPos: plan.PartitionColPos, + singlePart: plan.SinglePart, + partTblID: plan.PartTblID, + columns: plan.Columns, + } + + e.snapshot, err = b.getSnapshot() if err != nil { b.err = err return nil } - - decoder := NewRowDecoder(b.ctx, plan.Schema(), plan.TblInfo) - e := &BatchPointGetExec{ - baseExecutor: newBaseExecutor(b.ctx, plan.Schema(), plan.ID()), - tblInfo: plan.TblInfo, - idxInfo: plan.IndexInfo, - rowDecoder: decoder, - snapshotTS: snapshotTS, - readReplicaScope: b.readReplicaScope, - isStaleness: b.isStaleness, - keepOrder: plan.KeepOrder, - desc: plan.Desc, - lock: plan.Lock, - waitTime: plan.LockWaitTime, - partExpr: plan.PartitionExpr, - partPos: plan.PartitionColPos, - singlePart: plan.SinglePart, - partTblID: plan.PartTblID, - columns: plan.Columns, + if e.runtimeStats != nil { + snapshotStats := &txnsnapshot.SnapshotRuntimeStats{} + e.stats = &runtimeStatsWithSnapshot{ + SnapshotRuntimeStats: snapshotStats, + } + e.snapshot.SetOption(kv.CollectRuntimeStats, snapshotStats) + b.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.id, e.stats) } + failpoint.Inject("assertBatchPointReplicaOption", func(val failpoint.Value) { + assertScope := val.(string) + if e.ctx.GetSessionVars().GetReplicaRead().IsClosestRead() && assertScope != b.readReplicaScope { + panic("batch point get replica option fail") + } + }) + + snapshotTS, err := b.getSnapshotTS() + if err != nil { + b.err = err + return nil + } if plan.TblInfo.TableCacheStatusType == model.TableCacheStatusEnable { - e.cacheTable = b.getCacheTable(plan.TblInfo, snapshotTS) + if cacheTable := b.getCacheTable(plan.TblInfo, snapshotTS); cacheTable != nil { + e.snapshot = cacheTableSnapshot{e.snapshot, cacheTable} + } } if plan.TblInfo.TempTableType != model.TempTableNone { diff --git a/executor/point_get.go b/executor/point_get.go index 1b4d6666663b5..881aa758ecd1f 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -20,8 +20,6 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/ddl/placement" "github.com/pingcap/tidb/distsql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" @@ -44,7 +42,8 @@ import ( ) func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) Executor { - if err := b.validCanReadTemporaryOrCacheTable(p.TblInfo); err != nil { + var err error + if err = b.validCanReadTemporaryOrCacheTable(p.TblInfo); err != nil { b.err = err return nil } @@ -56,25 +55,47 @@ func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) Executor { }() } - snapshotTS, err := b.getSnapshotTS() - if err != nil { - b.err = err - return nil - } - e := &PointGetExecutor{ baseExecutor: newBaseExecutor(b.ctx, p.Schema(), p.ID()), readReplicaScope: b.readReplicaScope, isStaleness: b.isStaleness, } - if p.TblInfo.TableCacheStatusType == model.TableCacheStatusEnable { - e.cacheTable = b.getCacheTable(p.TblInfo, snapshotTS) - } - e.base().initCap = 1 e.base().maxChunkSize = 1 - e.Init(p, snapshotTS) + e.Init(p) + + e.snapshot, err = b.getSnapshot() + if err != nil { + b.err = err + return nil + } + if e.runtimeStats != nil { + snapshotStats := &txnsnapshot.SnapshotRuntimeStats{} + e.stats = &runtimeStatsWithSnapshot{ + SnapshotRuntimeStats: snapshotStats, + } + e.snapshot.SetOption(kv.CollectRuntimeStats, snapshotStats) + b.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.id, e.stats) + } + + failpoint.Inject("assertPointReplicaOption", func(val failpoint.Value) { + assertScope := val.(string) + if e.ctx.GetSessionVars().GetReplicaRead().IsClosestRead() && assertScope != e.readReplicaScope { + panic("point get replica option fail") + } + }) + + snapshotTS, err := b.getSnapshotTS() + if err != nil { + b.err = err + return nil + } + if p.TblInfo.TableCacheStatusType == model.TableCacheStatusEnable { + if cacheTable := b.getCacheTable(p.TblInfo, snapshotTS); cacheTable != nil { + e.snapshot = cacheTableSnapshot{e.snapshot, cacheTable} + } + } if e.lock { b.hasLock = true @@ -94,7 +115,6 @@ type PointGetExecutor struct { idxKey kv.Key handleVal []byte idxVals []types.Datum - snapshotTS uint64 readReplicaScope string isStaleness bool txn kv.Transaction @@ -112,18 +132,16 @@ type PointGetExecutor struct { // virtualColumnRetFieldTypes records the RetFieldTypes of virtual columns. virtualColumnRetFieldTypes []*types.FieldType - stats *runtimeStatsWithSnapshot - cacheTable kv.MemBuffer + stats *runtimeStatsWithSnapshot } // Init set fields needed for PointGetExecutor reuse, this does NOT change baseExecutor field -func (e *PointGetExecutor) Init(p *plannercore.PointGetPlan, snapshotTS uint64) { +func (e *PointGetExecutor) Init(p *plannercore.PointGetPlan) { decoder := NewRowDecoder(e.ctx, p.Schema(), p.TblInfo) e.tblInfo = p.TblInfo e.handle = p.Handle e.idxInfo = p.IndexInfo e.idxVals = p.IndexValues - e.snapshotTS = snapshotTS e.done = false if e.tblInfo.TempTableType == model.TempTableNone { e.lock = p.Lock @@ -152,56 +170,14 @@ func (e *PointGetExecutor) buildVirtualColumnInfo() { // Open implements the Executor interface. func (e *PointGetExecutor) Open(context.Context) error { - txnCtx := e.ctx.GetSessionVars().TxnCtx - snapshotTS := e.snapshotTS var err error e.txn, err = e.ctx.Txn(false) if err != nil { return err } - if e.txn.Valid() && txnCtx.StartTS == txnCtx.GetForUpdateTS() && txnCtx.StartTS == snapshotTS { - e.snapshot = e.txn.GetSnapshot() - } else { - e.snapshot = e.ctx.GetSnapshotWithTS(snapshotTS) - } - if e.ctx.GetSessionVars().StmtCtx.RCCheckTS { - e.snapshot.SetOption(kv.IsolationLevel, kv.RCCheckTS) - } - if e.cacheTable != nil { - e.snapshot = cacheTableSnapshot{e.snapshot, e.cacheTable} - } if err := e.verifyTxnScope(); err != nil { return err } - if e.runtimeStats != nil { - snapshotStats := &txnsnapshot.SnapshotRuntimeStats{} - e.stats = &runtimeStatsWithSnapshot{ - SnapshotRuntimeStats: snapshotStats, - } - e.snapshot.SetOption(kv.CollectRuntimeStats, snapshotStats) - e.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.id, e.stats) - } - readReplicaType := e.ctx.GetSessionVars().GetReplicaRead() - if readReplicaType.IsFollowerRead() && !e.ctx.GetSessionVars().StmtCtx.RCCheckTS { - e.snapshot.SetOption(kv.ReplicaRead, readReplicaType) - } - e.snapshot.SetOption(kv.TaskID, e.ctx.GetSessionVars().StmtCtx.TaskID) - e.snapshot.SetOption(kv.ReadReplicaScope, e.readReplicaScope) - e.snapshot.SetOption(kv.IsStalenessReadOnly, e.isStaleness) - if readReplicaType.IsClosestRead() && e.readReplicaScope != kv.GlobalTxnScope { - e.snapshot.SetOption(kv.MatchStoreLabels, []*metapb.StoreLabel{ - { - Key: placement.DCLabelKey, - Value: e.readReplicaScope, - }, - }) - } - failpoint.Inject("assertPointReplicaOption", func(val failpoint.Value) { - assertScope := val.(string) - if readReplicaType.IsClosestRead() && assertScope != e.readReplicaScope { - panic("point get replica option fail") - } - }) setOptionForTopSQL(e.ctx.GetSessionVars().StmtCtx, e.snapshot) return nil } diff --git a/session/session.go b/session/session.go index 57d9dad27d10a..5ae3f2165b146 100644 --- a/session/session.go +++ b/session/session.go @@ -3155,13 +3155,6 @@ func (s *session) RefreshTxnCtx(ctx context.Context) error { return sessiontxn.NewTxn(ctx, s) } -// GetSnapshotWithTS returns a snapshot with ts. -func (s *session) GetSnapshotWithTS(ts uint64) kv.Snapshot { - snap := s.GetStore().GetSnapshot(kv.Version{Ver: ts}) - snap.SetOption(kv.SnapInterceptor, s.getSnapshotInterceptor()) - return snap -} - // GetStore gets the store of session. func (s *session) GetStore() kv.Storage { return s.store diff --git a/session/txnmanager.go b/session/txnmanager.go index ce93c0ded44da..63c5340e41e3a 100644 --- a/session/txnmanager.go +++ b/session/txnmanager.go @@ -94,6 +94,20 @@ func (m *txnManager) GetStmtForUpdateTS() (uint64, error) { return ts, nil } +func (m *txnManager) GetReadSnapshot() (kv.Snapshot, error) { + if m.ctxProvider == nil { + return nil, errors.New("context provider not set") + } + return m.ctxProvider.GetSnapshotWithStmtReadTS() +} + +func (m *txnManager) GetForUpdateSnapshot() (kv.Snapshot, error) { + if m.ctxProvider == nil { + return nil, errors.New("context provider not set") + } + return m.ctxProvider.GetSnapshotWithStmtForUpdateTS() +} + func (m *txnManager) GetContextProvider() sessiontxn.TxnContextProvider { return m.ctxProvider } diff --git a/sessionctx/context.go b/sessionctx/context.go index 3a320fe9078ef..21b89ae72a351 100644 --- a/sessionctx/context.go +++ b/sessionctx/context.go @@ -109,9 +109,6 @@ type Context interface { // only used to daemon session like `statsHandle` to detect global variable change. RefreshVars(context.Context) error - // GetSnapshotWithTS returns a snapshot with start ts - GetSnapshotWithTS(ts uint64) kv.Snapshot - // GetStore returns the store of session. GetStore() kv.Storage diff --git a/sessiontxn/interface.go b/sessiontxn/interface.go index d9c23acf70df6..6d809fa923c38 100644 --- a/sessiontxn/interface.go +++ b/sessiontxn/interface.go @@ -121,6 +121,10 @@ type TxnContextProvider interface { GetStmtReadTS() (uint64, error) // GetStmtForUpdateTS returns the read timestamp used by update/insert/delete or select ... for update GetStmtForUpdateTS() (uint64, error) + // GetSnapshotWithStmtReadTS get snapshot with read ts + GetSnapshotWithStmtReadTS() (kv.Snapshot, error) + // GetSnapshotWithStmtForUpdateTS get snapshot with for update ts + GetSnapshotWithStmtForUpdateTS() (kv.Snapshot, error) // OnInitialize is the hook that should be called when enter a new txn with this provider OnInitialize(ctx context.Context, enterNewTxnType EnterNewTxnType) error @@ -147,6 +151,10 @@ type TxnManager interface { GetStmtForUpdateTS() (uint64, error) // GetContextProvider returns the current TxnContextProvider GetContextProvider() TxnContextProvider + // GetReadSnapshot get snapshot with read ts + GetReadSnapshot() (kv.Snapshot, error) + // GetForUpdateSnapshot get snapshot with for update ts + GetForUpdateSnapshot() (kv.Snapshot, error) // EnterNewTxn enters a new transaction. EnterNewTxn(ctx context.Context, req *EnterNewTxnRequest) error diff --git a/sessiontxn/isolation/base.go b/sessiontxn/isolation/base.go index 11c543a17bcf4..4bcd657974053 100644 --- a/sessiontxn/isolation/base.go +++ b/sessiontxn/isolation/base.go @@ -279,3 +279,47 @@ func (p *baseTxnContextProvider) AdviseWarmup() error { func (p *baseTxnContextProvider) AdviseOptimizeWithPlan(_ interface{}) error { return nil } + +// GetSnapshotWithStmtReadTS get snapshot with read ts +func (p *baseTxnContextProvider) GetSnapshotWithStmtReadTS() (kv.Snapshot, error) { + ts, err := p.GetStmtReadTS() + if err != nil { + return nil, err + } + + return p.getSnapshotByTS(ts) +} + +// GetSnapshotWithStmtForUpdateTS get snapshot with for update ts +func (p *baseTxnContextProvider) GetSnapshotWithStmtForUpdateTS() (kv.Snapshot, error) { + ts, err := p.GetStmtForUpdateTS() + if err != nil { + return nil, err + } + + return p.getSnapshotByTS(ts) +} + +// getSnapshotByTS get snapshot from store according to the snapshotTS and set the transaction related +// options before return +func (p *baseTxnContextProvider) getSnapshotByTS(snapshotTS uint64) (kv.Snapshot, error) { + txn, err := p.sctx.Txn(false) + if err != nil { + return nil, err + } + + txnCtx := p.sctx.GetSessionVars().TxnCtx + if txn.Valid() && txnCtx.StartTS == txnCtx.GetForUpdateTS() && txnCtx.StartTS == snapshotTS { + return txn.GetSnapshot(), nil + } + + sessVars := p.sctx.GetSessionVars() + snapshot := sessiontxn.GetSnapshotWithTS(p.sctx, snapshotTS) + + replicaReadType := sessVars.GetReplicaRead() + if replicaReadType.IsFollowerRead() && !sessVars.StmtCtx.RCCheckTS { + snapshot.SetOption(kv.ReplicaRead, replicaReadType) + } + + return snapshot, nil +} diff --git a/sessiontxn/isolation/readcommitted.go b/sessiontxn/isolation/readcommitted.go index ead853459fdb1..d2afad4c2ea26 100644 --- a/sessiontxn/isolation/readcommitted.go +++ b/sessiontxn/isolation/readcommitted.go @@ -257,3 +257,17 @@ func (p *PessimisticRCTxnContextProvider) AdviseOptimizeWithPlan(val interface{} return nil } + +// GetSnapshotWithStmtReadTS get snapshot with read ts +func (p *PessimisticRCTxnContextProvider) GetSnapshotWithStmtReadTS() (kv.Snapshot, error) { + snapshot, err := p.baseTxnContextProvider.GetSnapshotWithStmtForUpdateTS() + if err != nil { + return nil, err + } + + if p.sctx.GetSessionVars().StmtCtx.RCCheckTS { + snapshot.SetOption(kv.IsolationLevel, kv.RCCheckTS) + } + + return snapshot, nil +} diff --git a/sessiontxn/staleread/provider.go b/sessiontxn/staleread/provider.go index 34c8ec985371d..289d4bcc024e8 100644 --- a/sessiontxn/staleread/provider.go +++ b/sessiontxn/staleread/provider.go @@ -149,3 +149,32 @@ func (p *StalenessTxnContextProvider) AdviseWarmup() error { func (p *StalenessTxnContextProvider) AdviseOptimizeWithPlan(_ interface{}) error { return nil } + +// GetSnapshotWithStmtReadTS get snapshot with read ts and set the transaction related options +// before return +func (p *StalenessTxnContextProvider) GetSnapshotWithStmtReadTS() (kv.Snapshot, error) { + txn, err := p.sctx.Txn(false) + if err != nil { + return nil, err + } + + if txn.Valid() { + return txn.GetSnapshot(), nil + } + + sessVars := p.sctx.GetSessionVars() + snapshot := sessiontxn.GetSnapshotWithTS(p.sctx, p.ts) + + replicaReadType := sessVars.GetReplicaRead() + if replicaReadType.IsFollowerRead() { + snapshot.SetOption(kv.ReplicaRead, replicaReadType) + } + snapshot.SetOption(kv.IsStalenessReadOnly, true) + + return snapshot, nil +} + +// GetSnapshotWithStmtForUpdateTS get snapshot with for update ts +func (p *StalenessTxnContextProvider) GetSnapshotWithStmtForUpdateTS() (kv.Snapshot, error) { + return nil, errors.New("GetSnapshotWithStmtForUpdateTS not supported for stalenessTxnProvider") +} diff --git a/sessiontxn/txn.go b/sessiontxn/txn.go index 890c17590939c..6955f9c5ee7fb 100644 --- a/sessiontxn/txn.go +++ b/sessiontxn/txn.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/table/temptable" "github.com/tikv/client-go/v2/oracle" ) @@ -71,6 +72,13 @@ func CanReuseTxnWhenExplicitBegin(sctx sessionctx.Context) bool { return txnCtx.History == nil && !txnCtx.IsStaleness && sessVars.SnapshotTS == 0 } +// GetSnapshotWithTS returns a snapshot with ts. +func GetSnapshotWithTS(s sessionctx.Context, ts uint64) kv.Snapshot { + snap := s.GetStore().GetSnapshot(kv.Version{Ver: ts}) + snap.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(s)) + return snap +} + // SetTxnAssertionLevel sets assertion level of a transactin. Note that assertion level should be set only once just // after creating a new transaction. func SetTxnAssertionLevel(txn kv.Transaction, assertionLevel variable.AssertionLevel) { diff --git a/sessiontxn/txn_manager_test.go b/sessiontxn/txn_manager_test.go index 29d7da62a7afb..137a06bcd7ce1 100644 --- a/sessiontxn/txn_manager_test.go +++ b/sessiontxn/txn_manager_test.go @@ -15,6 +15,7 @@ package sessiontxn_test import ( + "bytes" "context" "testing" @@ -250,6 +251,195 @@ func TestEnterNewTxn(t *testing.T) { } } +func TestGetSnapshot(t *testing.T) { + store, _, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") + tk.MustExec("use test") + tk.MustExec("create table t (id int primary key)") + + isSnapshotEqual := func(t *testing.T, snap1 kv.Snapshot, snap2 kv.Snapshot) bool { + require.NotNil(t, snap1) + require.NotNil(t, snap2) + + iter1, err := snap1.Iter([]byte{}, []byte{}) + require.NoError(t, err) + iter2, err := snap2.Iter([]byte{}, []byte{}) + require.NoError(t, err) + + for { + if iter1.Valid() && iter2.Valid() { + if iter1.Key().Cmp(iter2.Key()) != 0 { + return false + } + if !bytes.Equal(iter1.Value(), iter2.Value()) { + return false + } + err = iter1.Next() + require.NoError(t, err) + err = iter2.Next() + require.NoError(t, err) + } else if !iter1.Valid() && !iter2.Valid() { + return true + } else { + return false + } + } + } + + mgr := sessiontxn.GetTxnManager(tk.Session()) + + cases := []struct { + isolation string + prepare func(t *testing.T) + check func(t *testing.T, sctx sessionctx.Context) + }{ + { + isolation: "Pessimistic Repeatable Read", + prepare: func(t *testing.T) { + tk.MustExec("set @@tx_isolation='REPEATABLE-READ'") + tk.MustExec("begin pessimistic") + }, + check: func(t *testing.T, sctx sessionctx.Context) { + ts, err := mgr.GetStmtReadTS() + require.NoError(t, err) + compareSnap := sessiontxn.GetSnapshotWithTS(sctx, ts) + snap, err := mgr.GetReadSnapshot() + require.NoError(t, err) + require.True(t, isSnapshotEqual(t, compareSnap, snap)) + + tk2.MustExec("insert into t values(10)") + + tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3", "10")) + ts, err = mgr.GetStmtForUpdateTS() + require.NoError(t, err) + compareSnap2 := sessiontxn.GetSnapshotWithTS(sctx, ts) + snap, err = mgr.GetReadSnapshot() + require.NoError(t, err) + require.False(t, isSnapshotEqual(t, compareSnap2, snap)) + snap, err = mgr.GetForUpdateSnapshot() + require.NoError(t, err) + require.True(t, isSnapshotEqual(t, compareSnap2, snap)) + + require.False(t, isSnapshotEqual(t, compareSnap, snap)) + }, + }, + { + isolation: "Pessimistic Read Committed", + prepare: func(t *testing.T) { + tk.MustExec("set tx_isolation = 'READ-COMMITTED'") + tk.MustExec("begin pessimistic") + }, + check: func(t *testing.T, sctx sessionctx.Context) { + ts, err := mgr.GetStmtReadTS() + require.NoError(t, err) + compareSnap := sessiontxn.GetSnapshotWithTS(sctx, ts) + snap, err := mgr.GetReadSnapshot() + require.NoError(t, err) + require.True(t, isSnapshotEqual(t, compareSnap, snap)) + + tk2.MustExec("insert into t values(10)") + + tk.MustQuery("select * from t").Check(testkit.Rows("1", "3", "10")) + ts, err = mgr.GetStmtForUpdateTS() + require.NoError(t, err) + compareSnap2 := sessiontxn.GetSnapshotWithTS(sctx, ts) + snap, err = mgr.GetReadSnapshot() + require.NoError(t, err) + require.True(t, isSnapshotEqual(t, compareSnap2, snap)) + snap, err = mgr.GetForUpdateSnapshot() + require.NoError(t, err) + require.True(t, isSnapshotEqual(t, compareSnap2, snap)) + + require.False(t, isSnapshotEqual(t, compareSnap, snap)) + }, + }, + { + isolation: "Optimistic", + prepare: func(t *testing.T) { + tk.MustExec("begin optimistic") + }, + check: func(t *testing.T, sctx sessionctx.Context) { + ts, err := mgr.GetStmtReadTS() + require.NoError(t, err) + compareSnap := sessiontxn.GetSnapshotWithTS(sctx, ts) + snap, err := mgr.GetReadSnapshot() + require.NoError(t, err) + require.True(t, isSnapshotEqual(t, compareSnap, snap)) + + tk2.MustExec("insert into t values(10)") + + tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3")) + ts, err = mgr.GetStmtForUpdateTS() + require.NoError(t, err) + compareSnap2 := sessiontxn.GetSnapshotWithTS(sctx, ts) + snap, err = mgr.GetReadSnapshot() + require.NoError(t, err) + require.True(t, isSnapshotEqual(t, compareSnap2, snap)) + snap, err = mgr.GetForUpdateSnapshot() + require.NoError(t, err) + require.True(t, isSnapshotEqual(t, compareSnap2, snap)) + + require.True(t, isSnapshotEqual(t, compareSnap, snap)) + }, + }, + { + isolation: "Pessimistic Serializable", + prepare: func(t *testing.T) { + tk.MustExec("set tidb_skip_isolation_level_check = 1") + tk.MustExec("set tx_isolation = 'SERIALIZABLE'") + tk.MustExec("begin pessimistic") + }, + check: func(t *testing.T, sctx sessionctx.Context) { + ts, err := mgr.GetStmtReadTS() + require.NoError(t, err) + compareSnap := sessiontxn.GetSnapshotWithTS(sctx, ts) + snap, err := mgr.GetReadSnapshot() + require.NoError(t, err) + require.True(t, isSnapshotEqual(t, compareSnap, snap)) + + tk2.MustExec("insert into t values(10)") + + tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3")) + ts, err = mgr.GetStmtForUpdateTS() + require.NoError(t, err) + compareSnap2 := sessiontxn.GetSnapshotWithTS(sctx, ts) + snap, err = mgr.GetReadSnapshot() + require.NoError(t, err) + require.True(t, isSnapshotEqual(t, compareSnap2, snap)) + snap, err = mgr.GetForUpdateSnapshot() + require.NoError(t, err) + require.True(t, isSnapshotEqual(t, compareSnap2, snap)) + + require.True(t, isSnapshotEqual(t, compareSnap, snap)) + }, + }, + } + + for _, c := range cases { + t.Run(c.isolation, func(t *testing.T) { + se := tk.Session() + tk.MustExec("truncate t") + tk.MustExec("set @@tidb_txn_mode=''") + tk.MustExec("set @@autocommit=1") + tk.MustExec("insert into t values(1), (3)") + tk.MustExec("commit") + + if c.prepare != nil { + c.prepare(t) + } + + if c.check != nil { + c.check(t, se) + } + tk.MustExec("rollback") + }) + } +} + func checkBasicActiveTxn(t *testing.T, sctx sessionctx.Context) kv.Transaction { txn, err := sctx.Txn(false) require.NoError(t, err) diff --git a/tests/realtikvtest/sessiontest/temporary_table_test.go b/tests/realtikvtest/sessiontest/temporary_table_test.go index 796f51cd68287..67f94d381fbf7 100644 --- a/tests/realtikvtest/sessiontest/temporary_table_test.go +++ b/tests/realtikvtest/sessiontest/temporary_table_test.go @@ -337,7 +337,7 @@ func TestTemporaryTableInterceptor(t *testing.T) { } // Also check GetSnapshotWithTS - snap := tk.Session().GetSnapshotWithTS(0) + snap := sessiontxn.GetSnapshotWithTS(tk.Session(), 0) val, err := snap.Get(context.Background(), k) require.NoError(t, err) require.Equal(t, []byte("v1"), val) diff --git a/util/mock/context.go b/util/mock/context.go index fcd182a64297a..b8be5b9ddaaf9 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -254,11 +254,6 @@ func (c *Context) NewStaleTxnWithStartTS(ctx context.Context, startTS uint64) er return c.NewTxn(ctx) } -// GetSnapshotWithTS return a snapshot with ts -func (c *Context) GetSnapshotWithTS(ts uint64) kv.Snapshot { - return c.Store.GetSnapshot(kv.Version{Ver: ts}) -} - // RefreshTxnCtx implements the sessionctx.Context interface. func (c *Context) RefreshTxnCtx(ctx context.Context) error { return errors.Trace(c.NewTxn(ctx))