Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

table: expose some fields to MutateContext from GetSessionVars() #54767

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pkg/table/context/buffers.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ func (b *ColSizeDeltaBuffer) UpdateColSizeMap(m map[int64]int64) map[int64]int64
// Because inner slices are reused, you should not call the get methods again before finishing the previous usage.
// Otherwise, the previous data will be overwritten.
type MutateBuffers struct {
stmtBufs *variable.WriteStmtBufs
encodeRow *EncodeRowBuffer
checkRow *CheckRowBuffer
colSizeDelta *ColSizeDeltaBuffer
Expand All @@ -156,6 +157,7 @@ type MutateBuffers struct {
// NewMutateBuffers creates a new `MutateBuffers`.
func NewMutateBuffers(stmtBufs *variable.WriteStmtBufs) *MutateBuffers {
return &MutateBuffers{
stmtBufs: stmtBufs,
encodeRow: &EncodeRowBuffer{
writeStmtBufs: stmtBufs,
},
Expand Down Expand Up @@ -204,6 +206,11 @@ func (b *MutateBuffers) GetColSizeDeltaBufferWithCap(capacity int) *ColSizeDelta
return buffer
}

// GetWriteStmtBufs returns the `*variable.WriteStmtBufs`
func (b *MutateBuffers) GetWriteStmtBufs() *variable.WriteStmtBufs {
return b.stmtBufs
}

// ensureCapacityAndReset is similar to the built-in make(),
// but it reuses the given slice if it has enough capacity.
func ensureCapacityAndReset[T any](slice []T, size int, optCap ...int) []T {
Expand Down
2 changes: 2 additions & 0 deletions pkg/table/context/buffers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ func TestMutateBuffersGetter(t *testing.T) {

colSize := buffers.GetColSizeDeltaBufferWithCap(6)
require.Equal(t, 6, cap(colSize.delta))

require.Same(t, stmtBufs, buffers.GetWriteStmtBufs())
}

func TestEnsureCapacityAndReset(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions pkg/table/context/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,15 @@ type MutateContext interface {
// TxnRecordTempTable record the temporary table to the current transaction.
// This method will be called when the temporary table is modified or should allocate id in the transaction.
TxnRecordTempTable(tbl *model.TableInfo) tableutil.TempTable
// ConnectionID returns the id of the current connection.
// If the current environment is not in a query from the client, the return value is 0.
ConnectionID() uint64
// InRestrictedSQL returns whether the current context is used in restricted SQL.
InRestrictedSQL() bool
// TxnAssertionLevel returns the assertion level of the current transaction.
TxnAssertionLevel() variable.AssertionLevel
// EnableMutationChecker returns whether to check data consistency for mutations.
EnableMutationChecker() bool
// GetRowEncodingConfig returns the RowEncodingConfig.
GetRowEncodingConfig() RowEncodingConfig
// GetMutateBuffers returns the MutateBuffers,
Expand Down
1 change: 1 addition & 0 deletions pkg/table/contextimpl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ go_test(
deps = [
":contextimpl",
"//pkg/sessionctx/binloginfo",
"//pkg/sessionctx/variable",
"//pkg/testkit",
"//pkg/util/mock",
"@com_github_pingcap_tipb//go-binlog",
Expand Down
17 changes: 16 additions & 1 deletion pkg/table/contextimpl/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,24 @@ func (ctx *TableContextImpl) GetExprCtx() exprctx.ExprContext {
return ctx.Context.GetExprCtx()
}

// ConnectionID implements the MutateContext interface.
func (ctx *TableContextImpl) ConnectionID() uint64 {
return ctx.vars().ConnectionID
}

// InRestrictedSQL returns whether the current context is used in restricted SQL.
func (ctx *TableContextImpl) InRestrictedSQL() bool {
return ctx.vars().StmtCtx.InRestrictedSQL
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found StmtCtx and SessionVars all have InRestrictedSQL and they are kept the same in ResetContextOfStmt:

sc.InRestrictedSQL = vars.InRestrictedSQL

Read it from SessionVars instead of StmtCtx keeps the same behavior as previous implementations.

return ctx.vars().InRestrictedSQL
}

// TxnAssertionLevel implements the MutateContext interface.
func (ctx *TableContextImpl) TxnAssertionLevel() variable.AssertionLevel {
return ctx.vars().AssertionLevel
}

// EnableMutationChecker implements the MutateContext interface.
func (ctx *TableContextImpl) EnableMutationChecker() bool {
return ctx.vars().EnableMutationChecker
}

// BinlogEnabled returns whether the binlog is enabled.
Expand Down
18 changes: 16 additions & 2 deletions pkg/table/contextimpl/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"testing"

"github.com/pingcap/tidb/pkg/sessionctx/binloginfo"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/table/contextimpl"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/pingcap/tidb/pkg/util/mock"
Expand All @@ -39,11 +40,24 @@ func TestMutateContextImplFields(t *testing.T) {
binlogMutation := ctx.GetBinlogMutation(1234)
require.NotNil(t, binlogMutation)
require.Same(t, sctx.StmtGetMutation(1234), binlogMutation)
// ConnectionID
sctx.GetSessionVars().ConnectionID = 12345
require.Equal(t, uint64(12345), ctx.ConnectionID())
// restricted SQL
sctx.GetSessionVars().StmtCtx.InRestrictedSQL = false
sctx.GetSessionVars().InRestrictedSQL = false
require.False(t, ctx.InRestrictedSQL())
sctx.GetSessionVars().StmtCtx.InRestrictedSQL = true
sctx.GetSessionVars().InRestrictedSQL = true
require.True(t, ctx.InRestrictedSQL())
// AssertionLevel
ctx.GetSessionVars().AssertionLevel = variable.AssertionLevelFast
require.Equal(t, variable.AssertionLevelFast, ctx.TxnAssertionLevel())
ctx.GetSessionVars().AssertionLevel = variable.AssertionLevelStrict
require.Equal(t, variable.AssertionLevelStrict, ctx.TxnAssertionLevel())
// EnableMutationChecker
ctx.GetSessionVars().EnableMutationChecker = true
require.True(t, ctx.EnableMutationChecker())
ctx.GetSessionVars().EnableMutationChecker = false
require.False(t, ctx.EnableMutationChecker())
// encoding config
sctx.GetSessionVars().EnableRowLevelChecksum = true
sctx.GetSessionVars().RowEncoder.Enable = true
Expand Down
2 changes: 1 addition & 1 deletion pkg/table/tables/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (c *index) Create(sctx table.MutateContext, txn kv.Transaction, indexedValu
ctx = context.TODO()
}
vars := sctx.GetSessionVars()
writeBufs := vars.GetWriteStmtBufs()
writeBufs := sctx.GetMutateBuffers().GetWriteStmtBufs()
skipCheck := vars.StmtCtx.BatchCheck
evalCtx := sctx.GetExprCtx().GetEvalCtx()
loc, ec := evalCtx.Location(), evalCtx.ErrCtx()
Expand Down
20 changes: 10 additions & 10 deletions pkg/table/tables/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,15 +553,15 @@ func (t *TableCommon) UpdateRecord(ctx context.Context, sctx table.MutateContext
// Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect.
// Since only the first assertion takes effect, set the injected assertion before setting the correct one to
// override it.
if sctx.GetSessionVars().ConnectionID != 0 {
if sctx.ConnectionID() != 0 {
logutil.BgLogger().Info("force asserting not exist on UpdateRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS()))
if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil {
failpoint.Return(err)
}
}
})

if t.shouldAssert(sessVars.AssertionLevel) {
if t.shouldAssert(sctx.TxnAssertionLevel()) {
err = txn.SetAssertion(key, kv.SetAssertExist)
} else {
err = txn.SetAssertion(key, kv.SetAssertUnknown)
Expand All @@ -573,7 +573,7 @@ func (t *TableCommon) UpdateRecord(ctx context.Context, sctx table.MutateContext
if err = injectMutationError(t, txn, sh); err != nil {
return err
}
if sessVars.EnableMutationChecker {
if sctx.EnableMutationChecker() {
if err = CheckDataConsistency(txn, tc, t, newData, oldData, memBuffer, sh); err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -948,7 +948,7 @@ func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts
if setPresume {
flags = []kv.FlagsOp{kv.SetPresumeKeyNotExists}
if !sessVars.ConstraintCheckInPlacePessimistic && sessVars.TxnCtx.IsPessimistic && sessVars.InTxn() &&
!sessVars.InRestrictedSQL && sessVars.ConnectionID > 0 {
!sctx.InRestrictedSQL() && sctx.ConnectionID() > 0 {
flags = append(flags, kv.SetNeedConstraintCheckInPrewrite)
}
}
Expand All @@ -962,7 +962,7 @@ func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts
// Assert the key exists while it actually doesn't. This is helpful to test if assertion takes effect.
// Since only the first assertion takes effect, set the injected assertion before setting the correct one to
// override it.
if sctx.GetSessionVars().ConnectionID != 0 {
if sctx.ConnectionID() != 0 {
logutil.BgLogger().Info("force asserting exist on AddRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS()))
if err = txn.SetAssertion(key, kv.SetAssertExist); err != nil {
failpoint.Return(nil, err)
Expand Down Expand Up @@ -996,7 +996,7 @@ func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts
if err = injectMutationError(t, txn, sh); err != nil {
return nil, err
}
if sessVars.EnableMutationChecker {
if sctx.EnableMutationChecker() {
if err = CheckDataConsistency(txn, tc, t, r, nil, memBuffer, sh); err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -1048,7 +1048,7 @@ func genIndexKeyStrs(colVals []types.Datum) ([]string, error) {

// addIndices adds data into indices. If any key is duplicated, returns the original handle.
func (t *TableCommon) addIndices(sctx table.MutateContext, recordID kv.Handle, r []types.Datum, txn kv.Transaction, opts []table.CreateIdxOptFunc) (kv.Handle, error) {
writeBufs := sctx.GetSessionVars().GetWriteStmtBufs()
writeBufs := sctx.GetMutateBuffers().GetWriteStmtBufs()
indexVals := writeBufs.IndexValsBuf
skipCheck := sctx.GetSessionVars().StmtCtx.BatchCheck
for _, v := range t.Indices() {
Expand Down Expand Up @@ -1264,7 +1264,7 @@ func (t *TableCommon) RemoveRecord(ctx table.MutateContext, h kv.Handle, r []typ
}

tc := ctx.GetExprCtx().GetEvalCtx().TypeCtx()
if ctx.GetSessionVars().EnableMutationChecker {
if ctx.EnableMutationChecker() {
if err = CheckDataConsistency(txn, tc, t, nil, r, memBuffer, sh); err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -1406,14 +1406,14 @@ func (t *TableCommon) removeRowData(ctx table.MutateContext, h kv.Handle) error
// Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect.
// Since only the first assertion takes effect, set the injected assertion before setting the correct one to
// override it.
if ctx.GetSessionVars().ConnectionID != 0 {
if ctx.ConnectionID() != 0 {
logutil.BgLogger().Info("force asserting not exist on RemoveRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS()))
if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil {
failpoint.Return(err)
}
}
})
if t.shouldAssert(ctx.GetSessionVars().AssertionLevel) {
if t.shouldAssert(ctx.TxnAssertionLevel()) {
err = txn.SetAssertion(key, kv.SetAssertExist)
} else {
err = txn.SetAssertion(key, kv.SetAssertUnknown)
Expand Down