diff --git a/pkg/ddl/partition.go b/pkg/ddl/partition.go index 55ef0d59b624c..dfc090b14362a 100644 --- a/pkg/ddl/partition.go +++ b/pkg/ddl/partition.go @@ -3526,15 +3526,16 @@ func (w *reorgPartitionWorker) fetchRowColVals(txn kv.Transaction, taskRange reo // Non-clustered table / not unique _tidb_rowid for the whole table // Generate new _tidb_rowid if exists. // Due to EXCHANGE PARTITION, the existing _tidb_rowid may collide between partitions! - stmtCtx := w.sessCtx.GetSessionVars().StmtCtx - if stmtCtx.BaseRowID >= stmtCtx.MaxRowID { + if reserved, ok := w.tblCtx.GetReservedRowIDAlloc(); ok && reserved.Exhausted() { // TODO: Which autoid allocator to use? ids := uint64(max(1, w.batchCnt-len(w.rowRecords))) // Keep using the original table's allocator - stmtCtx.BaseRowID, stmtCtx.MaxRowID, err = tables.AllocHandleIDs(w.ctx, w.tblCtx, w.reorgedTbl, ids) + var baseRowID, maxRowID int64 + baseRowID, maxRowID, err = tables.AllocHandleIDs(w.ctx, w.tblCtx, w.reorgedTbl, ids) if err != nil { return false, errors.Trace(err) } + reserved.Reset(baseRowID, maxRowID) } recordID, err := tables.AllocHandle(w.ctx, w.tblCtx, w.reorgedTbl) if err != nil { diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 1bef849f15ced..ffff683a5fca4 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -105,6 +105,33 @@ func (rf *ReferenceCount) UnFreeze() { atomic.StoreInt32((*int32)(rf), ReferenceCountNoReference) } +// ReservedRowIDAlloc is used to reserve autoID for the auto_increment column. +type ReservedRowIDAlloc struct { + base int64 + max int64 +} + +// Reset resets the base and max of reserved rowIDs. +func (r *ReservedRowIDAlloc) Reset(base int64, max int64) { + r.base = base + r.max = max +} + +// Consume consumes a reserved rowID. +// If the second return value is false, it means the reserved rowID is exhausted. +func (r *ReservedRowIDAlloc) Consume() (int64, bool) { + if r.base < r.max { + r.base++ + return r.base, true + } + return 0, false +} + +// Exhausted returns whether the reserved rowID is exhausted. +func (r *ReservedRowIDAlloc) Exhausted() bool { + return r.base >= r.max +} + // StatementContext contains variables for a statement. // It should be reset before executing a statement. type StatementContext struct { @@ -223,8 +250,8 @@ type StatementContext struct { // InsertID is the given insert ID of an auto_increment column. InsertID uint64 - BaseRowID int64 - MaxRowID int64 + // ReservedRowIDAlloc is used to alloc auto ID from the reserved IDs. + ReservedRowIDAlloc ReservedRowIDAlloc // Copied from SessionVars.TimeZone. Priority mysql.PriorityEnum @@ -972,8 +999,7 @@ func (sc *StatementContext) resetMuForRetry() { // ResetForRetry resets the changed states during execution. func (sc *StatementContext) ResetForRetry() { sc.resetMuForRetry() - sc.MaxRowID = 0 - sc.BaseRowID = 0 + sc.ReservedRowIDAlloc.Reset(0, 0) sc.TableIDs = sc.TableIDs[:0] sc.IndexNames = sc.IndexNames[:0] sc.TaskID = AllocateTaskID() diff --git a/pkg/sessionctx/stmtctx/stmtctx_test.go b/pkg/sessionctx/stmtctx/stmtctx_test.go index 6238e37a52756..7809bbfb6386f 100644 --- a/pkg/sessionctx/stmtctx/stmtctx_test.go +++ b/pkg/sessionctx/stmtctx/stmtctx_test.go @@ -464,6 +464,32 @@ func TestErrCtx(t *testing.T) { require.Equal(t, errctx.NewContextWithLevels(levels, sc), sc.ErrCtx()) } +func TestReservedRowIDAlloc(t *testing.T) { + var reserved stmtctx.ReservedRowIDAlloc + // no reserved by default + require.True(t, reserved.Exhausted()) + id, ok := reserved.Consume() + require.False(t, ok) + require.Equal(t, int64(0), id) + // reset some ids + reserved.Reset(12, 15) + require.False(t, reserved.Exhausted()) + id, ok = reserved.Consume() + require.True(t, ok) + require.Equal(t, int64(13), id) + id, ok = reserved.Consume() + require.True(t, ok) + require.Equal(t, int64(14), id) + id, ok = reserved.Consume() + require.True(t, ok) + require.Equal(t, int64(15), id) + // exhausted + require.True(t, reserved.Exhausted()) + id, ok = reserved.Consume() + require.False(t, ok) + require.Equal(t, int64(0), id) +} + func BenchmarkErrCtx(b *testing.B) { sc := stmtctx.NewStmtCtx() diff --git a/pkg/sessionctx/variable/session.go b/pkg/sessionctx/variable/session.go index d51e46be1fcd2..f0134cbd44ab8 100644 --- a/pkg/sessionctx/variable/session.go +++ b/pkg/sessionctx/variable/session.go @@ -204,9 +204,7 @@ type TxnCtxNoNeedToRestore struct { StaleReadTs uint64 // ShardStep indicates the max size of continuous rowid shard in one transaction. - ShardStep int - shardRemain int - currentShard int64 + RowIDShardGenerator RowIDShardGenerator // unchangedKeys is used to store the unchanged keys that needs to lock for pessimistic transaction. unchangedKeys map[string]struct{} @@ -269,24 +267,68 @@ type SavepointRecord struct { TxnCtxSavepoint TxnCtxNeedToRestore } +// RowIDShardGenerator is used to generate shard for row id. +type RowIDShardGenerator struct { + // shardRand is used for generated rand shard + shardSeed int64 + shardRand *rand.Rand + // shardStep indicates the max size of continuous rowid shard in one transaction. + shardStep int + shardRemain int + currentShard int64 +} + +// SetShardStep sets the step of shard +func (s *RowIDShardGenerator) SetShardStep(step int) { + s.shardStep = step +} + +// GetShardStep returns the shard step +func (s *RowIDShardGenerator) GetShardStep() int { + if s.shardStep <= 0 { + return DefTiDBShardAllocateStep + } + return s.shardStep +} + +// SetShardSeed sets the rand seed to generate shard +func (s *RowIDShardGenerator) SetShardSeed(seed int64) { + s.shardSeed = seed + s.shardRand = nil +} + +// GetShardSeed returns the seed to generate shard +func (s *RowIDShardGenerator) GetShardSeed() int64 { + return s.shardSeed +} + // GetCurrentShard returns the shard for the next `count` IDs. -func (s *SessionVars) GetCurrentShard(count int) int64 { - tc := s.TxnCtx +func (s *RowIDShardGenerator) GetCurrentShard(count int) int64 { if s.shardRand == nil { - s.shardRand = rand.New(rand.NewSource(int64(tc.StartTS))) // #nosec G404 + seed := s.shardSeed + if seed == 0 { + seed = time.Now().UnixNano() + } + s.shardRand = rand.New(rand.NewSource(seed)) // #nosec G404 } - if tc.shardRemain <= 0 { - tc.updateShard(s.shardRand) - tc.shardRemain = tc.ShardStep + if s.shardRemain <= 0 { + s.updateShard(s.shardRand) + s.shardRemain = s.GetShardStep() } - tc.shardRemain -= count - return tc.currentShard + s.shardRemain -= count + return s.currentShard } -func (tc *TransactionContext) updateShard(shardRand *rand.Rand) { +func (s *RowIDShardGenerator) updateShard(shardRand *rand.Rand) { var buf [8]byte binary.LittleEndian.PutUint64(buf[:], shardRand.Uint64()) - tc.currentShard = int64(murmur3.Sum32(buf[:])) + s.currentShard = int64(murmur3.Sum32(buf[:])) +} + +// GetCurrentShard returns the shard for the next `count` IDs. +func (s *SessionVars) GetCurrentShard(count int) int64 { + txnCtx := s.TxnCtx + return txnCtx.RowIDShardGenerator.GetCurrentShard(count) } // AddUnchangedKeyForLock adds an unchanged key for pessimistic lock. @@ -1514,9 +1556,6 @@ type SessionVars struct { // StoreBatchSize indicates the batch size limit of store batch, set this field to 0 to disable store batch. StoreBatchSize int - // shardRand is used by TxnCtx, for the GetCurrentShard() method. - shardRand *rand.Rand - // Resource group name // NOTE: all statement relate operation should use StmtCtx.ResourceGroupName instead. // NOTE: please don't change it directly. Use `SetResourceGroupName`, because it'll need to inc/dec the metrics diff --git a/pkg/sessionctx/variable/session_test.go b/pkg/sessionctx/variable/session_test.go index 8bb032612e0a8..3d7600b4f8e10 100644 --- a/pkg/sessionctx/variable/session_test.go +++ b/pkg/sessionctx/variable/session_test.go @@ -602,3 +602,28 @@ func TestMapDeltaCols(t *testing.T) { } } } + +func TestRowIDShardGenerator(t *testing.T) { + var g variable.RowIDShardGenerator + // default settings + require.Equal(t, variable.DefTiDBShardAllocateStep, g.GetShardStep()) + require.Equal(t, int64(0), g.GetShardSeed()) + shard := g.GetCurrentShard(variable.DefTiDBShardAllocateStep - 1) + require.Equal(t, shard, g.GetCurrentShard(1)) + // reset state + g.SetShardStep(5) + require.Equal(t, 5, g.GetShardStep()) + g.SetShardSeed(12345) + require.Equal(t, int64(12345), g.GetShardSeed()) + // generate shard in step + shard = g.GetCurrentShard(1) + require.Equal(t, int64(3535546008), shard) + require.Equal(t, shard, g.GetCurrentShard(1)) + require.Equal(t, shard, g.GetCurrentShard(1)) + require.Equal(t, shard, g.GetCurrentShard(2)) + // generate shard in next step + shard = g.GetCurrentShard(5) + require.Equal(t, int64(1371624976), shard) + // shard should be different in each step + require.NotEqual(t, shard, g.GetCurrentShard(1)) +} diff --git a/pkg/sessiontxn/isolation/base.go b/pkg/sessiontxn/isolation/base.go index 6f5a516d2754e..b30a17af9075a 100644 --- a/pkg/sessiontxn/isolation/base.go +++ b/pkg/sessiontxn/isolation/base.go @@ -120,10 +120,10 @@ func (p *baseTxnContextProvider) OnInitialize(ctx context.Context, tp sessiontxn TxnCtxNoNeedToRestore: variable.TxnCtxNoNeedToRestore{ CreateTime: time.Now(), InfoSchema: p.infoSchema, - ShardStep: int(sessVars.ShardAllocateStep), TxnScope: sessVars.CheckAndGetTxnScope(), }, } + txnCtx.RowIDShardGenerator.SetShardStep(int(sessVars.ShardAllocateStep)) if p.onInitializeTxnCtx != nil { p.onInitializeTxnCtx(txnCtx) } @@ -294,7 +294,9 @@ func (p *baseTxnContextProvider) ActivateTxn() (kv.Transaction, error) { sessVars := p.sctx.GetSessionVars() sessVars.TxnCtxMu.Lock() - sessVars.TxnCtx.StartTS = txn.StartTS() + startTS := txn.StartTS() + sessVars.TxnCtx.StartTS = startTS + sessVars.TxnCtx.RowIDShardGenerator.SetShardSeed(int64(startTS)) sessVars.TxnCtxMu.Unlock() if sessVars.MemDBFootprint != nil { sessVars.MemDBFootprint.Detach() diff --git a/pkg/sessiontxn/isolation/main_test.go b/pkg/sessiontxn/isolation/main_test.go index 1c14db4138442..5bbb175c40e4e 100644 --- a/pkg/sessiontxn/isolation/main_test.go +++ b/pkg/sessiontxn/isolation/main_test.go @@ -81,7 +81,7 @@ func (a *txnAssert[T]) Check(t testing.TB) { require.Equal(t, a.isolation, txnCtx.Isolation) require.Equal(t, a.isolation != "", txnCtx.IsPessimistic) require.Equal(t, sessVars.CheckAndGetTxnScope(), txnCtx.TxnScope) - require.Equal(t, sessVars.ShardAllocateStep, int64(txnCtx.ShardStep)) + require.Equal(t, sessVars.ShardAllocateStep, int64(txnCtx.RowIDShardGenerator.GetShardStep())) require.False(t, txnCtx.IsStaleness) require.GreaterOrEqual(t, txnCtx.CreateTime.UnixNano(), a.minStartTime.UnixNano()) require.Equal(t, a.inTxn, sessVars.InTxn()) diff --git a/pkg/sessiontxn/staleread/provider.go b/pkg/sessiontxn/staleread/provider.go index 44679990685b0..7da1c68a4b4a1 100644 --- a/pkg/sessiontxn/staleread/provider.go +++ b/pkg/sessiontxn/staleread/provider.go @@ -120,11 +120,11 @@ func (p *StalenessTxnContextProvider) activateStaleTxn() error { InfoSchema: is, CreateTime: time.Now(), StartTS: txn.StartTS(), - ShardStep: int(sessVars.ShardAllocateStep), IsStaleness: true, TxnScope: txnScope, }, } + sessVars.TxnCtx.RowIDShardGenerator.SetShardStep(int(sessVars.ShardAllocateStep)) sessVars.TxnCtxMu.Unlock() if interceptor := temptable.SessionSnapshotInterceptor(p.sctx, is); interceptor != nil { diff --git a/pkg/table/context/table.go b/pkg/table/context/table.go index 2d6c32da8cce9..9919faac36795 100644 --- a/pkg/table/context/table.go +++ b/pkg/table/context/table.go @@ -19,6 +19,7 @@ import ( infoschema "github.com/pingcap/tidb/pkg/infoschema/context" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/util/rowcodec" "github.com/pingcap/tidb/pkg/util/tableutil" @@ -71,6 +72,10 @@ type MutateContext interface { // which is a buffer for table related structures that aims to reuse memory and // saves allocation. GetMutateBuffers() *MutateBuffers + // GetRowIDShardGenerator returns the `RowIDShardGenerator` object to shard rows. + GetRowIDShardGenerator() *variable.RowIDShardGenerator + // GetReservedRowIDAlloc returns the `ReservedRowIDAlloc` object to allocate row id from reservation. + GetReservedRowIDAlloc() (*stmtctx.ReservedRowIDAlloc, bool) } // AllocatorContext is used to provide context for method `table.Allocators`. diff --git a/pkg/table/contextimpl/table.go b/pkg/table/contextimpl/table.go index e80eb92ecc54b..3ec932e96cfd9 100644 --- a/pkg/table/contextimpl/table.go +++ b/pkg/table/contextimpl/table.go @@ -18,8 +18,10 @@ import ( exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table/context" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/tableutil" "github.com/pingcap/tipb/go-binlog" ) @@ -99,6 +101,24 @@ func (ctx *TableContextImpl) GetMutateBuffers() *context.MutateBuffers { return ctx.mutateBuffers } +// GetRowIDShardGenerator implements the MutateContext interface. +func (ctx *TableContextImpl) GetRowIDShardGenerator() *variable.RowIDShardGenerator { + return &ctx.vars().TxnCtx.RowIDShardGenerator +} + +// GetReservedRowIDAlloc implements the MutateContext interface. +func (ctx *TableContextImpl) GetReservedRowIDAlloc() (*stmtctx.ReservedRowIDAlloc, bool) { + if sc := ctx.vars().StmtCtx; sc != nil { + return &sc.ReservedRowIDAlloc, true + } + // `StmtCtx` should not be nil in the `variable.SessionVars`. + // We just put an assertion that will panic only if in test here. + // In production code, here returns (nil, false) to make code safe + // because some old code checks `StmtCtx != nil` but we don't know why. + intest.Assert(false, "SessionVars.StmtCtx should not be nil") + return nil, false +} + func (ctx *TableContextImpl) vars() *variable.SessionVars { return ctx.Context.GetSessionVars() } diff --git a/pkg/table/contextimpl/table_test.go b/pkg/table/contextimpl/table_test.go index dda098578985a..4ebe53bcfd674 100644 --- a/pkg/table/contextimpl/table_test.go +++ b/pkg/table/contextimpl/table_test.go @@ -82,4 +82,10 @@ func TestMutateContextImplFields(t *testing.T) { require.Equal(t, sctx.GetSessionVars().IsRowLevelChecksumEnabled(), cfg.IsRowLevelChecksumEnabled) // mutate buffers require.NotNil(t, ctx.GetMutateBuffers()) + // RowIDShardGenerator + require.Same(t, &sctx.GetSessionVars().TxnCtx.RowIDShardGenerator, ctx.GetRowIDShardGenerator()) + // ReservedRowIDAlloc + reserved, ok := ctx.GetReservedRowIDAlloc() + require.True(t, ok) + require.Same(t, &sctx.GetSessionVars().StmtCtx.ReservedRowIDAlloc, reserved) } diff --git a/pkg/table/tables/tables.go b/pkg/table/tables/tables.go index 88445d3df8aeb..a0f51d1f328b7 100644 --- a/pkg/table/tables/tables.go +++ b/pkg/table/tables/tables.go @@ -838,11 +838,12 @@ func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts // The reserved ID could be used in the future within this statement, by the // following AddRecord() operation. // Make the IDs continuous benefit for the performance of TiKV. - sessVars := sctx.GetSessionVars() - stmtCtx := sessVars.StmtCtx - stmtCtx.BaseRowID, stmtCtx.MaxRowID, err = AllocHandleIDs(ctx, sctx, t, uint64(opt.ReserveAutoID)) - if err != nil { - return nil, err + if reserved, ok := sctx.GetReservedRowIDAlloc(); ok { + var baseRowID, maxRowID int64 + if baseRowID, maxRowID, err = AllocHandleIDs(ctx, sctx, t, uint64(opt.ReserveAutoID)); err != nil { + return nil, err + } + reserved.Reset(baseRowID, maxRowID) } } @@ -1604,11 +1605,10 @@ func GetColDefaultValue(ctx exprctx.BuildContext, col *table.Column, defaultVals func AllocHandle(ctx context.Context, mctx table.MutateContext, t table.Table) (kv.IntHandle, error) { if mctx != nil { - if stmtCtx := mctx.GetSessionVars().StmtCtx; stmtCtx != nil { + if reserved, ok := mctx.GetReservedRowIDAlloc(); ok { // First try to alloc if the statement has reserved auto ID. - if stmtCtx.BaseRowID < stmtCtx.MaxRowID { - stmtCtx.BaseRowID++ - return kv.IntHandle(stmtCtx.BaseRowID), nil + if rowID, ok := reserved.Consume(); ok { + return kv.IntHandle(rowID), nil } } } @@ -1638,7 +1638,7 @@ func AllocHandleIDs(ctx context.Context, mctx table.MutateContext, t table.Table // shard = 0010000000000000000000000000000000000000000000000000000000000000 return 0, 0, autoid.ErrAutoincReadFailed } - shard := mctx.GetSessionVars().GetCurrentShard(int(n)) + shard := mctx.GetRowIDShardGenerator().GetCurrentShard(int(n)) base = shardFmt.Compose(shard, base) maxID = shardFmt.Compose(shard, maxID) }