diff --git a/planner/core/cache.go b/planner/core/cache.go index 8bcc99ac577bf..84dbf879344e2 100644 --- a/planner/core/cache.go +++ b/planner/core/cache.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" @@ -181,6 +182,8 @@ type PlanCacheValue struct { BinVarTypes []byte // variable types under binary protocol IsBinProto bool // whether this plan is under binary protocol BindSQL string + // stmtHints stores the hints which set session variables, because the hints won't be processed using cached plan. + stmtHints *stmtctx.StmtHints } func (v *PlanCacheValue) varTypesUnchanged(binVarTps []byte, txtVarTps []*types.FieldType) bool { @@ -192,7 +195,7 @@ func (v *PlanCacheValue) varTypesUnchanged(binVarTps []byte, txtVarTps []*types. // NewPlanCacheValue creates a SQLCacheValue. func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool, - isBinProto bool, binVarTypes []byte, txtVarTps []*types.FieldType, bindSQL string) *PlanCacheValue { + isBinProto bool, binVarTypes []byte, txtVarTps []*types.FieldType, stmtCtx *stmtctx.StatementContext) *PlanCacheValue { dstMap := make(map[*model.TableInfo]bool) for k, v := range srcMap { dstMap[k] = v @@ -208,7 +211,8 @@ func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.Ta TxtVarTypes: userVarTypes, BinVarTypes: binVarTypes, IsBinProto: isBinProto, - BindSQL: bindSQL, + BindSQL: stmtCtx.BindSQL, + stmtHints: stmtCtx.StmtHints.Clone(), } } diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 55ab4b98edf54..597ae94569f5a 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -509,6 +509,11 @@ func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, e.names = names e.Plan = plan stmtCtx.PointExec = true + if pointPlan, ok := plan.(*PointGetPlan); ok { + if pointPlan.stmtHints != nil { + stmtCtx.StmtHints = *pointPlan.stmtHints + } + } return nil } if prepared.UseCache && !ignorePlanCache { // for general plans @@ -564,6 +569,9 @@ func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, e.names = cachedVal.OutPutNames e.Plan = cachedVal.Plan stmtCtx.SetPlanDigest(preparedStmt.NormalizedPlan, preparedStmt.PlanDigest) + if cachedVal.stmtHints != nil { + stmtCtx.StmtHints = *cachedVal.stmtHints + } return nil } break @@ -601,7 +609,7 @@ REBUILD: } sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{} } - cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, isBinProtocol, binVarTypes, txtVarTypes, sessVars.StmtCtx.BindSQL) + cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, isBinProtocol, binVarTypes, txtVarTypes, sessVars.StmtCtx) preparedStmt.NormalizedPlan, preparedStmt.PlanDigest = NormalizePlan(p) stmtCtx.SetPlanDigest(preparedStmt.NormalizedPlan, preparedStmt.PlanDigest) if cacheVals, exists := sctx.PreparedPlanCache().Get(cacheKey); exists { @@ -673,13 +681,14 @@ func (e *Execute) tryCachePointPlan(ctx context.Context, sctx sessionctx.Context err error names types.NameSlice ) - switch p.(type) { + switch pointPlan := p.(type) { case *PointGetPlan: ok, err = IsPointGetWithPKOrUniqueKeyByAutoCommit(sctx, p) names = p.OutputNames() if err != nil { return err } + pointPlan.stmtHints = sctx.GetSessionVars().StmtCtx.StmtHints.Clone() } if ok { // just cache point plan now diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 5d309606afc06..2b6e36b219d77 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -87,6 +87,8 @@ type PointGetPlan struct { planCost float64 // accessCols represents actual columns the PointGet will access, which are used to calculate row-size accessCols []*expression.Column + // stmtHints should restore in executing context. + stmtHints *stmtctx.StmtHints } type nameValuePair struct { diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index d5f19e8fe6dd5..a78b7ee759035 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -300,7 +300,6 @@ type StatementContext struct { type StmtHints struct { // Hint Information MemQuotaQuery int64 - ApplyCacheCapacity int64 MaxExecutionTime uint64 ReplicaRead byte AllowInSubqToJoinAndAgg bool @@ -329,6 +328,41 @@ func (sh *StmtHints) TaskMapNeedBackUp() bool { return sh.ForceNthPlan != -1 } +// Clone the StmtHints struct and returns the pointer of the new one. +func (sh *StmtHints) Clone() *StmtHints { + var ( + vars map[string]string + tableHints []*ast.TableOptimizerHint + ) + if len(sh.SetVars) > 0 { + vars = make(map[string]string, len(sh.SetVars)) + for k, v := range sh.SetVars { + vars[k] = v + } + } + if len(sh.OriginalTableHints) > 0 { + tableHints = make([]*ast.TableOptimizerHint, len(sh.OriginalTableHints)) + copy(tableHints, sh.OriginalTableHints) + } + return &StmtHints{ + MemQuotaQuery: sh.MemQuotaQuery, + MaxExecutionTime: sh.MaxExecutionTime, + ReplicaRead: sh.ReplicaRead, + AllowInSubqToJoinAndAgg: sh.AllowInSubqToJoinAndAgg, + NoIndexMergeHint: sh.NoIndexMergeHint, + StraightJoinOrder: sh.StraightJoinOrder, + EnableCascadesPlanner: sh.EnableCascadesPlanner, + ForceNthPlan: sh.ForceNthPlan, + HasAllowInSubqToJoinAndAggHint: sh.HasAllowInSubqToJoinAndAggHint, + HasMemQuotaHint: sh.HasMemQuotaHint, + HasReplicaReadHint: sh.HasReplicaReadHint, + HasMaxExecutionTime: sh.HasMaxExecutionTime, + HasEnableCascadesPlannerHint: sh.HasEnableCascadesPlannerHint, + SetVars: vars, + OriginalTableHints: tableHints, + } +} + // StmtCacheKey represents the key type in the StmtCache. type StmtCacheKey int diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index 7a4ec77a90660..d1a083043ae2f 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -17,6 +17,7 @@ package stmtctx_test import ( "context" "fmt" + "reflect" "testing" "time" @@ -143,3 +144,25 @@ func TestWeakConsistencyRead(t *testing.T) { execAndCheck("execute s", testkit.Rows("1 1 2"), kv.SI) tk.MustExec("rollback") } + +func TestStmtHintsClone(t *testing.T) { + hints := stmtctx.StmtHints{} + value := reflect.ValueOf(&hints).Elem() + for i := 0; i < value.NumField(); i++ { + field := value.Field(i) + switch field.Kind() { + case reflect.Int, reflect.Int32, reflect.Int64: + field.SetInt(1) + case reflect.Uint, reflect.Uint32, reflect.Uint64: + field.SetUint(1) + case reflect.Uint8: // byte + field.SetUint(1) + case reflect.Bool: + field.SetBool(true) + case reflect.String: + field.SetString("test") + default: + } + } + require.Equal(t, hints, *hints.Clone()) +} diff --git a/tests/realtikvtest/sessiontest/session_test.go b/tests/realtikvtest/sessiontest/session_test.go index 075e0d0ef161b..ac11e3f59e811 100644 --- a/tests/realtikvtest/sessiontest/session_test.go +++ b/tests/realtikvtest/sessiontest/session_test.go @@ -38,6 +38,7 @@ import ( "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/copr" "github.com/pingcap/tidb/store/mockstore" @@ -3792,3 +3793,64 @@ func TestSQLModeOp(t *testing.T) { a = mysql.SetSQLMode(s, mysql.ModeAllowInvalidDates) require.Equal(t, mysql.ModeNoBackslashEscapes|mysql.ModeOnlyFullGroupBy|mysql.ModeAllowInvalidDates, a) } + +func TestPrepareExecuteWithSQLHints(t *testing.T) { + store, clean := realtikvtest.CreateMockStoreAndSetup(t) + defer clean() + tk := testkit.NewTestKit(t, store) + se := tk.Session() + se.SetConnectionID(1) + tk.MustExec("use test") + tk.MustExec("create table t(a int primary key)") + + type hintCheck struct { + hint string + check func(*stmtctx.StmtHints) + } + + hintChecks := []hintCheck{ + { + hint: "MEMORY_QUOTA(1024 MB)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasMemQuotaHint) + require.Equal(t, int64(1024*1024*1024), stmtHint.MemQuotaQuery) + }, + }, + { + hint: "READ_CONSISTENT_REPLICA()", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasReplicaReadHint) + require.Equal(t, byte(kv.ReplicaReadFollower), stmtHint.ReplicaRead) + }, + }, + { + hint: "MAX_EXECUTION_TIME(1000)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasMaxExecutionTime) + require.Equal(t, uint64(1000), stmtHint.MaxExecutionTime) + }, + }, + { + hint: "USE_TOJA(TRUE)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasAllowInSubqToJoinAndAggHint) + require.True(t, stmtHint.AllowInSubqToJoinAndAgg) + }, + }, + } + + for i, check := range hintChecks { + // common path + tk.MustExec(fmt.Sprintf("prepare stmt%d from 'select /*+ %s */ * from t'", i, check.hint)) + for j := 0; j < 10; j++ { + tk.MustQuery(fmt.Sprintf("execute stmt%d", i)) + check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints) + } + // fast path + tk.MustExec(fmt.Sprintf("prepare fast%d from 'select /*+ %s */ * from t where a = 1'", i, check.hint)) + for j := 0; j < 10; j++ { + tk.MustQuery(fmt.Sprintf("execute fast%d", i)) + check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints) + } + } +}