From 4e6af00850cdb77aa996ea85662d69cc6b31b46c Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Tue, 15 Aug 2023 15:59:00 +0800 Subject: [PATCH] planner: store the hints of session variable (#45814) (#46046) close pingcap/tidb#45812 --- planner/core/plan_cache.go | 11 ++++- planner/core/plan_cache_utils.go | 6 ++- planner/core/point_get_plan.go | 2 + session/session_test/BUILD.bazel | 1 + session/session_test/session_test.go | 61 ++++++++++++++++++++++++++++ sessionctx/stmtctx/BUILD.bazel | 2 +- sessionctx/stmtctx/stmtctx.go | 36 +++++++++++++++- sessionctx/stmtctx/stmtctx_test.go | 23 +++++++++++ 8 files changed, 137 insertions(+), 5 deletions(-) diff --git a/planner/core/plan_cache.go b/planner/core/plan_cache.go index b6e5c3824781b..caae34c3b4cd3 100644 --- a/planner/core/plan_cache.go +++ b/planner/core/plan_cache.go @@ -209,6 +209,9 @@ func getPointQueryPlan(stmt *ast.Prepared, sessVars *variable.SessionVars, stmtC } sessVars.FoundInPlanCache = true stmtCtx.PointExec = true + if pointGetPlan, ok := plan.(*PointGetPlan); ok && pointGetPlan != nil && pointGetPlan.stmtHints != nil { + sessVars.StmtCtx.StmtHints = *pointGetPlan.stmtHints + } return plan, names, true, nil } @@ -251,6 +254,7 @@ func getGeneralPlan(sctx sessionctx.Context, isGeneralPlanCache bool, cacheKey k planCacheCounter.Inc() } stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest) + stmtCtx.StmtHints = *cachedVal.stmtHints return cachedVal.Plan, cachedVal.OutPutNames, true, nil } @@ -289,7 +293,7 @@ func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isGeneralPlan } sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{} } - cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, paramTypes) + cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, paramTypes, &stmtCtx.StmtHints) stmt.NormalizedPlan, stmt.PlanDigest = NormalizePlan(p) stmtCtx.SetPlan(p) stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest) @@ -687,12 +691,15 @@ func tryCachePointPlan(_ context.Context, sctx sessionctx.Context, names types.NameSlice ) - if _, _ok := p.(*PointGetPlan); _ok { + if plan, _ok := p.(*PointGetPlan); _ok { ok, err = IsPointGetWithPKOrUniqueKeyByAutoCommit(sctx, p) names = p.OutputNames() if err != nil { return err } + if ok { + plan.stmtHints = sctx.GetSessionVars().StmtCtx.StmtHints.Clone() + } } if ok { diff --git a/planner/core/plan_cache_utils.go b/planner/core/plan_cache_utils.go index 5431f7ef71c27..ba35b76afc87d 100644 --- a/planner/core/plan_cache_utils.go +++ b/planner/core/plan_cache_utils.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" driver "github.com/pingcap/tidb/types/parser_driver" @@ -348,6 +349,8 @@ type PlanCacheValue struct { TblInfo2UnionScan map[*model.TableInfo]bool ParamTypes FieldSlice memoryUsage int64 + // stmtHints stores the hints which set session variables, because the hints won't be processed using cached plan. + stmtHints *stmtctx.StmtHints } func (v *PlanCacheValue) varTypesUnchanged(txtVarTps []*types.FieldType) bool { @@ -395,7 +398,7 @@ func (v *PlanCacheValue) MemoryUsage() (sum int64) { // NewPlanCacheValue creates a SQLCacheValue. func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool, - paramTypes []*types.FieldType) *PlanCacheValue { + paramTypes []*types.FieldType, stmtHints *stmtctx.StmtHints) *PlanCacheValue { dstMap := make(map[*model.TableInfo]bool) for k, v := range srcMap { dstMap[k] = v @@ -409,6 +412,7 @@ func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.Ta OutPutNames: names, TblInfo2UnionScan: dstMap, ParamTypes: userParamTypes, + stmtHints: stmtHints.Clone(), } } diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 649cb3dd1bbb0..9294b7ac64788 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -96,6 +96,8 @@ type PointGetPlan struct { // probeParents records the IndexJoins and Applys with this operator in their inner children. // Please see comments in PhysicalPlan for details. probeParents []PhysicalPlan + // stmtHints should restore in executing context. + stmtHints *stmtctx.StmtHints } func (p *PointGetPlan) getEstRowCountForDisplay() float64 { diff --git a/session/session_test/BUILD.bazel b/session/session_test/BUILD.bazel index f0fa774e9f9e3..a4701232f403b 100644 --- a/session/session_test/BUILD.bazel +++ b/session/session_test/BUILD.bazel @@ -26,6 +26,7 @@ go_test( "//privilege/privileges", "//session", "//sessionctx", + "//sessionctx/stmtctx", "//sessionctx/variable", "//store/copr", "//store/mockstore", diff --git a/session/session_test/session_test.go b/session/session_test/session_test.go index 827db086dda42..87ea706a782eb 100644 --- a/session/session_test/session_test.go +++ b/session/session_test/session_test.go @@ -42,6 +42,7 @@ import ( "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/copr" "github.com/pingcap/tidb/store/mockstore" @@ -4139,3 +4140,63 @@ func TestSQLModeOp(t *testing.T) { a = mysql.SetSQLMode(s, mysql.ModeAllowInvalidDates) require.Equal(t, mysql.ModeNoBackslashEscapes|mysql.ModeOnlyFullGroupBy|mysql.ModeAllowInvalidDates, a) } + +func TestPrepareExecuteWithSQLHints(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + se := tk.Session() + se.SetConnectionID(1) + tk.MustExec("use test") + tk.MustExec("create table t(a int primary key)") + + type hintCheck struct { + hint string + check func(*stmtctx.StmtHints) + } + + hintChecks := []hintCheck{ + { + hint: "MEMORY_QUOTA(1024 MB)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasMemQuotaHint) + require.Equal(t, int64(1024*1024*1024), stmtHint.MemQuotaQuery) + }, + }, + { + hint: "READ_CONSISTENT_REPLICA()", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasReplicaReadHint) + require.Equal(t, byte(kv.ReplicaReadFollower), stmtHint.ReplicaRead) + }, + }, + { + hint: "MAX_EXECUTION_TIME(1000)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasMaxExecutionTime) + require.Equal(t, uint64(1000), stmtHint.MaxExecutionTime) + }, + }, + { + hint: "USE_TOJA(TRUE)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasAllowInSubqToJoinAndAggHint) + require.True(t, stmtHint.AllowInSubqToJoinAndAgg) + }, + }, + } + + for i, check := range hintChecks { + // common path + tk.MustExec(fmt.Sprintf("prepare stmt%d from 'select /*+ %s */ * from t'", i, check.hint)) + for j := 0; j < 10; j++ { + tk.MustQuery(fmt.Sprintf("execute stmt%d", i)) + check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints) + } + // fast path + tk.MustExec(fmt.Sprintf("prepare fast%d from 'select /*+ %s */ * from t where a = 1'", i, check.hint)) + for j := 0; j < 10; j++ { + tk.MustQuery(fmt.Sprintf("execute fast%d", i)) + check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints) + } + } +} diff --git a/sessionctx/stmtctx/BUILD.bazel b/sessionctx/stmtctx/BUILD.bazel index ed9294d5b7890..65fcbb08fc697 100644 --- a/sessionctx/stmtctx/BUILD.bazel +++ b/sessionctx/stmtctx/BUILD.bazel @@ -34,7 +34,7 @@ go_test( ], embed = [":stmtctx"], flaky = True, - shard_count = 5, + shard_count = 6, deps = [ "//kv", "//sessionctx/variable", diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index cd440b7b785f7..d8e0bc7cd5fb7 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -389,7 +389,6 @@ type StatementContext struct { type StmtHints struct { // Hint Information MemQuotaQuery int64 - ApplyCacheCapacity int64 MaxExecutionTime uint64 ReplicaRead byte AllowInSubqToJoinAndAgg bool @@ -418,6 +417,41 @@ func (sh *StmtHints) TaskMapNeedBackUp() bool { return sh.ForceNthPlan != -1 } +// Clone the StmtHints struct and returns the pointer of the new one. +func (sh *StmtHints) Clone() *StmtHints { + var ( + vars map[string]string + tableHints []*ast.TableOptimizerHint + ) + if len(sh.SetVars) > 0 { + vars = make(map[string]string, len(sh.SetVars)) + for k, v := range sh.SetVars { + vars[k] = v + } + } + if len(sh.OriginalTableHints) > 0 { + tableHints = make([]*ast.TableOptimizerHint, len(sh.OriginalTableHints)) + copy(tableHints, sh.OriginalTableHints) + } + return &StmtHints{ + MemQuotaQuery: sh.MemQuotaQuery, + MaxExecutionTime: sh.MaxExecutionTime, + ReplicaRead: sh.ReplicaRead, + AllowInSubqToJoinAndAgg: sh.AllowInSubqToJoinAndAgg, + NoIndexMergeHint: sh.NoIndexMergeHint, + StraightJoinOrder: sh.StraightJoinOrder, + EnableCascadesPlanner: sh.EnableCascadesPlanner, + ForceNthPlan: sh.ForceNthPlan, + HasAllowInSubqToJoinAndAggHint: sh.HasAllowInSubqToJoinAndAggHint, + HasMemQuotaHint: sh.HasMemQuotaHint, + HasReplicaReadHint: sh.HasReplicaReadHint, + HasMaxExecutionTime: sh.HasMaxExecutionTime, + HasEnableCascadesPlannerHint: sh.HasEnableCascadesPlannerHint, + SetVars: vars, + OriginalTableHints: tableHints, + } +} + // StmtCacheKey represents the key type in the StmtCache. type StmtCacheKey int diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index 53555184e0d46..a4964663298ab 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -19,6 +19,7 @@ import ( "encoding/json" "fmt" "math/rand" + "reflect" "sort" "testing" "time" @@ -272,3 +273,25 @@ func TestApproxRuntimeInfo(t *testing.T) { require.Equal(t, d.TotBackoffTime[backoff], timeSum) } } + +func TestStmtHintsClone(t *testing.T) { + hints := stmtctx.StmtHints{} + value := reflect.ValueOf(&hints).Elem() + for i := 0; i < value.NumField(); i++ { + field := value.Field(i) + switch field.Kind() { + case reflect.Int, reflect.Int32, reflect.Int64: + field.SetInt(1) + case reflect.Uint, reflect.Uint32, reflect.Uint64: + field.SetUint(1) + case reflect.Uint8: // byte + field.SetUint(1) + case reflect.Bool: + field.SetBool(true) + case reflect.String: + field.SetString("test") + default: + } + } + require.Equal(t, hints, *hints.Clone()) +}