From e8be77848d8586ad25b61f8b522b581c3b76760f Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Mon, 28 Feb 2022 14:17:45 +0800 Subject: [PATCH] planner: fix the usage of ParamMaker in BatchPointGet (#32534) ref pingcap/tidb#31056 --- planner/core/common_plans.go | 13 +++++-- planner/core/point_get_plan.go | 67 +++++++++++++++++++++++++++------- planner/core/prepare_test.go | 50 ++++++++++++++++++++++--- 3 files changed, 108 insertions(+), 22 deletions(-) diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index cab9b46085a7b..059710367bdcc 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -769,8 +769,11 @@ func (e *Execute) rebuildRange(p Plan) error { } for i, param := range x.HandleParams { if param != nil { - var iv int64 - iv, err = param.Datum.ToInt64(sc) + dVal, err := convertConstant2Datum(sc, param, x.HandleType) + if err != nil { + return err + } + iv, err := dVal.ToInt64(sc) if err != nil { return err } @@ -783,7 +786,11 @@ func (e *Execute) rebuildRange(p Plan) error { } for j, param := range params { if param != nil { - x.IndexValues[i][j] = param.Datum + dVal, err := convertConstant2Datum(sc, param, x.IndexColTypes[j]) + if err != nil { + return err + } + x.IndexValues[i][j] = *dVal } } } diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 9f3f2ce903e3e..dafd04997e1dc 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -281,9 +281,11 @@ type BatchPointGetPlan struct { TblInfo *model.TableInfo IndexInfo *model.IndexInfo Handles []kv.Handle - HandleParams []*driver.ParamMarkerExpr + HandleType *types.FieldType + HandleParams []*expression.Constant // record all Parameters for Plan-Cache IndexValues [][]types.Datum - IndexValueParams [][]*driver.ParamMarkerExpr + IndexValueParams [][]*expression.Constant // record all Parameters for Plan-Cache + IndexColTypes []*types.FieldType AccessConditions []expression.Expression IdxCols []*expression.Column IdxColLens []int @@ -587,20 +589,27 @@ func newBatchPointGetPlan( } if handleCol != nil { var handles = make([]kv.Handle, len(patternInExpr.List)) - var handleParams = make([]*driver.ParamMarkerExpr, len(patternInExpr.List)) + var handleParams = make([]*expression.Constant, len(patternInExpr.List)) for i, item := range patternInExpr.List { // SELECT * FROM t WHERE (key) in ((1), (2)) if p, ok := item.(*ast.ParenthesesExpr); ok { item = p.Expr } var d types.Datum - var param *driver.ParamMarkerExpr + var con *expression.Constant switch x := item.(type) { case *driver.ValueExpr: d = x.Datum case *driver.ParamMarkerExpr: - d = x.Datum - param = x + var err error + con, err = expression.ParamMarkerExpression(ctx, x, true) + if err != nil { + return nil + } + d, err = con.Eval(chunk.Row{}) + if err != nil { + return nil + } default: return nil } @@ -612,12 +621,13 @@ func newBatchPointGetPlan( return nil } handles[i] = kv.IntHandle(intDatum.GetInt64()) - handleParams[i] = param + handleParams[i] = con } return BatchPointGetPlan{ TblInfo: tbl, Handles: handles, HandleParams: handleParams, + HandleType: &handleCol.FieldType, PartitionExpr: partitionExpr, }.Init(ctx, statsInfo, schema, names, 0) } @@ -672,14 +682,15 @@ func newBatchPointGetPlan( } indexValues := make([][]types.Datum, len(patternInExpr.List)) - indexValueParams := make([][]*driver.ParamMarkerExpr, len(patternInExpr.List)) + indexValueParams := make([][]*expression.Constant, len(patternInExpr.List)) + var indexTypes []*types.FieldType for i, item := range patternInExpr.List { // SELECT * FROM t WHERE (key) in ((1), (2)) if p, ok := item.(*ast.ParenthesesExpr); ok { item = p.Expr } var values []types.Datum - var valuesParams []*driver.ParamMarkerExpr + var valuesParams []*expression.Constant switch x := item.(type) { case *ast.RowExpr: // The `len(values) == len(valuesParams)` should be satisfied in this mode @@ -687,7 +698,12 @@ func newBatchPointGetPlan( return nil } values = make([]types.Datum, len(x.Values)) - valuesParams = make([]*driver.ParamMarkerExpr, len(x.Values)) + valuesParams = make([]*expression.Constant, len(x.Values)) + initTypes := false + if indexTypes == nil { // only init once + indexTypes = make([]*types.FieldType, len(x.Values)) + initTypes = true + } for index, inner := range x.Values { permIndex := permutations[index] switch innerX := inner.(type) { @@ -698,12 +714,23 @@ func newBatchPointGetPlan( } values[permIndex] = innerX.Datum case *driver.ParamMarkerExpr: - dval := getPointGetValue(stmtCtx, colInfos[index], &innerX.Datum) + con, err := expression.ParamMarkerExpression(ctx, innerX, true) + if err != nil { + return nil + } + d, err := con.Eval(chunk.Row{}) + if err != nil { + return nil + } + dval := getPointGetValue(stmtCtx, colInfos[index], &d) if dval == nil { return nil } values[permIndex] = innerX.Datum - valuesParams[permIndex] = innerX + valuesParams[permIndex] = con + if initTypes { + indexTypes[permIndex] = &colInfos[index].FieldType + } default: return nil } @@ -723,12 +750,23 @@ func newBatchPointGetPlan( if len(whereColNames) != 1 { return nil } - dval := getPointGetValue(stmtCtx, colInfos[0], &x.Datum) + con, err := expression.ParamMarkerExpression(ctx, x, true) + if err != nil { + return nil + } + d, err := con.Eval(chunk.Row{}) + if err != nil { + return nil + } + dval := getPointGetValue(stmtCtx, colInfos[0], &d) if dval == nil { return nil } values = []types.Datum{*dval} - valuesParams = []*driver.ParamMarkerExpr{x} + valuesParams = []*expression.Constant{con} + if indexTypes == nil { // only init once + indexTypes = []*types.FieldType{&colInfos[0].FieldType} + } default: return nil } @@ -740,6 +778,7 @@ func newBatchPointGetPlan( IndexInfo: matchIdxInfo, IndexValues: indexValues, IndexValueParams: indexValueParams, + IndexColTypes: indexTypes, PartitionColPos: pos, PartitionExpr: partitionExpr, }.Init(ctx, statsInfo, schema, names, 0) diff --git a/planner/core/prepare_test.go b/planner/core/prepare_test.go index d7f754771a64e..7d5911ece814c 100644 --- a/planner/core/prepare_test.go +++ b/planner/core/prepare_test.go @@ -1541,7 +1541,7 @@ func TestParamMarker4FastPlan(t *testing.T) { require.NoError(t, err) tk := testkit.NewTestKitWithSession(t, store, se) - // test handle + // test handle for point get tk.MustExec(`use test`) tk.MustExec("drop table if exists t") tk.MustExec("create table t(pk int primary key)") @@ -1565,7 +1565,7 @@ func TestParamMarker4FastPlan(t *testing.T) { tk.MustQuery(`execute stmt using @a1`).Check(testkit.Rows()) tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) - // test indexValues + // test indexValues for point get tk.MustExec("drop table if exists t") tk.MustExec("create table t(pk int, unique index idx(pk))") tk.MustExec("insert into t values(1)") @@ -1588,7 +1588,7 @@ func TestParamMarker4FastPlan(t *testing.T) { tk.MustQuery(`execute stmt using @a1`).Check(testkit.Rows()) tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) - // test _tidb_rowid + // test _tidb_rowid for point get tk.MustExec(`use test`) tk.MustExec("drop table if exists t") tk.MustExec("create table t (a int, b int);") @@ -1599,6 +1599,46 @@ func TestParamMarker4FastPlan(t *testing.T) { tk.MustExec(`set @a=1`) tk.MustQuery("execute stmt using @a;").Check(testkit.Rows("1 7")) tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + // test handle for batch point get + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(pk int primary key)") + tk.MustExec("insert into t values (1), (2), (3), (4), (5)") + tk.MustExec(`prepare stmt from 'select * from t where pk in (1, ?, ?)'`) + tk.MustExec(`set @a0=0, @a1=1, @a2=2, @a3=3, @a1_1=1.1, @a4=4, @a5=5`) + tk.MustQuery(`execute stmt using @a2, @a3`).Sort().Check(testkit.Rows("1", "2", "3")) + tk.MustQuery(`execute stmt using @a2, @a3`).Sort().Check(testkit.Rows("1", "2", "3")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + tk.MustQuery(`execute stmt using @a0, @a4`).Sort().Check(testkit.Rows("1", "4")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + tk.MustQuery(`execute stmt using @a1_1, @a5`).Sort().Check(testkit.Rows("1", "5")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + + // test indexValues for batch point get + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(pk int, unique index idx(pk))") + tk.MustExec("insert into t values (1), (2), (3), (4), (5)") + tk.MustExec(`prepare stmt from 'select * from t where pk in (1, ?, ?)'`) + tk.MustExec(`set @a0=0, @a1=1, @a2=2, @a3=3, @a1_1=1.1, @a4=4, @a5=5`) + tk.MustQuery(`execute stmt using @a2, @a3`).Sort().Check(testkit.Rows("1", "2", "3")) + tk.MustQuery(`execute stmt using @a2, @a3`).Sort().Check(testkit.Rows("1", "2", "3")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + tk.MustQuery(`execute stmt using @a0, @a4`).Sort().Check(testkit.Rows("1", "4")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + tk.MustQuery(`execute stmt using @a1_1, @a5`).Sort().Check(testkit.Rows("1", "5")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) + + // test _tidb_rowid for batch point get + tk.MustExec(`use test`) + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int);") + tk.MustExec("insert t values (1, 7), (1, 8), (1, 9), (1, 10);") + tk.MustExec(`prepare stmt from 'select * from t where _tidb_rowid in (1, ?, ?)'`) + tk.MustExec(`set @a2=2, @a3=3`) + tk.MustQuery("execute stmt using @a2, @a3;").Sort().Check(testkit.Rows("1 7", "1 8", "1 9")) + tk.MustExec(`set @a2=4, @a3=2`) + tk.MustQuery("execute stmt using @a2, @a3;").Sort().Check(testkit.Rows("1 10", "1 7", "1 8")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) } func TestIssue29565(t *testing.T) { @@ -2125,7 +2165,7 @@ func TestIssue29993(t *testing.T) { tk.MustQuery("execute stmt using @b").Check(testkit.Rows()) tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) tk.MustQuery("execute stmt using @z").Check(testkit.Rows()) - tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // invalid since 'z' is not in enum('a', 'b') tk.MustQuery("execute stmt using @z").Check(testkit.Rows()) // test PointGet + non cluster index @@ -2153,7 +2193,7 @@ func TestIssue29993(t *testing.T) { tk.MustQuery("execute stmt using @b").Check(testkit.Rows()) tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) tk.MustQuery("execute stmt using @z").Check(testkit.Rows()) - tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) // invalid since 'z' is not in enum('a', 'b') tk.MustQuery("execute stmt using @z").Check(testkit.Rows()) }