diff --git a/pkg/table/context/buffers.go b/pkg/table/context/buffers.go index a4c9cd373f872..12e6722947aa5 100644 --- a/pkg/table/context/buffers.go +++ b/pkg/table/context/buffers.go @@ -148,6 +148,7 @@ func (b *ColSizeDeltaBuffer) UpdateColSizeMap(m map[int64]int64) map[int64]int64 // Because inner slices are reused, you should not call the get methods again before finishing the previous usage. // Otherwise, the previous data will be overwritten. type MutateBuffers struct { + stmtBufs *variable.WriteStmtBufs encodeRow *EncodeRowBuffer checkRow *CheckRowBuffer colSizeDelta *ColSizeDeltaBuffer @@ -156,6 +157,7 @@ type MutateBuffers struct { // NewMutateBuffers creates a new `MutateBuffers`. func NewMutateBuffers(stmtBufs *variable.WriteStmtBufs) *MutateBuffers { return &MutateBuffers{ + stmtBufs: stmtBufs, encodeRow: &EncodeRowBuffer{ writeStmtBufs: stmtBufs, }, @@ -204,6 +206,11 @@ func (b *MutateBuffers) GetColSizeDeltaBufferWithCap(capacity int) *ColSizeDelta return buffer } +// GetWriteStmtBufs returns the `*variable.WriteStmtBufs` +func (b *MutateBuffers) GetWriteStmtBufs() *variable.WriteStmtBufs { + return b.stmtBufs +} + // ensureCapacityAndReset is similar to the built-in make(), // but it reuses the given slice if it has enough capacity. func ensureCapacityAndReset[T any](slice []T, size int, optCap ...int) []T { diff --git a/pkg/table/context/buffers_test.go b/pkg/table/context/buffers_test.go index 7aa0f5f3ef443..d6d20d3c9e514 100644 --- a/pkg/table/context/buffers_test.go +++ b/pkg/table/context/buffers_test.go @@ -243,6 +243,8 @@ func TestMutateBuffersGetter(t *testing.T) { colSize := buffers.GetColSizeDeltaBufferWithCap(6) require.Equal(t, 6, cap(colSize.delta)) + + require.Same(t, stmtBufs, buffers.GetWriteStmtBufs()) } func TestEnsureCapacityAndReset(t *testing.T) { diff --git a/pkg/table/context/table.go b/pkg/table/context/table.go index 02619a2eb92fb..2d6c32da8cce9 100644 --- a/pkg/table/context/table.go +++ b/pkg/table/context/table.go @@ -56,8 +56,15 @@ type MutateContext interface { // TxnRecordTempTable record the temporary table to the current transaction. // This method will be called when the temporary table is modified or should allocate id in the transaction. TxnRecordTempTable(tbl *model.TableInfo) tableutil.TempTable + // ConnectionID returns the id of the current connection. + // If the current environment is not in a query from the client, the return value is 0. + ConnectionID() uint64 // InRestrictedSQL returns whether the current context is used in restricted SQL. InRestrictedSQL() bool + // TxnAssertionLevel returns the assertion level of the current transaction. + TxnAssertionLevel() variable.AssertionLevel + // EnableMutationChecker returns whether to check data consistency for mutations. + EnableMutationChecker() bool // GetRowEncodingConfig returns the RowEncodingConfig. GetRowEncodingConfig() RowEncodingConfig // GetMutateBuffers returns the MutateBuffers, diff --git a/pkg/table/contextimpl/BUILD.bazel b/pkg/table/contextimpl/BUILD.bazel index 7e512b82cf7e2..2c263e6e2fa6e 100644 --- a/pkg/table/contextimpl/BUILD.bazel +++ b/pkg/table/contextimpl/BUILD.bazel @@ -24,6 +24,7 @@ go_test( deps = [ ":contextimpl", "//pkg/sessionctx/binloginfo", + "//pkg/sessionctx/variable", "//pkg/testkit", "//pkg/util/mock", "@com_github_pingcap_tipb//go-binlog", diff --git a/pkg/table/contextimpl/table.go b/pkg/table/contextimpl/table.go index 3c3c073ddeed4..e80eb92ecc54b 100644 --- a/pkg/table/contextimpl/table.go +++ b/pkg/table/contextimpl/table.go @@ -55,9 +55,24 @@ func (ctx *TableContextImpl) GetExprCtx() exprctx.ExprContext { return ctx.Context.GetExprCtx() } +// ConnectionID implements the MutateContext interface. +func (ctx *TableContextImpl) ConnectionID() uint64 { + return ctx.vars().ConnectionID +} + // InRestrictedSQL returns whether the current context is used in restricted SQL. func (ctx *TableContextImpl) InRestrictedSQL() bool { - return ctx.vars().StmtCtx.InRestrictedSQL + return ctx.vars().InRestrictedSQL +} + +// TxnAssertionLevel implements the MutateContext interface. +func (ctx *TableContextImpl) TxnAssertionLevel() variable.AssertionLevel { + return ctx.vars().AssertionLevel +} + +// EnableMutationChecker implements the MutateContext interface. +func (ctx *TableContextImpl) EnableMutationChecker() bool { + return ctx.vars().EnableMutationChecker } // BinlogEnabled returns whether the binlog is enabled. diff --git a/pkg/table/contextimpl/table_test.go b/pkg/table/contextimpl/table_test.go index bad7845dbdbf9..dda098578985a 100644 --- a/pkg/table/contextimpl/table_test.go +++ b/pkg/table/contextimpl/table_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table/contextimpl" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/util/mock" @@ -39,11 +40,24 @@ func TestMutateContextImplFields(t *testing.T) { binlogMutation := ctx.GetBinlogMutation(1234) require.NotNil(t, binlogMutation) require.Same(t, sctx.StmtGetMutation(1234), binlogMutation) + // ConnectionID + sctx.GetSessionVars().ConnectionID = 12345 + require.Equal(t, uint64(12345), ctx.ConnectionID()) // restricted SQL - sctx.GetSessionVars().StmtCtx.InRestrictedSQL = false + sctx.GetSessionVars().InRestrictedSQL = false require.False(t, ctx.InRestrictedSQL()) - sctx.GetSessionVars().StmtCtx.InRestrictedSQL = true + sctx.GetSessionVars().InRestrictedSQL = true require.True(t, ctx.InRestrictedSQL()) + // AssertionLevel + ctx.GetSessionVars().AssertionLevel = variable.AssertionLevelFast + require.Equal(t, variable.AssertionLevelFast, ctx.TxnAssertionLevel()) + ctx.GetSessionVars().AssertionLevel = variable.AssertionLevelStrict + require.Equal(t, variable.AssertionLevelStrict, ctx.TxnAssertionLevel()) + // EnableMutationChecker + ctx.GetSessionVars().EnableMutationChecker = true + require.True(t, ctx.EnableMutationChecker()) + ctx.GetSessionVars().EnableMutationChecker = false + require.False(t, ctx.EnableMutationChecker()) // encoding config sctx.GetSessionVars().EnableRowLevelChecksum = true sctx.GetSessionVars().RowEncoder.Enable = true diff --git a/pkg/table/tables/index.go b/pkg/table/tables/index.go index 68ad4a537ae83..0f9eaf6ebddc3 100644 --- a/pkg/table/tables/index.go +++ b/pkg/table/tables/index.go @@ -177,7 +177,7 @@ func (c *index) Create(sctx table.MutateContext, txn kv.Transaction, indexedValu ctx = context.TODO() } vars := sctx.GetSessionVars() - writeBufs := vars.GetWriteStmtBufs() + writeBufs := sctx.GetMutateBuffers().GetWriteStmtBufs() skipCheck := vars.StmtCtx.BatchCheck evalCtx := sctx.GetExprCtx().GetEvalCtx() loc, ec := evalCtx.Location(), evalCtx.ErrCtx() diff --git a/pkg/table/tables/tables.go b/pkg/table/tables/tables.go index 2d2ffb6191bc0..88445d3df8aeb 100644 --- a/pkg/table/tables/tables.go +++ b/pkg/table/tables/tables.go @@ -553,7 +553,7 @@ func (t *TableCommon) UpdateRecord(ctx context.Context, sctx table.MutateContext // Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect. // Since only the first assertion takes effect, set the injected assertion before setting the correct one to // override it. - if sctx.GetSessionVars().ConnectionID != 0 { + if sctx.ConnectionID() != 0 { logutil.BgLogger().Info("force asserting not exist on UpdateRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil { failpoint.Return(err) @@ -561,7 +561,7 @@ func (t *TableCommon) UpdateRecord(ctx context.Context, sctx table.MutateContext } }) - if t.shouldAssert(sessVars.AssertionLevel) { + if t.shouldAssert(sctx.TxnAssertionLevel()) { err = txn.SetAssertion(key, kv.SetAssertExist) } else { err = txn.SetAssertion(key, kv.SetAssertUnknown) @@ -573,7 +573,7 @@ func (t *TableCommon) UpdateRecord(ctx context.Context, sctx table.MutateContext if err = injectMutationError(t, txn, sh); err != nil { return err } - if sessVars.EnableMutationChecker { + if sctx.EnableMutationChecker() { if err = CheckDataConsistency(txn, tc, t, newData, oldData, memBuffer, sh); err != nil { return errors.Trace(err) } @@ -948,7 +948,7 @@ func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts if setPresume { flags = []kv.FlagsOp{kv.SetPresumeKeyNotExists} if !sessVars.ConstraintCheckInPlacePessimistic && sessVars.TxnCtx.IsPessimistic && sessVars.InTxn() && - !sessVars.InRestrictedSQL && sessVars.ConnectionID > 0 { + !sctx.InRestrictedSQL() && sctx.ConnectionID() > 0 { flags = append(flags, kv.SetNeedConstraintCheckInPrewrite) } } @@ -962,7 +962,7 @@ func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts // Assert the key exists while it actually doesn't. This is helpful to test if assertion takes effect. // Since only the first assertion takes effect, set the injected assertion before setting the correct one to // override it. - if sctx.GetSessionVars().ConnectionID != 0 { + if sctx.ConnectionID() != 0 { logutil.BgLogger().Info("force asserting exist on AddRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) if err = txn.SetAssertion(key, kv.SetAssertExist); err != nil { failpoint.Return(nil, err) @@ -996,7 +996,7 @@ func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts if err = injectMutationError(t, txn, sh); err != nil { return nil, err } - if sessVars.EnableMutationChecker { + if sctx.EnableMutationChecker() { if err = CheckDataConsistency(txn, tc, t, r, nil, memBuffer, sh); err != nil { return nil, errors.Trace(err) } @@ -1048,7 +1048,7 @@ func genIndexKeyStrs(colVals []types.Datum) ([]string, error) { // addIndices adds data into indices. If any key is duplicated, returns the original handle. func (t *TableCommon) addIndices(sctx table.MutateContext, recordID kv.Handle, r []types.Datum, txn kv.Transaction, opts []table.CreateIdxOptFunc) (kv.Handle, error) { - writeBufs := sctx.GetSessionVars().GetWriteStmtBufs() + writeBufs := sctx.GetMutateBuffers().GetWriteStmtBufs() indexVals := writeBufs.IndexValsBuf skipCheck := sctx.GetSessionVars().StmtCtx.BatchCheck for _, v := range t.Indices() { @@ -1264,7 +1264,7 @@ func (t *TableCommon) RemoveRecord(ctx table.MutateContext, h kv.Handle, r []typ } tc := ctx.GetExprCtx().GetEvalCtx().TypeCtx() - if ctx.GetSessionVars().EnableMutationChecker { + if ctx.EnableMutationChecker() { if err = CheckDataConsistency(txn, tc, t, nil, r, memBuffer, sh); err != nil { return errors.Trace(err) } @@ -1406,14 +1406,14 @@ func (t *TableCommon) removeRowData(ctx table.MutateContext, h kv.Handle) error // Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect. // Since only the first assertion takes effect, set the injected assertion before setting the correct one to // override it. - if ctx.GetSessionVars().ConnectionID != 0 { + if ctx.ConnectionID() != 0 { logutil.BgLogger().Info("force asserting not exist on RemoveRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil { failpoint.Return(err) } } }) - if t.shouldAssert(ctx.GetSessionVars().AssertionLevel) { + if t.shouldAssert(ctx.TxnAssertionLevel()) { err = txn.SetAssertion(key, kv.SetAssertExist) } else { err = txn.SetAssertion(key, kv.SetAssertUnknown)