Skip to content

Commit

Permalink
planner: add more test cases for non-prep plan cache (#40060)
Browse files Browse the repository at this point in the history
  • Loading branch information
qw4990 authored Dec 20, 2022
1 parent 0f3031e commit 5f1a739
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 7 deletions.
4 changes: 4 additions & 0 deletions planner/core/plan_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ func planCachePreprocess(ctx context.Context, sctx sessionctx.Context, isNonPrep
func GetPlanFromSessionPlanCache(ctx context.Context, sctx sessionctx.Context,
isNonPrepared bool, is infoschema.InfoSchema, stmt *PlanCacheStmt,
params []expression.Expression) (plan Plan, names []*types.FieldName, err error) {
if v := ctx.Value("____GetPlanFromSessionPlanCacheErr"); v != nil { // for testing
return nil, nil, errors.New("____GetPlanFromSessionPlanCacheErr")
}

if err := planCachePreprocess(ctx, sctx, isNonPrepared, is, stmt, params); err != nil {
return nil, nil, err
}
Expand Down
11 changes: 8 additions & 3 deletions planner/core/plan_cache_param.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package core

import (
"context"
"errors"
"strings"
"sync"
Expand Down Expand Up @@ -70,7 +71,7 @@ func (pr *paramReplacer) Reset() { pr.params = nil }
// ParameterizeAST parameterizes this StmtNode.
// e.g. `select * from t where a<10 and b<23` --> `select * from t where a<? and b<?`, [10, 23].
// NOTICE: this function may modify the input stmt.
func ParameterizeAST(sctx sessionctx.Context, stmt ast.StmtNode) (paramSQL string, params []*driver.ValueExpr, err error) {
func ParameterizeAST(ctx context.Context, sctx sessionctx.Context, stmt ast.StmtNode) (paramSQL string, params []*driver.ValueExpr, err error) {
pr := paramReplacerPool.Get().(*paramReplacer)
pCtx := paramCtxPool.Get().(*format.RestoreCtx)
defer func() {
Expand All @@ -81,7 +82,7 @@ func ParameterizeAST(sctx sessionctx.Context, stmt ast.StmtNode) (paramSQL strin
}()
stmt.Accept(pr)
if err := stmt.Restore(pCtx); err != nil {
err = RestoreASTWithParams(sctx, stmt, pr.params)
err = RestoreASTWithParams(ctx, sctx, stmt, pr.params) // keep the stmt unchanged if err
return "", nil, err
}
paramSQL, params = pCtx.In.(*strings.Builder).String(), pr.params
Expand Down Expand Up @@ -119,7 +120,11 @@ func (pr *paramRestorer) Reset() {

// RestoreASTWithParams restore this parameterized AST with specific parameters.
// e.g. `select * from t where a<? and b<?`, [10, 23] --> `select * from t where a<10 and b<23`.
func RestoreASTWithParams(_ sessionctx.Context, stmt ast.StmtNode, params []*driver.ValueExpr) error {
func RestoreASTWithParams(ctx context.Context, _ sessionctx.Context, stmt ast.StmtNode, params []*driver.ValueExpr) error {
if v := ctx.Value("____RestoreASTWithParamsErr"); v != nil {
return errors.New("____RestoreASTWithParamsErr")
}

pr := paramRestorerPool.Get().(*paramRestorer)
defer func() {
pr.Reset()
Expand Down
5 changes: 3 additions & 2 deletions planner/core/plan_cache_param_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package core

import (
"context"
"strings"
"testing"

Expand Down Expand Up @@ -61,15 +62,15 @@ func TestParameterize(t *testing.T) {
for _, c := range cases {
stmt, err := parser.New().ParseOneStmt(c.sql, "", "")
require.Nil(t, err)
paramSQL, params, err := ParameterizeAST(sctx, stmt)
paramSQL, params, err := ParameterizeAST(context.Background(), sctx, stmt)
require.Nil(t, err)
require.Equal(t, c.paramSQL, paramSQL)
require.Equal(t, len(c.params), len(params))
for i := range params {
require.Equal(t, c.params[i], params[i].Datum.GetValue())
}

err = RestoreASTWithParams(sctx, stmt, params)
err = RestoreASTWithParams(context.Background(), sctx, stmt, params)
require.Nil(t, err)
var buf strings.Builder
rCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &buf)
Expand Down
40 changes: 40 additions & 0 deletions planner/core/plan_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package core_test

import (
"context"
"errors"
"fmt"
"math/rand"
Expand Down Expand Up @@ -112,6 +113,45 @@ func TestNonPreparedPlanCacheWithExplain(t *testing.T) {
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
}

func TestNonPreparedPlanCacheFallback(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec(`use test`)
tk.MustExec(`create table t (a int)`)
for i := 0; i < 5; i++ {
tk.MustExec(fmt.Sprintf("insert into t values (%v)", i))
}
tk.MustExec("set tidb_enable_non_prepared_plan_cache=1")

// inject a fault to GeneratePlanCacheStmtWithAST
ctx := context.WithValue(context.Background(), "____GeneratePlanCacheStmtWithASTErr", struct{}{})
tk.MustQueryWithContext(ctx, "select * from t where a in (1, 2)").Sort().Check(testkit.Rows("1", "2"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // cannot generate PlanCacheStmt
tk.MustQueryWithContext(ctx, "select * from t where a in (1, 3)").Sort().Check(testkit.Rows("1", "3"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // cannot generate PlanCacheStmt
tk.MustQuery("select * from t where a in (1, 2)").Sort().Check(testkit.Rows("1", "2"))
tk.MustQuery("select * from t where a in (1, 3)").Sort().Check(testkit.Rows("1", "3"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) // no error

// inject a fault to GetPlanFromSessionPlanCache
tk.MustQuery("select * from t where a=1").Check(testkit.Rows("1")) // cache this plan
tk.MustQuery("select * from t where a=2").Check(testkit.Rows("2")) // plan from cache
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
ctx = context.WithValue(context.Background(), "____GetPlanFromSessionPlanCacheErr", struct{}{})
tk.MustQueryWithContext(ctx, "select * from t where a=3").Check(testkit.Rows("3"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // fallback to the normal opt-path
tk.MustQueryWithContext(ctx, "select * from t where a=4").Check(testkit.Rows("4"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // fallback to the normal opt-path
tk.MustQueryWithContext(context.Background(), "select * from t where a=0").Check(testkit.Rows("0"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) // use the cached plan if no error

// inject a fault to RestoreASTWithParams
ctx = context.WithValue(context.Background(), "____GetPlanFromSessionPlanCacheErr", struct{}{})
ctx = context.WithValue(ctx, "____RestoreASTWithParamsErr", struct{}{})
_, err := tk.ExecWithContext(ctx, "select * from t where a=1")
require.NotNil(t, err)
}

func TestNonPreparedPlanCacheBasically(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
4 changes: 4 additions & 0 deletions planner/core/plan_cache_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ func (e *paramMarkerExtractor) Leave(in ast.Node) (ast.Node, bool) {
// paramSQL is the corresponding parameterized sql like 'select * from t where a<? and b>?'.
// paramStmt is the Node of paramSQL.
func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, paramSQL string, paramStmt ast.StmtNode) (*PlanCacheStmt, Plan, int, error) {
if v := ctx.Value("____GeneratePlanCacheStmtWithASTErr"); v != nil { // for testing
return nil, nil, 0, errors.New("____GeneratePlanCacheStmtWithASTErr")
}

vars := sctx.GetSessionVars()
var extractor paramMarkerExtractor
paramStmt.Accept(&extractor)
Expand Down
11 changes: 9 additions & 2 deletions planner/optimize.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,22 @@ func matchSQLBinding(sctx sessionctx.Context, stmtNode ast.StmtNode) (bindRecord
}

// getPlanFromNonPreparedPlanCache tries to get an available cached plan from the NonPrepared Plan Cache for this stmt.
func getPlanFromNonPreparedPlanCache(ctx context.Context, sctx sessionctx.Context, stmt ast.StmtNode, is infoschema.InfoSchema) (core.Plan, types.NameSlice, bool, error) {
func getPlanFromNonPreparedPlanCache(ctx context.Context, sctx sessionctx.Context, stmt ast.StmtNode, is infoschema.InfoSchema) (p core.Plan, ns types.NameSlice, ok bool, err error) {
if sctx.GetSessionVars().StmtCtx.InPreparedPlanBuilding || // already in cached plan rebuilding phase
!core.NonPreparedPlanCacheableWithCtx(sctx, stmt, is) {
return nil, nil, false, nil
}
paramSQL, params, err := core.ParameterizeAST(sctx, stmt)
paramSQL, params, err := core.ParameterizeAST(ctx, sctx, stmt)
if err != nil {
return nil, nil, false, err
}
defer func() {
if err != nil {
// keep the stmt unchanged if err so that it can fallback to the normal optimization path.
// TODO: add metrics
err = core.RestoreASTWithParams(ctx, sctx, stmt, params)
}
}()
val := sctx.GetSessionVars().GetNonPreparedPlanCacheStmt(paramSQL)
if val == nil {
cachedStmt, _, _, err := core.GeneratePlanCacheStmtWithAST(ctx, sctx, paramSQL, stmt)
Expand Down

0 comments on commit 5f1a739

Please sign in to comment.