Skip to content

Commit

Permalink
table: expose some fields to MutateContext from GetSessionVars() (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored Jul 19, 2024
1 parent 267e1c4 commit a23af54
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 14 deletions.
7 changes: 7 additions & 0 deletions pkg/table/context/buffers.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ func (b *ColSizeDeltaBuffer) UpdateColSizeMap(m map[int64]int64) map[int64]int64
// Because inner slices are reused, you should not call the get methods again before finishing the previous usage.
// Otherwise, the previous data will be overwritten.
type MutateBuffers struct {
stmtBufs *variable.WriteStmtBufs
encodeRow *EncodeRowBuffer
checkRow *CheckRowBuffer
colSizeDelta *ColSizeDeltaBuffer
Expand All @@ -156,6 +157,7 @@ type MutateBuffers struct {
// NewMutateBuffers creates a new `MutateBuffers`.
func NewMutateBuffers(stmtBufs *variable.WriteStmtBufs) *MutateBuffers {
return &MutateBuffers{
stmtBufs: stmtBufs,
encodeRow: &EncodeRowBuffer{
writeStmtBufs: stmtBufs,
},
Expand Down Expand Up @@ -204,6 +206,11 @@ func (b *MutateBuffers) GetColSizeDeltaBufferWithCap(capacity int) *ColSizeDelta
return buffer
}

// GetWriteStmtBufs returns the `*variable.WriteStmtBufs`
func (b *MutateBuffers) GetWriteStmtBufs() *variable.WriteStmtBufs {
return b.stmtBufs
}

// ensureCapacityAndReset is similar to the built-in make(),
// but it reuses the given slice if it has enough capacity.
func ensureCapacityAndReset[T any](slice []T, size int, optCap ...int) []T {
Expand Down
2 changes: 2 additions & 0 deletions pkg/table/context/buffers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ func TestMutateBuffersGetter(t *testing.T) {

colSize := buffers.GetColSizeDeltaBufferWithCap(6)
require.Equal(t, 6, cap(colSize.delta))

require.Same(t, stmtBufs, buffers.GetWriteStmtBufs())
}

func TestEnsureCapacityAndReset(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions pkg/table/context/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,15 @@ type MutateContext interface {
// TxnRecordTempTable record the temporary table to the current transaction.
// This method will be called when the temporary table is modified or should allocate id in the transaction.
TxnRecordTempTable(tbl *model.TableInfo) tableutil.TempTable
// ConnectionID returns the id of the current connection.
// If the current environment is not in a query from the client, the return value is 0.
ConnectionID() uint64
// InRestrictedSQL returns whether the current context is used in restricted SQL.
InRestrictedSQL() bool
// TxnAssertionLevel returns the assertion level of the current transaction.
TxnAssertionLevel() variable.AssertionLevel
// EnableMutationChecker returns whether to check data consistency for mutations.
EnableMutationChecker() bool
// GetRowEncodingConfig returns the RowEncodingConfig.
GetRowEncodingConfig() RowEncodingConfig
// GetMutateBuffers returns the MutateBuffers,
Expand Down
1 change: 1 addition & 0 deletions pkg/table/contextimpl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ go_test(
deps = [
":contextimpl",
"//pkg/sessionctx/binloginfo",
"//pkg/sessionctx/variable",
"//pkg/testkit",
"//pkg/util/mock",
"@com_github_pingcap_tipb//go-binlog",
Expand Down
17 changes: 16 additions & 1 deletion pkg/table/contextimpl/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,24 @@ func (ctx *TableContextImpl) GetExprCtx() exprctx.ExprContext {
return ctx.Context.GetExprCtx()
}

// ConnectionID implements the MutateContext interface.
func (ctx *TableContextImpl) ConnectionID() uint64 {
return ctx.vars().ConnectionID
}

// InRestrictedSQL returns whether the current context is used in restricted SQL.
func (ctx *TableContextImpl) InRestrictedSQL() bool {
return ctx.vars().StmtCtx.InRestrictedSQL
return ctx.vars().InRestrictedSQL
}

// TxnAssertionLevel implements the MutateContext interface.
func (ctx *TableContextImpl) TxnAssertionLevel() variable.AssertionLevel {
return ctx.vars().AssertionLevel
}

// EnableMutationChecker implements the MutateContext interface.
func (ctx *TableContextImpl) EnableMutationChecker() bool {
return ctx.vars().EnableMutationChecker
}

// BinlogEnabled returns whether the binlog is enabled.
Expand Down
18 changes: 16 additions & 2 deletions pkg/table/contextimpl/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"testing"

"github.com/pingcap/tidb/pkg/sessionctx/binloginfo"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/table/contextimpl"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/pingcap/tidb/pkg/util/mock"
Expand All @@ -39,11 +40,24 @@ func TestMutateContextImplFields(t *testing.T) {
binlogMutation := ctx.GetBinlogMutation(1234)
require.NotNil(t, binlogMutation)
require.Same(t, sctx.StmtGetMutation(1234), binlogMutation)
// ConnectionID
sctx.GetSessionVars().ConnectionID = 12345
require.Equal(t, uint64(12345), ctx.ConnectionID())
// restricted SQL
sctx.GetSessionVars().StmtCtx.InRestrictedSQL = false
sctx.GetSessionVars().InRestrictedSQL = false
require.False(t, ctx.InRestrictedSQL())
sctx.GetSessionVars().StmtCtx.InRestrictedSQL = true
sctx.GetSessionVars().InRestrictedSQL = true
require.True(t, ctx.InRestrictedSQL())
// AssertionLevel
ctx.GetSessionVars().AssertionLevel = variable.AssertionLevelFast
require.Equal(t, variable.AssertionLevelFast, ctx.TxnAssertionLevel())
ctx.GetSessionVars().AssertionLevel = variable.AssertionLevelStrict
require.Equal(t, variable.AssertionLevelStrict, ctx.TxnAssertionLevel())
// EnableMutationChecker
ctx.GetSessionVars().EnableMutationChecker = true
require.True(t, ctx.EnableMutationChecker())
ctx.GetSessionVars().EnableMutationChecker = false
require.False(t, ctx.EnableMutationChecker())
// encoding config
sctx.GetSessionVars().EnableRowLevelChecksum = true
sctx.GetSessionVars().RowEncoder.Enable = true
Expand Down
2 changes: 1 addition & 1 deletion pkg/table/tables/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (c *index) Create(sctx table.MutateContext, txn kv.Transaction, indexedValu
ctx = context.TODO()
}
vars := sctx.GetSessionVars()
writeBufs := vars.GetWriteStmtBufs()
writeBufs := sctx.GetMutateBuffers().GetWriteStmtBufs()
skipCheck := vars.StmtCtx.BatchCheck
evalCtx := sctx.GetExprCtx().GetEvalCtx()
loc, ec := evalCtx.Location(), evalCtx.ErrCtx()
Expand Down
20 changes: 10 additions & 10 deletions pkg/table/tables/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,15 +553,15 @@ func (t *TableCommon) UpdateRecord(ctx context.Context, sctx table.MutateContext
// Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect.
// Since only the first assertion takes effect, set the injected assertion before setting the correct one to
// override it.
if sctx.GetSessionVars().ConnectionID != 0 {
if sctx.ConnectionID() != 0 {
logutil.BgLogger().Info("force asserting not exist on UpdateRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS()))
if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil {
failpoint.Return(err)
}
}
})

if t.shouldAssert(sessVars.AssertionLevel) {
if t.shouldAssert(sctx.TxnAssertionLevel()) {
err = txn.SetAssertion(key, kv.SetAssertExist)
} else {
err = txn.SetAssertion(key, kv.SetAssertUnknown)
Expand All @@ -573,7 +573,7 @@ func (t *TableCommon) UpdateRecord(ctx context.Context, sctx table.MutateContext
if err = injectMutationError(t, txn, sh); err != nil {
return err
}
if sessVars.EnableMutationChecker {
if sctx.EnableMutationChecker() {
if err = CheckDataConsistency(txn, tc, t, newData, oldData, memBuffer, sh); err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -948,7 +948,7 @@ func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts
if setPresume {
flags = []kv.FlagsOp{kv.SetPresumeKeyNotExists}
if !sessVars.ConstraintCheckInPlacePessimistic && sessVars.TxnCtx.IsPessimistic && sessVars.InTxn() &&
!sessVars.InRestrictedSQL && sessVars.ConnectionID > 0 {
!sctx.InRestrictedSQL() && sctx.ConnectionID() > 0 {
flags = append(flags, kv.SetNeedConstraintCheckInPrewrite)
}
}
Expand All @@ -962,7 +962,7 @@ func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts
// Assert the key exists while it actually doesn't. This is helpful to test if assertion takes effect.
// Since only the first assertion takes effect, set the injected assertion before setting the correct one to
// override it.
if sctx.GetSessionVars().ConnectionID != 0 {
if sctx.ConnectionID() != 0 {
logutil.BgLogger().Info("force asserting exist on AddRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS()))
if err = txn.SetAssertion(key, kv.SetAssertExist); err != nil {
failpoint.Return(nil, err)
Expand Down Expand Up @@ -996,7 +996,7 @@ func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts
if err = injectMutationError(t, txn, sh); err != nil {
return nil, err
}
if sessVars.EnableMutationChecker {
if sctx.EnableMutationChecker() {
if err = CheckDataConsistency(txn, tc, t, r, nil, memBuffer, sh); err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -1048,7 +1048,7 @@ func genIndexKeyStrs(colVals []types.Datum) ([]string, error) {

// addIndices adds data into indices. If any key is duplicated, returns the original handle.
func (t *TableCommon) addIndices(sctx table.MutateContext, recordID kv.Handle, r []types.Datum, txn kv.Transaction, opts []table.CreateIdxOptFunc) (kv.Handle, error) {
writeBufs := sctx.GetSessionVars().GetWriteStmtBufs()
writeBufs := sctx.GetMutateBuffers().GetWriteStmtBufs()
indexVals := writeBufs.IndexValsBuf
skipCheck := sctx.GetSessionVars().StmtCtx.BatchCheck
for _, v := range t.Indices() {
Expand Down Expand Up @@ -1264,7 +1264,7 @@ func (t *TableCommon) RemoveRecord(ctx table.MutateContext, h kv.Handle, r []typ
}

tc := ctx.GetExprCtx().GetEvalCtx().TypeCtx()
if ctx.GetSessionVars().EnableMutationChecker {
if ctx.EnableMutationChecker() {
if err = CheckDataConsistency(txn, tc, t, nil, r, memBuffer, sh); err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -1406,14 +1406,14 @@ func (t *TableCommon) removeRowData(ctx table.MutateContext, h kv.Handle) error
// Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect.
// Since only the first assertion takes effect, set the injected assertion before setting the correct one to
// override it.
if ctx.GetSessionVars().ConnectionID != 0 {
if ctx.ConnectionID() != 0 {
logutil.BgLogger().Info("force asserting not exist on RemoveRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS()))
if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil {
failpoint.Return(err)
}
}
})
if t.shouldAssert(ctx.GetSessionVars().AssertionLevel) {
if t.shouldAssert(ctx.TxnAssertionLevel()) {
err = txn.SetAssertion(key, kv.SetAssertExist)
} else {
err = txn.SetAssertion(key, kv.SetAssertUnknown)
Expand Down

0 comments on commit a23af54

Please sign in to comment.