Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: store the hints of session variable #45814

Merged
merged 2 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions planner/core/plan_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ func getCachedPointPlan(stmt *ast.Prepared, sessVars *variable.SessionVars, stmt
}
sessVars.FoundInPlanCache = true
stmtCtx.PointExec = true
if pointGetPlan, ok := plan.(*PointGetPlan); ok && pointGetPlan != nil && pointGetPlan.stmtHints != nil {
sessVars.StmtCtx.StmtHints = *pointGetPlan.stmtHints
}
return plan, names, true, nil
}

Expand Down Expand Up @@ -285,6 +288,7 @@ func getCachedPlan(sctx sessionctx.Context, isNonPrepared bool, cacheKey kvcache
core_metrics.GetPlanCacheHitCounter(isNonPrepared).Inc()
}
stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest)
stmtCtx.StmtHints = *cachedVal.stmtHints
return cachedVal.Plan, cachedVal.OutPutNames, true, nil
}

Expand Down Expand Up @@ -327,7 +331,7 @@ func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isNonPrepared
}
sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{}
}
cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, matchOpts)
cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, matchOpts, &stmtCtx.StmtHints)
stmt.NormalizedPlan, stmt.PlanDigest = NormalizePlan(p)
stmtCtx.SetPlan(p)
stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest)
Expand Down Expand Up @@ -757,12 +761,15 @@ func tryCachePointPlan(_ context.Context, sctx sessionctx.Context,
names types.NameSlice
)

if _, _ok := p.(*PointGetPlan); _ok {
if plan, _ok := p.(*PointGetPlan); _ok {
ok, err = IsPointGetWithPKOrUniqueKeyByAutoCommit(sctx, p)
names = p.OutputNames()
if err != nil {
return err
}
if ok {
plan.stmtHints = sctx.GetSessionVars().StmtCtx.StmtHints.Clone()
}
}

if ok {
Expand Down
6 changes: 5 additions & 1 deletion planner/core/plan_cache_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/pingcap/tidb/planner/util"
"github.com/pingcap/tidb/planner/util/fixcontrol"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/statistics"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -337,6 +338,8 @@ type PlanCacheValue struct {

// matchOpts stores some fields help to choose a suitable plan
matchOpts *utilpc.PlanCacheMatchOpts
// stmtHints stores the hints which set session variables, because the hints won't be processed using cached plan.
stmtHints *stmtctx.StmtHints
}

// unKnownMemoryUsage represent the memory usage of uncounted structure, maybe need implement later
Expand Down Expand Up @@ -383,7 +386,7 @@ func (v *PlanCacheValue) MemoryUsage() (sum int64) {

// NewPlanCacheValue creates a SQLCacheValue.
func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool,
matchOpts *utilpc.PlanCacheMatchOpts) *PlanCacheValue {
matchOpts *utilpc.PlanCacheMatchOpts, stmtHints *stmtctx.StmtHints) *PlanCacheValue {
dstMap := make(map[*model.TableInfo]bool)
for k, v := range srcMap {
dstMap[k] = v
Expand All @@ -397,6 +400,7 @@ func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.Ta
OutPutNames: names,
TblInfo2UnionScan: dstMap,
matchOpts: matchOpts,
stmtHints: stmtHints.Clone(),
}
}

Expand Down
2 changes: 2 additions & 0 deletions planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ type PointGetPlan struct {
// probeParents records the IndexJoins and Applys with this operator in their inner children.
// Please see comments in PhysicalPlan for details.
probeParents []PhysicalPlan
// stmtHints should restore in executing context.
stmtHints *stmtctx.StmtHints
}

func (p *PointGetPlan) getEstRowCountForDisplay() float64 {
Expand Down
3 changes: 2 additions & 1 deletion session/test/vars/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ go_test(
"vars_test.go",
],
flaky = True,
shard_count = 12,
shard_count = 13,
deps = [
"//config",
"//domain",
"//errno",
"//kv",
"//parser/mysql",
"//parser/terror",
"//sessionctx/stmtctx",
"//sessionctx/variable",
"//testkit",
"//testkit/testmain",
Expand Down
68 changes: 68 additions & 0 deletions session/test/vars/vars_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
tikv "github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/testkit"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -615,3 +616,70 @@ func TestGetSysVariables(t *testing.T) {
tk.MustExec("select @@local.performance_schema_max_mutex_classes")
tk.MustGetErrMsg("select @@global.last_insert_id", "[variable:1238]Variable 'last_insert_id' is a SESSION variable")
}

func TestPrepareExecuteWithSQLHints(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
se := tk.Session()
se.SetConnectionID(1)
tk.MustExec("use test")
tk.MustExec("create table t(a int primary key)")

type hintCheck struct {
hint string
check func(*stmtctx.StmtHints)
}

hintChecks := []hintCheck{
{
hint: "MEMORY_QUOTA(1024 MB)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasMemQuotaHint)
require.Equal(t, int64(1024*1024*1024), stmtHint.MemQuotaQuery)
},
},
{
hint: "READ_CONSISTENT_REPLICA()",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasReplicaReadHint)
require.Equal(t, byte(tikv.ReplicaReadFollower), stmtHint.ReplicaRead)
},
},
{
hint: "MAX_EXECUTION_TIME(1000)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasMaxExecutionTime)
require.Equal(t, uint64(1000), stmtHint.MaxExecutionTime)
},
},
{
hint: "USE_TOJA(TRUE)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasAllowInSubqToJoinAndAggHint)
require.True(t, stmtHint.AllowInSubqToJoinAndAgg)
},
},
{
hint: "RESOURCE_GROUP(rg1)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasResourceGroup)
require.Equal(t, "rg1", stmtHint.ResourceGroup)
},
},
}

