diff --git a/pkg/ddl/reorg.go b/pkg/ddl/reorg.go index 7cf13f60fbb01..c86c983dd39d2 100644 --- a/pkg/ddl/reorg.go +++ b/pkg/ddl/reorg.go @@ -84,7 +84,6 @@ func newReorgExprCtx() exprctx.ExprContext { return contextstatic.NewStaticExprContext( contextstatic.WithEvalCtx(evalCtx), - contextstatic.WithUseCache(false), ) } diff --git a/pkg/distsql/context/context.go b/pkg/distsql/context/context.go index 797ec00666126..e0d6baad8fafd 100644 --- a/pkg/distsql/context/context.go +++ b/pkg/distsql/context/context.go @@ -47,7 +47,7 @@ type DistSQLContext struct { EnabledRateLimitAction bool EnableChunkRPC bool OriginalSQL string - KVVars *tikvstore.Variables + KVVars tikvstore.Variables KvExecCounter *stmtstats.KvExecCounter SessionMemTracker *memory.Tracker diff --git a/pkg/distsql/distsql.go b/pkg/distsql/distsql.go index 3949b30ea8dc9..1f40cd0df36b9 100644 --- a/pkg/distsql/distsql.go +++ b/pkg/distsql/distsql.go @@ -88,7 +88,7 @@ func Select(ctx context.Context, dctx *distsqlctx.DistSQLContext, kvReq *kv.Requ option.AppendWarning = dctx.AppendWarning } - resp := dctx.Client.Send(ctx, kvReq, dctx.KVVars, option) + resp := dctx.Client.Send(ctx, kvReq, &dctx.KVVars, option) if resp == nil { return nil, errors.New("client returns nil response") } diff --git a/pkg/executor/BUILD.bazel b/pkg/executor/BUILD.bazel index 5d4e30394b25c..b83ab87f2ca0c 100644 --- a/pkg/executor/BUILD.bazel +++ b/pkg/executor/BUILD.bazel @@ -125,6 +125,7 @@ go_library( "//pkg/expression", "//pkg/expression/aggregation", "//pkg/expression/context", + "//pkg/expression/contextsession", "//pkg/infoschema", "//pkg/infoschema/context", "//pkg/keyspace", diff --git a/pkg/executor/table_reader.go b/pkg/executor/table_reader.go index 59c1043ec4cbc..c481edcf20e92 100644 --- a/pkg/executor/table_reader.go +++ b/pkg/executor/table_reader.go @@ -33,6 +33,7 @@ import ( internalutil "github.com/pingcap/tidb/pkg/executor/internal/util" "github.com/pingcap/tidb/pkg/expression" exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/expression/contextsession" "github.com/pingcap/tidb/pkg/infoschema" isctx "github.com/pingcap/tidb/pkg/infoschema/context" "github.com/pingcap/tidb/pkg/kv" @@ -82,7 +83,7 @@ type tableReaderExecutorContext struct { dctx *distsqlctx.DistSQLContext rctx *rangerctx.RangerContext buildPBCtx *planctx.BuildPBContext - ectx exprctx.BuildContext + ectx exprctx.ExprContext stmtMemTracker *memory.Tracker @@ -102,6 +103,22 @@ func (treCtx *tableReaderExecutorContext) GetDDLOwner(ctx context.Context) (*inf return nil, errors.New("GetDDLOwner in a context without DDL") } +// IntoStatic detaches the current context from the original session context. +// +// NOTE: For `dctx`, `rctx`... most of the fields don't need to be handled specially, because they are already copied from the session context. +// some reference types like `WarnHandler` also doesn't need to copy because a new statement will always creates a new `WarnHandler`, so it's +// safe to continue to use it here. We'll need to call `IntoStatic` method for `evalCtx` and `exprCtx`, because maybe they are implemented by +// the session context directly. +func (treCtx *tableReaderExecutorContext) IntoStatic() { + if sctx, ok := treCtx.ectx.(*contextsession.SessionExprContext); ok { + staticECtx := sctx.IntoStatic() + + treCtx.rctx.IntoStatic(staticECtx) + treCtx.buildPBCtx.IntoStatic(staticECtx) + treCtx.ectx = staticECtx + } +} + func newTableReaderExecutorContext(sctx sessionctx.Context) tableReaderExecutorContext { // Explicitly get `ownerManager` out of the closure to show that the `tableReaderExecutorContext` itself doesn't // depend on `sctx` directly. diff --git a/pkg/expression/contextsession/BUILD.bazel b/pkg/expression/contextsession/BUILD.bazel index 38353c17faad7..e333ef78b8d3b 100644 --- a/pkg/expression/contextsession/BUILD.bazel +++ b/pkg/expression/contextsession/BUILD.bazel @@ -9,6 +9,7 @@ go_library( "//pkg/errctx", "//pkg/expression/context", "//pkg/expression/contextopt", + "//pkg/expression/contextstatic", "//pkg/infoschema/context", "//pkg/parser/auth", "//pkg/parser/model", diff --git a/pkg/expression/contextsession/sessionctx.go b/pkg/expression/contextsession/sessionctx.go index 2119919b6bbdf..fc27daa71c946 100644 --- a/pkg/expression/contextsession/sessionctx.go +++ b/pkg/expression/contextsession/sessionctx.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/pkg/errctx" exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/expression/contextopt" + "github.com/pingcap/tidb/pkg/expression/contextstatic" infoschema "github.com/pingcap/tidb/pkg/infoschema/context" "github.com/pingcap/tidb/pkg/parser/auth" "github.com/pingcap/tidb/pkg/parser/model" @@ -131,6 +132,26 @@ func (ctx *SessionExprContext) ConnectionID() uint64 { return ctx.sctx.GetSessionVars().ConnectionID } +// IntoStatic turns the SessionExprContext into a StaticExprContext. +func (ctx *SessionExprContext) IntoStatic() *contextstatic.StaticExprContext { + staticEvalContext := ctx.SessionEvalContext.IntoStatic() + return contextstatic.NewStaticExprContext( + contextstatic.WithEvalCtx(staticEvalContext), + contextstatic.WithCharset(ctx.GetCharsetInfo()), + contextstatic.WithDefaultCollationForUTF8MB4(ctx.GetDefaultCollationForUTF8MB4()), + contextstatic.WithBlockEncryptionMode(ctx.GetBlockEncryptionMode()), + contextstatic.WithSysDateIsNow(ctx.GetSysdateIsNow()), + contextstatic.WithNoopFuncsMode(ctx.GetNoopFuncsMode()), + contextstatic.WithRng(ctx.Rng()), + contextstatic.WithPlanCacheTracker(ctx.sctx.GetSessionVars().StmtCtx.PlanCacheTracker), + contextstatic.WithColumnIDAllocator( + exprctx.NewSimplePlanColumnIDAllocator(ctx.sctx.GetSessionVars().PlanColumnID.Load())), + contextstatic.WithConnectionID(ctx.ConnectionID()), + contextstatic.WithWindowingUseHighPrecision(ctx.GetWindowingUseHighPrecision()), + contextstatic.WithGroupConcatMaxLen(ctx.GetGroupConcatMaxLen()), + ) +} + // SessionEvalContext implements the `expression.EvalContext` interface to provide evaluation context in session. type SessionEvalContext struct { sctx sessionctx.Context @@ -273,6 +294,71 @@ func (ctx *SessionEvalContext) RequestDynamicVerification(privName string, grant return checker.RequestDynamicVerification(ctx.sctx.GetSessionVars().ActiveRoles, privName, grantable) } +// IntoStatic turns the SessionEvalContext into a StaticEvalContext. +func (ctx *SessionEvalContext) IntoStatic() *contextstatic.StaticEvalContext { + typeCtx := ctx.TypeCtx() + errCtx := ctx.ErrCtx() + + props := make([]exprctx.OptionalEvalPropProvider, 0, exprctx.OptPropsCnt) + for i := 0; i < exprctx.OptPropsCnt; i++ { + // TODO: check whether these `prop` is safe to copy + if prop, ok := ctx.GetOptionalPropProvider(exprctx.OptionalEvalPropKey(i)); ok { + props = append(props, prop) + } + } + + // TODO: use a more structural way to replace the closure. + // These closure makes sure the fields which may be changed in the execution of the next statement will not be embedded into them, to make + // sure it's safe to call them after the session continues to execute other statements. + staticCtx := contextstatic.NewStaticEvalContext( + contextstatic.WithWarnHandler(ctx.sctx.GetSessionVars().StmtCtx.WarnHandler), + contextstatic.WithSQLMode(ctx.SQLMode()), + contextstatic.WithTypeFlags(typeCtx.Flags()), + contextstatic.WithLocation(typeCtx.Location()), + contextstatic.WithErrLevelMap(errCtx.LevelMap()), + contextstatic.WithCurrentDB(ctx.CurrentDB()), + contextstatic.WithCurrentTime(func() func() (time.Time, error) { + currentTime, currentTimeErr := ctx.CurrentTime() + + return func() (time.Time, error) { + return currentTime, currentTimeErr + } + }()), + contextstatic.WithMaxAllowedPacket(ctx.GetMaxAllowedPacket()), + contextstatic.WithDefaultWeekFormatMode(ctx.GetDefaultWeekFormatMode()), + contextstatic.WithDivPrecisionIncrement(ctx.GetDivPrecisionIncrement()), + contextstatic.WithPrivCheck(func() func(db string, table string, column string, priv mysql.PrivilegeType) bool { + checker := privilege.GetPrivilegeManager(ctx.sctx) + activeRoles := make([]*auth.RoleIdentity, len(ctx.sctx.GetSessionVars().ActiveRoles)) + copy(activeRoles, ctx.sctx.GetSessionVars().ActiveRoles) + + return func(db string, table string, column string, priv mysql.PrivilegeType) bool { + if checker == nil { + return true + } + + return checker.RequestVerification(activeRoles, db, table, column, priv) + } + }()), + contextstatic.WithDynamicPrivCheck(func() func(privName string, grantable bool) bool { + checker := privilege.GetPrivilegeManager(ctx.sctx) + activeRoles := make([]*auth.RoleIdentity, len(ctx.sctx.GetSessionVars().ActiveRoles)) + copy(activeRoles, ctx.sctx.GetSessionVars().ActiveRoles) + + return func(privName string, grantable bool) bool { + if checker == nil { + return true + } + + return checker.RequestDynamicVerification(activeRoles, privName, grantable) + } + }()), + contextstatic.WithOptionalProperty(props...), + ) + + return staticCtx +} + func getStmtTimestamp(ctx sessionctx.Context) (time.Time, error) { if ctx != nil { staleTSO, err := ctx.GetSessionVars().StmtCtx.GetStaleTSO() diff --git a/pkg/expression/contextstatic/BUILD.bazel b/pkg/expression/contextstatic/BUILD.bazel index b8060900ffe09..aacdb8e479ab3 100644 --- a/pkg/expression/contextstatic/BUILD.bazel +++ b/pkg/expression/contextstatic/BUILD.bazel @@ -31,7 +31,7 @@ go_test( ], embed = [":contextstatic"], flaky = True, - shard_count = 9, + shard_count = 8, deps = [ "//pkg/errctx", "//pkg/expression/context", diff --git a/pkg/expression/contextstatic/exprctx.go b/pkg/expression/contextstatic/exprctx.go index f6adac24c4436..12da9738e31c1 100644 --- a/pkg/expression/contextstatic/exprctx.go +++ b/pkg/expression/contextstatic/exprctx.go @@ -15,12 +15,11 @@ package contextstatic import ( - "sync/atomic" - exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/sessionctx/variable" + contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/mathutil" ) @@ -39,8 +38,7 @@ type staticExprCtxState struct { sysDateIsNow bool noopFuncsMode int rng *mathutil.MysqlRng - canUseCache *atomic.Bool - skipCacheHandleFunc func(useCache *atomic.Bool, skipReason string) + planCacheTracker *contextutil.PlanCacheTracker columnIDAllocator exprctx.PlanColumnIDAllocator connectionID uint64 windowingUseHighPrecision bool @@ -103,17 +101,11 @@ func WithRng(rng *mathutil.MysqlRng) StaticExprCtxOption { } } -// WithUseCache sets the return value of `IsUseCache` for `StaticExprContext`. -func WithUseCache(useCache bool) StaticExprCtxOption { +// WithPlanCacheTracker sets the plan cache tracker for `StaticExprContext`. +func WithPlanCacheTracker(tracker *contextutil.PlanCacheTracker) StaticExprCtxOption { + intest.AssertNotNil(tracker) return func(s *staticExprCtxState) { - s.canUseCache.Store(useCache) - } -} - -// WithSkipCacheHandleFunc sets inner skip plan cache function for StaticExprContext -func WithSkipCacheHandleFunc(fn func(useCache *atomic.Bool, skipReason string)) StaticExprCtxOption { - return func(s *staticExprCtxState) { - s.skipCacheHandleFunc = fn + s.planCacheTracker = tracker } } @@ -171,8 +163,7 @@ func NewStaticExprContext(opts ...StaticExprCtxOption) *StaticExprContext { }, } - ctx.canUseCache = &atomic.Bool{} - ctx.canUseCache.Store(true) + ctx.planCacheTracker = contextutil.NewPlanCacheTracker(ctx.evalCtx) for _, opt := range opts { opt(&ctx.staticExprCtxState) @@ -199,9 +190,6 @@ func (ctx *StaticExprContext) Apply(opts ...StaticExprCtxOption) *StaticExprCont staticExprCtxState: ctx.staticExprCtxState, } - newCtx.canUseCache = &atomic.Bool{} - newCtx.canUseCache.Store(ctx.canUseCache.Load()) - for _, opt := range opts { opt(&newCtx.staticExprCtxState) } @@ -246,16 +234,12 @@ func (ctx *StaticExprContext) Rng() *mathutil.MysqlRng { // IsUseCache implements the `ExprContext.IsUseCache`. func (ctx *StaticExprContext) IsUseCache() bool { - return ctx.canUseCache.Load() + return ctx.planCacheTracker.UseCache() } // SetSkipPlanCache implements the `ExprContext.SetSkipPlanCache`. func (ctx *StaticExprContext) SetSkipPlanCache(reason string) { - if fn := ctx.skipCacheHandleFunc; fn != nil { - fn(ctx.canUseCache, reason) - return - } - ctx.canUseCache.Store(false) + ctx.planCacheTracker.SetSkipPlanCache(reason) } // AllocPlanColumnID implements the `ExprContext.AllocPlanColumnID`. diff --git a/pkg/expression/contextstatic/exprctx_test.go b/pkg/expression/contextstatic/exprctx_test.go index fa358402f1ed3..c0b1284a2ba5b 100644 --- a/pkg/expression/contextstatic/exprctx_test.go +++ b/pkg/expression/contextstatic/exprctx_test.go @@ -15,7 +15,6 @@ package contextstatic import ( - "sync/atomic" "testing" "time" @@ -41,16 +40,13 @@ func TestNewStaticExprCtx(t *testing.T) { func TestStaticExprCtxApplyOptions(t *testing.T) { ctx := NewStaticExprContext() - oldCanUseCache := ctx.canUseCache oldEvalCtx := ctx.evalCtx oldColumnIDAllocator := ctx.columnIDAllocator // apply with options opts, s := getExprCtxOptionsForTest() ctx2 := ctx.Apply(opts...) - require.NotSame(t, oldCanUseCache, ctx2.canUseCache) require.Equal(t, oldEvalCtx, ctx.evalCtx) - require.Same(t, oldCanUseCache, ctx.canUseCache) require.Same(t, oldColumnIDAllocator, ctx.columnIDAllocator) checkDefaultStaticExprCtx(t, ctx) checkOptionsStaticExprCtx(t, ctx2, s) @@ -59,7 +55,6 @@ func TestStaticExprCtxApplyOptions(t *testing.T) { ctx3 := ctx2.Apply() s.skipCacheArgs = nil checkOptionsStaticExprCtx(t, ctx3, s) - require.NotSame(t, ctx2.canUseCache, ctx3.canUseCache) } func checkDefaultStaticExprCtx(t *testing.T, ctx *StaticExprContext) { @@ -75,8 +70,6 @@ func checkDefaultStaticExprCtx(t *testing.T, ctx *StaticExprContext) { require.Equal(t, variable.DefSysdateIsNow, ctx.GetSysdateIsNow()) require.Equal(t, variable.TiDBOptOnOffWarn(variable.DefTiDBEnableNoopFuncs), ctx.GetNoopFuncsMode()) require.NotNil(t, ctx.Rng()) - require.True(t, ctx.IsUseCache()) - require.Nil(t, ctx.skipCacheHandleFunc) require.NotNil(t, ctx.columnIDAllocator) _, ok := ctx.columnIDAllocator.(*context.SimplePlanColumnIDAllocator) require.True(t, ok) @@ -107,10 +100,6 @@ func getExprCtxOptionsForTest() ([]StaticExprCtxOption, *exprCtxOptionsTestState WithSysDateIsNow(true), WithNoopFuncsMode(variable.WarnInt), WithRng(s.rng), - WithUseCache(false), - WithSkipCacheHandleFunc(func(useCache *atomic.Bool, skipReason string) { - s.skipCacheArgs = []any{useCache, skipReason} - }), WithColumnIDAllocator(s.colIDAlloc), WithConnectionID(778899), WithWindowingUseHighPrecision(false), @@ -131,91 +120,12 @@ func checkOptionsStaticExprCtx(t *testing.T, ctx *StaticExprContext, s *exprCtxO require.False(t, ctx.IsUseCache()) require.Nil(t, s.skipCacheArgs) ctx.SetSkipPlanCache("reason") - require.Equal(t, []any{ctx.canUseCache, "reason"}, s.skipCacheArgs) require.Same(t, s.colIDAlloc, ctx.columnIDAllocator) require.Equal(t, uint64(778899), ctx.ConnectionID()) require.False(t, ctx.GetWindowingUseHighPrecision()) require.Equal(t, uint64(2233445566), ctx.GetGroupConcatMaxLen()) } -func TestStaticExprCtxUseCache(t *testing.T) { - // default implement - ctx := NewStaticExprContext() - require.True(t, ctx.IsUseCache()) - require.Nil(t, ctx.skipCacheHandleFunc) - ctx.SetSkipPlanCache("reason") - require.False(t, ctx.IsUseCache()) - require.Empty(t, ctx.GetEvalCtx().TruncateWarnings(0)) - - ctx = NewStaticExprContext(WithUseCache(false)) - require.False(t, ctx.IsUseCache()) - require.Nil(t, ctx.skipCacheHandleFunc) - ctx.SetSkipPlanCache("reason") - require.False(t, ctx.IsUseCache()) - require.Empty(t, ctx.GetEvalCtx().TruncateWarnings(0)) - - ctx = NewStaticExprContext(WithUseCache(true)) - require.True(t, ctx.IsUseCache()) - require.Nil(t, ctx.skipCacheHandleFunc) - ctx.SetSkipPlanCache("reason") - require.False(t, ctx.IsUseCache()) - require.Empty(t, ctx.GetEvalCtx().TruncateWarnings(0)) - - // custom skip func - var args []any - calls := 0 - ctx = NewStaticExprContext(WithSkipCacheHandleFunc(func(useCache *atomic.Bool, skipReason string) { - args = []any{useCache, skipReason} - calls++ - if calls > 1 { - useCache.Store(false) - } - })) - ctx.SetSkipPlanCache("reason1") - // If we use `WithSkipCacheHandleFunc`, useCache will be set in function - require.Equal(t, 1, calls) - require.True(t, ctx.IsUseCache()) - require.Equal(t, []any{ctx.canUseCache, "reason1"}, args) - - args = nil - ctx.SetSkipPlanCache("reason2") - require.Equal(t, 2, calls) - require.False(t, ctx.IsUseCache()) - require.Equal(t, []any{ctx.canUseCache, "reason2"}, args) - - // apply - ctx = NewStaticExprContext() - require.True(t, ctx.IsUseCache()) - ctx2 := ctx.Apply(WithUseCache(false)) - require.False(t, ctx2.IsUseCache()) - require.True(t, ctx.IsUseCache()) - require.NotSame(t, ctx.canUseCache, ctx2.canUseCache) - require.Nil(t, ctx.skipCacheHandleFunc) - require.Nil(t, ctx2.skipCacheHandleFunc) - - var args2 []any - fn1 := func(useCache *atomic.Bool, skipReason string) { args = []any{useCache, skipReason} } - fn2 := func(useCache *atomic.Bool, skipReason string) { args2 = []any{useCache, skipReason} } - ctx = NewStaticExprContext(WithUseCache(false), WithSkipCacheHandleFunc(fn1)) - require.False(t, ctx.IsUseCache()) - ctx2 = ctx.Apply(WithUseCache(true), WithSkipCacheHandleFunc(fn2)) - require.NotSame(t, ctx.canUseCache, ctx2.canUseCache) - require.False(t, ctx.IsUseCache()) - require.True(t, ctx2.IsUseCache()) - - args = nil - args2 = nil - ctx.SetSkipPlanCache("reasonA") - require.Equal(t, []any{ctx.canUseCache, "reasonA"}, args) - require.Nil(t, args2) - - args = nil - args2 = nil - ctx2.SetSkipPlanCache("reasonB") - require.Nil(t, args) - require.Equal(t, []any{ctx2.canUseCache, "reasonB"}, args2) -} - func TestExprCtxColumnIDAllocator(t *testing.T) { // default ctx := NewStaticExprContext() diff --git a/pkg/planner/context/context.go b/pkg/planner/context/context.go index 99b1ca4ed2495..6eecafe46747f 100644 --- a/pkg/planner/context/context.go +++ b/pkg/planner/context/context.go @@ -106,3 +106,9 @@ func (b *BuildPBContext) GetExprCtx() exprctx.BuildContext { func (b *BuildPBContext) GetClient() kv.Client { return b.Client } + +// IntoStatic persists some fields to make sure it's safe to read/write the context after the session continues +// to execute other statements. +func (b *BuildPBContext) IntoStatic(staticExprCtx exprctx.BuildContext) { + b.ExprCtx = staticExprCtx +} diff --git a/pkg/planner/core/exhaust_physical_plans_test.go b/pkg/planner/core/exhaust_physical_plans_test.go index 80d1af2d99cab..fcc9c01fc327d 100644 --- a/pkg/planner/core/exhaust_physical_plans_test.go +++ b/pkg/planner/core/exhaust_physical_plans_test.go @@ -350,7 +350,7 @@ func checkRangeFallbackAndReset(t *testing.T, ctx base.PlanContext, expectedRang } require.Equal(t, expectedRangeFallback, hasRangeFallbackWarn) stmtCtx.PlanCacheTracker = contextutil.NewPlanCacheTracker(stmtCtx) - stmtCtx.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(&stmtCtx.PlanCacheTracker, stmtCtx) + stmtCtx.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(stmtCtx.PlanCacheTracker, stmtCtx) stmtCtx.SetWarnings(nil) } diff --git a/pkg/session/session.go b/pkg/session/session.go index ccfc9fb278017..18e75faa27cf0 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -2560,7 +2560,7 @@ func (s *session) GetDistSQLCtx() *distsqlctx.DistSQLContext { EnabledRateLimitAction: vars.EnabledRateLimitAction, EnableChunkRPC: vars.EnableChunkRPC, OriginalSQL: sc.OriginalSQL, - KVVars: vars.KVVars, + KVVars: *vars.KVVars, KvExecCounter: sc.KvExecCounter, SessionMemTracker: vars.MemTracker, @@ -2612,6 +2612,8 @@ func (s *session) GetRangerCtx() *rangerctx.RangerContext { rctx := sc.GetOrInitRangerCtxFromCache(func() any { return &rangerctx.RangerContext{ + WarnHandler: sc.WarnHandler, + ExprCtx: s.GetExprCtx(), TypeCtx: s.GetSessionVars().StmtCtx.TypeCtx(), ErrCtx: s.GetSessionVars().StmtCtx.ErrCtx(), @@ -2621,8 +2623,8 @@ func (s *session) GetRangerCtx() *rangerctx.RangerContext { OptPrefixIndexSingleScan: s.GetSessionVars().OptPrefixIndexSingleScan, OptimizerFixControl: s.GetSessionVars().OptimizerFixControl, - PlanCacheTracker: &s.GetSessionVars().StmtCtx.PlanCacheTracker, - RangeFallbackHandler: &s.GetSessionVars().StmtCtx.RangeFallbackHandler, + PlanCacheTracker: s.GetSessionVars().StmtCtx.PlanCacheTracker, + RangeFallbackHandler: s.GetSessionVars().StmtCtx.RangeFallbackHandler, } }) diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 5a72c3b6ceb32..31e4ca37a84d9 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -164,8 +164,8 @@ type StatementContext struct { InPreparedPlanBuilding bool InShowWarning bool - contextutil.PlanCacheTracker - contextutil.RangeFallbackHandler + *contextutil.PlanCacheTracker + *contextutil.RangeFallbackHandler BatchCheck bool IgnoreExplainIDSuffix bool @@ -424,7 +424,7 @@ func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext { sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc) sc.errCtx = newErrCtx(sc.typeCtx, DefaultStmtErrLevels, sc) sc.PlanCacheTracker = contextutil.NewPlanCacheTracker(sc) - sc.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(&sc.PlanCacheTracker, sc) + sc.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(sc.PlanCacheTracker, sc) sc.WarnHandler = contextutil.NewStaticWarnHandler(0) sc.ExtraWarnHandler = contextutil.NewStaticWarnHandler(0) return sc @@ -438,7 +438,7 @@ func (sc *StatementContext) Reset() { sc.typeCtx = types.NewContext(types.DefaultStmtFlags, time.UTC, sc) sc.errCtx = newErrCtx(sc.typeCtx, DefaultStmtErrLevels, sc) sc.PlanCacheTracker = contextutil.NewPlanCacheTracker(sc) - sc.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(&sc.PlanCacheTracker, sc) + sc.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(sc.PlanCacheTracker, sc) sc.WarnHandler = contextutil.NewStaticWarnHandler(0) sc.ExtraWarnHandler = contextutil.NewStaticWarnHandler(0) } diff --git a/pkg/util/context/plancache.go b/pkg/util/context/plancache.go index 9cac5ec81012e..cdca904485798 100644 --- a/pkg/util/context/plancache.go +++ b/pkg/util/context/plancache.go @@ -139,8 +139,8 @@ func (h *PlanCacheTracker) PlanCacheUnqualified() string { } // NewPlanCacheTracker creates a new PlanCacheTracker. -func NewPlanCacheTracker(warnHandler WarnAppender) PlanCacheTracker { - return PlanCacheTracker{ +func NewPlanCacheTracker(warnHandler WarnAppender) *PlanCacheTracker { + return &PlanCacheTracker{ warnHandler: warnHandler, } } @@ -166,8 +166,8 @@ func (h *RangeFallbackHandler) RecordRangeFallback(rangeMaxSize int64) { } // NewRangeFallbackHandler creates a new RangeFallbackHandler. -func NewRangeFallbackHandler(planCacheTracker *PlanCacheTracker, warnHandler WarnAppender) RangeFallbackHandler { - return RangeFallbackHandler{ +func NewRangeFallbackHandler(planCacheTracker *PlanCacheTracker, warnHandler WarnAppender) *RangeFallbackHandler { + return &RangeFallbackHandler{ planCacheTracker: planCacheTracker, warnHandler: warnHandler, } diff --git a/pkg/util/mock/context.go b/pkg/util/mock/context.go index a8bd28e15dc8f..a45882fc68559 100644 --- a/pkg/util/mock/context.go +++ b/pkg/util/mock/context.go @@ -256,7 +256,7 @@ func (c *Context) GetDistSQLCtx() *distsqlctx.DistSQLContext { EnabledRateLimitAction: vars.EnabledRateLimitAction, EnableChunkRPC: vars.EnableChunkRPC, OriginalSQL: sc.OriginalSQL, - KVVars: vars.KVVars, + KVVars: *vars.KVVars, KvExecCounter: sc.KvExecCounter, SessionMemTracker: vars.MemTracker, Location: sc.TimeZone(), @@ -286,8 +286,8 @@ func (c *Context) GetRangerCtx() *rangerctx.RangerContext { OptPrefixIndexSingleScan: c.GetSessionVars().OptPrefixIndexSingleScan, OptimizerFixControl: c.GetSessionVars().OptimizerFixControl, - PlanCacheTracker: &c.GetSessionVars().StmtCtx.PlanCacheTracker, - RangeFallbackHandler: &c.GetSessionVars().StmtCtx.RangeFallbackHandler, + PlanCacheTracker: c.GetSessionVars().StmtCtx.PlanCacheTracker, + RangeFallbackHandler: c.GetSessionVars().StmtCtx.RangeFallbackHandler, } } diff --git a/pkg/util/ranger/context/context.go b/pkg/util/ranger/context/context.go index 626e924981e7c..0f5610bca247d 100644 --- a/pkg/util/ranger/context/context.go +++ b/pkg/util/ranger/context/context.go @@ -23,6 +23,8 @@ import ( // RangerContext is the context used to build range. type RangerContext struct { + WarnHandler contextutil.WarnAppender + TypeCtx types.Context ErrCtx errctx.Context ExprCtx exprctx.BuildContext @@ -34,3 +36,14 @@ type RangerContext struct { RegardNULLAsPoint bool OptPrefixIndexSingleScan bool } + +// IntoStatic persists some fields to make sure it's safe to read/write the context after the session continues +// to execute other statements. +func (r *RangerContext) IntoStatic(staticExprCtx exprctx.BuildContext) { + r.ExprCtx = staticExprCtx + + fixControl := make(map[uint64]string, len(r.OptimizerFixControl)) + for k, v := range r.OptimizerFixControl { + fixControl[k] = v + } +} diff --git a/pkg/util/ranger/ranger_test.go b/pkg/util/ranger/ranger_test.go index 0a260cc3bc05a..c317f7a92bdd7 100644 --- a/pkg/util/ranger/ranger_test.go +++ b/pkg/util/ranger/ranger_test.go @@ -1856,7 +1856,7 @@ func checkRangeFallbackAndReset(t *testing.T, sctx sessionctx.Context, expectedR } require.Equal(t, expectedRangeFallback, hasRangeFallbackWarn) stmtCtx.PlanCacheTracker = contextutil.NewPlanCacheTracker(stmtCtx) - stmtCtx.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(&stmtCtx.PlanCacheTracker, stmtCtx) + stmtCtx.RangeFallbackHandler = contextutil.NewRangeFallbackHandler(stmtCtx.PlanCacheTracker, stmtCtx) stmtCtx.SetWarnings(nil) }