Skip to content

Commit

Permalink
planner: fix the usage of ParamMaker in BatchPointGet (#32534)
Browse files Browse the repository at this point in the history
ref #31056
  • Loading branch information
qw4990 authored Feb 28, 2022
1 parent f0d8352 commit e8be778
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 22 deletions.
13 changes: 10 additions & 3 deletions planner/core/common_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -769,8 +769,11 @@ func (e *Execute) rebuildRange(p Plan) error {
}
for i, param := range x.HandleParams {
if param != nil {
var iv int64
iv, err = param.Datum.ToInt64(sc)
dVal, err := convertConstant2Datum(sc, param, x.HandleType)
if err != nil {
return err
}
iv, err := dVal.ToInt64(sc)
if err != nil {
return err
}
Expand All @@ -783,7 +786,11 @@ func (e *Execute) rebuildRange(p Plan) error {
}
for j, param := range params {
if param != nil {
x.IndexValues[i][j] = param.Datum
dVal, err := convertConstant2Datum(sc, param, x.IndexColTypes[j])
if err != nil {
return err
}
x.IndexValues[i][j] = *dVal
}
}
}
Expand Down
67 changes: 53 additions & 14 deletions planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,11 @@ type BatchPointGetPlan struct {
TblInfo *model.TableInfo
IndexInfo *model.IndexInfo
Handles []kv.Handle
HandleParams []*driver.ParamMarkerExpr
HandleType *types.FieldType
HandleParams []*expression.Constant // record all Parameters for Plan-Cache
IndexValues [][]types.Datum
IndexValueParams [][]*driver.ParamMarkerExpr
IndexValueParams [][]*expression.Constant // record all Parameters for Plan-Cache
IndexColTypes []*types.FieldType
AccessConditions []expression.Expression
IdxCols []*expression.Column
IdxColLens []int
Expand Down Expand Up @@ -587,20 +589,27 @@ func newBatchPointGetPlan(
}
if handleCol != nil {
var handles = make([]kv.Handle, len(patternInExpr.List))
var handleParams = make([]*driver.ParamMarkerExpr, len(patternInExpr.List))
var handleParams = make([]*expression.Constant, len(patternInExpr.List))
for i, item := range patternInExpr.List {
// SELECT * FROM t WHERE (key) in ((1), (2))
if p, ok := item.(*ast.ParenthesesExpr); ok {
item = p.Expr
}
var d types.Datum
var param *driver.ParamMarkerExpr
var con *expression.Constant
switch x := item.(type) {
case *driver.ValueExpr:
d = x.Datum
case *driver.ParamMarkerExpr:
d = x.Datum
param = x
var err error
con, err = expression.ParamMarkerExpression(ctx, x, true)
if err != nil {
return nil
}
d, err = con.Eval(chunk.Row{})
if err != nil {
return nil
}
default:
return nil
}
Expand All @@ -612,12 +621,13 @@ func newBatchPointGetPlan(
return nil
}
handles[i] = kv.IntHandle(intDatum.GetInt64())
handleParams[i] = param
handleParams[i] = con
}
return BatchPointGetPlan{
TblInfo: tbl,
Handles: handles,
HandleParams: handleParams,
HandleType: &handleCol.FieldType,
PartitionExpr: partitionExpr,
}.Init(ctx, statsInfo, schema, names, 0)
}
Expand Down Expand Up @@ -672,22 +682,28 @@ func newBatchPointGetPlan(
}

indexValues := make([][]types.Datum, len(patternInExpr.List))
indexValueParams := make([][]*driver.ParamMarkerExpr, len(patternInExpr.List))
indexValueParams := make([][]*expression.Constant, len(patternInExpr.List))
var indexTypes []*types.FieldType
for i, item := range patternInExpr.List {
// SELECT * FROM t WHERE (key) in ((1), (2))
if p, ok := item.(*ast.ParenthesesExpr); ok {
item = p.Expr
}
var values []types.Datum
var valuesParams []*driver.ParamMarkerExpr
var valuesParams []*expression.Constant
switch x := item.(type) {
case *ast.RowExpr:
// The `len(values) == len(valuesParams)` should be satisfied in this mode
if len(x.Values) != len(whereColNames) {
return nil
}
values = make([]types.Datum, len(x.Values))
valuesParams = make([]*driver.ParamMarkerExpr, len(x.Values))
valuesParams = make([]*expression.Constant, len(x.Values))
initTypes := false
if indexTypes == nil { // only init once
indexTypes = make([]*types.FieldType, len(x.Values))
initTypes = true
}
for index, inner := range x.Values {
permIndex := permutations[index]
switch innerX := inner.(type) {
Expand All @@ -698,12 +714,23 @@ func newBatchPointGetPlan(
}
values[permIndex] = innerX.Datum
case *driver.ParamMarkerExpr:
dval := getPointGetValue(stmtCtx, colInfos[index], &innerX.Datum)
con, err := expression.ParamMarkerExpression(ctx, innerX, true)
if err != nil {
return nil
}
d, err := con.Eval(chunk.Row{})
if err != nil {
return nil
}
dval := getPointGetValue(stmtCtx, colInfos[index], &d)
if dval == nil {
return nil
}
values[permIndex] = innerX.Datum
valuesParams[permIndex] = innerX
valuesParams[permIndex] = con
if initTypes {
indexTypes[permIndex] = &colInfos[index].FieldType
}
default:
return nil
}
Expand All @@ -723,12 +750,23 @@ func newBatchPointGetPlan(
if len(whereColNames) != 1 {
return nil
}
dval := getPointGetValue(stmtCtx, colInfos[0], &x.Datum)
con, err := expression.ParamMarkerExpression(ctx, x, true)
if err != nil {
return nil
}
d, err := con.Eval(chunk.Row{})
if err != nil {
return nil
}
dval := getPointGetValue(stmtCtx, colInfos[0], &d)
if dval == nil {
return nil
}
values = []types.Datum{*dval}
valuesParams = []*driver.ParamMarkerExpr{x}
valuesParams = []*expression.Constant{con}
if indexTypes == nil { // only init once
indexTypes = []*types.FieldType{&colInfos[0].FieldType}
}
default:
return nil
}
Expand All @@ -740,6 +778,7 @@ func newBatchPointGetPlan(
IndexInfo: matchIdxInfo,
IndexValues: indexValues,
IndexValueParams: indexValueParams,
IndexColTypes: indexTypes,
PartitionColPos: pos,
PartitionExpr: partitionExpr,
}.Init(ctx, statsInfo, schema, names, 0)
Expand Down
50 changes: 45 additions & 5 deletions planner/core/prepare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,7 @@ func TestParamMarker4FastPlan(t *testing.T) {
require.NoError(t, err)
tk := testkit.NewTestKitWithSession(t, store, se)

// test handle
// test handle for point get
tk.MustExec(`use test`)
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(pk int primary key)")
Expand All @@ -1565,7 +1565,7 @@ func TestParamMarker4FastPlan(t *testing.T) {
tk.MustQuery(`execute stmt using @a1`).Check(testkit.Rows())
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0"))

// test indexValues
// test indexValues for point get
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(pk int, unique index idx(pk))")
tk.MustExec("insert into t values(1)")
Expand All @@ -1588,7 +1588,7 @@ func TestParamMarker4FastPlan(t *testing.T) {
tk.MustQuery(`execute stmt using @a1`).Check(testkit.Rows())
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0"))

// test _tidb_rowid
// test _tidb_rowid for point get
tk.MustExec(`use test`)
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int);")
Expand All @@ -1599,6 +1599,46 @@ func TestParamMarker4FastPlan(t *testing.T) {
tk.MustExec(`set @a=1`)
tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1 7"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))

// test handle for batch point get
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(pk int primary key)")
tk.MustExec("insert into t values (1), (2), (3), (4), (5)")
tk.MustExec(`prepare stmt from 'select * from t where pk in (1, ?, ?)'`)
tk.MustExec(`set @a0=0, @a1=1, @a2=2, @a3=3, @a1_1=1.1, @a4=4, @a5=5`)
tk.MustQuery(`execute stmt using @a2, @a3`).Sort().Check(testkit.Rows("1", "2", "3"))
tk.MustQuery(`execute stmt using @a2, @a3`).Sort().Check(testkit.Rows("1", "2", "3"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
tk.MustQuery(`execute stmt using @a0, @a4`).Sort().Check(testkit.Rows("1", "4"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
tk.MustQuery(`execute stmt using @a1_1, @a5`).Sort().Check(testkit.Rows("1", "5"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0"))

// test indexValues for batch point get
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(pk int, unique index idx(pk))")
tk.MustExec("insert into t values (1), (2), (3), (4), (5)")
tk.MustExec(`prepare stmt from 'select * from t where pk in (1, ?, ?)'`)
tk.MustExec(`set @a0=0, @a1=1, @a2=2, @a3=3, @a1_1=1.1, @a4=4, @a5=5`)
tk.MustQuery(`execute stmt using @a2, @a3`).Sort().Check(testkit.Rows("1", "2", "3"))
tk.MustQuery(`execute stmt using @a2, @a3`).Sort().Check(testkit.Rows("1", "2", "3"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
tk.MustQuery(`execute stmt using @a0, @a4`).Sort().Check(testkit.Rows("1", "4"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
tk.MustQuery(`execute stmt using @a1_1, @a5`).Sort().Check(testkit.Rows("1", "5"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0"))

// test _tidb_rowid for batch point get
tk.MustExec(`use test`)
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int);")
tk.MustExec("insert t values (1, 7), (1, 8), (1, 9), (1, 10);")
tk.MustExec(`prepare stmt from 'select * from t where _tidb_rowid in (1, ?, ?)'`)
tk.MustExec(`set @a2=2, @a3=3`)
tk.MustQuery("execute stmt using @a2, @a3;").Sort().Check(testkit.Rows("1 7", "1 8", "1 9"))
tk.MustExec(`set @a2=4, @a3=2`)
tk.MustQuery("execute stmt using @a2, @a3;").Sort().Check(testkit.Rows("1 10", "1 7", "1 8"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
}

func TestIssue29565(t *testing.T) {
Expand Down Expand Up @@ -2125,7 +2165,7 @@ func TestIssue29993(t *testing.T) {
tk.MustQuery("execute stmt using @b").Check(testkit.Rows())
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
tk.MustQuery("execute stmt using @z").Check(testkit.Rows())
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // invalid since 'z' is not in enum('a', 'b')
tk.MustQuery("execute stmt using @z").Check(testkit.Rows())

// test PointGet + non cluster index
Expand Down Expand Up @@ -2153,7 +2193,7 @@ func TestIssue29993(t *testing.T) {
tk.MustQuery("execute stmt using @b").Check(testkit.Rows())
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
tk.MustQuery("execute stmt using @z").Check(testkit.Rows())
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1"))
tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // invalid since 'z' is not in enum('a', 'b')
tk.MustQuery("execute stmt using @z").Check(testkit.Rows())
}

Expand Down

0 comments on commit e8be778

Please sign in to comment.