for i, check := range hintChecks {
// common path
tk.MustExec(fmt.Sprintf("prepare stmt%d from 'select /*+ %s */ * from t'", i, check.hint))
you06 marked this conversation as resolved.
Show resolved Hide resolved
for j := 0; j < 10; j++ {
tk.MustQuery(fmt.Sprintf("execute stmt%d", i))
check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints)
}
// fast path
tk.MustExec(fmt.Sprintf("prepare fast%d from 'select /*+ %s */ * from t where a = 1'", i, check.hint))
for j := 0; j < 10; j++ {
tk.MustQuery(fmt.Sprintf("execute fast%d", i))
check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints)
}
}
}
2 changes: 1 addition & 1 deletion sessionctx/stmtctx/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ go_test(
],
embed = [":stmtctx"],
flaky = True,
shard_count = 5,
shard_count = 6,
deps = [
"//kv",
"//sessionctx/variable",
Expand Down
40 changes: 39 additions & 1 deletion sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,6 @@
type StmtHints struct {
// Hint Information
MemQuotaQuery int64
ApplyCacheCapacity int64
MaxExecutionTime uint64
TidbKvReadTimeout uint64
ReplicaRead byte
Expand Down Expand Up @@ -454,6 +453,45 @@
return sh.ForceNthPlan != -1
}

// Clone the StmtHints struct and returns the pointer of the new one.
func (sh *StmtHints) Clone() *StmtHints {
var (
vars map[string]string
tableHints []*ast.TableOptimizerHint
)
if len(sh.SetVars) > 0 {
vars = make(map[string]string, len(sh.SetVars))
for k, v := range sh.SetVars {
vars[k] = v
}

Check warning on line 466 in sessionctx/stmtctx/stmtctx.go

View check run for this annotation

Codecov / codecov/patch

sessionctx/stmtctx/stmtctx.go#L463-L466

Added lines #L463 - L466 were not covered by tests
}
if len(sh.OriginalTableHints) > 0 {
tableHints = make([]*ast.TableOptimizerHint, len(sh.OriginalTableHints))
copy(tableHints, sh.OriginalTableHints)
}
return &StmtHints{
MemQuotaQuery: sh.MemQuotaQuery,
MaxExecutionTime: sh.MaxExecutionTime,
TidbKvReadTimeout: sh.TidbKvReadTimeout,
ReplicaRead: sh.ReplicaRead,
AllowInSubqToJoinAndAgg: sh.AllowInSubqToJoinAndAgg,
NoIndexMergeHint: sh.NoIndexMergeHint,
StraightJoinOrder: sh.StraightJoinOrder,
EnableCascadesPlanner: sh.EnableCascadesPlanner,
ForceNthPlan: sh.ForceNthPlan,
ResourceGroup: sh.ResourceGroup,
HasAllowInSubqToJoinAndAggHint: sh.HasAllowInSubqToJoinAndAggHint,
HasMemQuotaHint: sh.HasMemQuotaHint,
HasReplicaReadHint: sh.HasReplicaReadHint,
HasMaxExecutionTime: sh.HasMaxExecutionTime,
HasTidbKvReadTimeout: sh.HasTidbKvReadTimeout,
HasEnableCascadesPlannerHint: sh.HasEnableCascadesPlannerHint,
HasResourceGroup: sh.HasResourceGroup,
SetVars: vars,
OriginalTableHints: tableHints,
}
}

// StmtCacheKey represents the key type in the StmtCache.
type StmtCacheKey int

Expand Down
23 changes: 23 additions & 0 deletions sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"fmt"
"math/rand"
"reflect"
"sort"
"testing"
"time"
Expand Down Expand Up @@ -273,3 +274,25 @@ func TestApproxRuntimeInfo(t *testing.T) {
require.Equal(t, d.TotBackoffTime[backoff], timeSum)
}
}

func TestStmtHintsClone(t *testing.T) {
hints := stmtctx.StmtHints{}
value := reflect.ValueOf(&hints).Elem()
for i := 0; i < value.NumField(); i++ {
field := value.Field(i)
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
field.SetInt(1)
case reflect.Uint, reflect.Uint32, reflect.Uint64:
field.SetUint(1)
case reflect.Uint8: // byte
field.SetUint(1)
case reflect.Bool:
field.SetBool(true)
case reflect.String:
field.SetString("test")
default:
}
}
require.Equal(t, hints, *hints.Clone())
}