diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 30770f0408e4a..89810c1ae9add 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -208,6 +208,7 @@ func (e *Execute) getPhysicalPlan(ctx sessionctx.Context, is infoschema.InfoSche func (e *Execute) rebuildRange(p Plan) error { sctx := p.context() sc := p.context().GetSessionVars().StmtCtx + var err error switch x := p.(type) { case *PhysicalTableReader: ts := x.TablePlans[0].(*PhysicalTableScan) @@ -218,7 +219,6 @@ func (e *Execute) rebuildRange(p Plan) error { } } if pkCol != nil { - var err error ts.Ranges, err = ranger.BuildTableRange(ts.AccessCondition, sc, pkCol.RetType) if err != nil { return errors.Trace(err) @@ -228,20 +228,31 @@ func (e *Execute) rebuildRange(p Plan) error { } case *PhysicalIndexReader: is := x.IndexPlans[0].(*PhysicalIndexScan) - var err error is.Ranges, err = e.buildRangeForIndexScan(sctx, is) if err != nil { return errors.Trace(err) } case *PhysicalIndexLookUpReader: is := x.IndexPlans[0].(*PhysicalIndexScan) - var err error is.Ranges, err = e.buildRangeForIndexScan(sctx, is) if err != nil { return errors.Trace(err) } + case *PointGetPlan: + if x.HandleParam != nil { + x.Handle, err = x.HandleParam.Datum.ToInt64(sc) + if err != nil { + return errors.Trace(err) + } + return nil + } + for i, param := range x.IndexValueParams { + if param != nil { + x.IndexValues[i] = param.Datum + } + } + return nil case PhysicalPlan: - var err error for _, child := range x.Children() { err = e.rebuildRange(child) if err != nil { diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 340dd31d3dfec..f8ae66079d585 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -36,18 +36,21 @@ import ( // This plan is much faster to build and to execute because it avoid the optimization and coprocessor cost. type PointGetPlan struct { basePlan - schema *expression.Schema - TblInfo *model.TableInfo - IndexInfo *model.IndexInfo - Handle int64 - IndexValues []types.Datum - expr expression.Expression - ctx sessionctx.Context + schema *expression.Schema + TblInfo *model.TableInfo + IndexInfo *model.IndexInfo + Handle int64 + HandleParam *driver.ParamMarkerExpr + IndexValues []types.Datum + IndexValueParams []*driver.ParamMarkerExpr + expr expression.Expression + ctx sessionctx.Context } type nameValuePair struct { colName string value types.Datum + param *driver.ParamMarkerExpr } // Schema implements the Plan interface. @@ -117,10 +120,6 @@ func (p *PointGetPlan) ResolveIndices() {} // TryFastPlan tries to use the PointGetPlan for the query. func TryFastPlan(ctx sessionctx.Context, node ast.Node) Plan { - if PreparedPlanCacheEnabled() { - // Do not support plan cache. - return nil - } switch x := node.(type) { case *ast.SelectStmt: fp := tryPointGetPlan(ctx, x) @@ -185,8 +184,8 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if pairs == nil { return nil } - handleDatum := findPKHandle(tbl, pairs) - if handleDatum.Kind() != types.KindNull { + handlePair := findPKHandle(tbl, pairs) + if handlePair.value.Kind() != types.KindNull { if len(pairs) != 1 { return nil } @@ -196,10 +195,11 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP } p := newPointGetPlan(ctx, schema, tbl) var err error - p.Handle, err = handleDatum.ToInt64(ctx.GetSessionVars().StmtCtx) + p.Handle, err = handlePair.value.ToInt64(ctx.GetSessionVars().StmtCtx) if err != nil { return nil } + p.HandleParam = handlePair.param return p } for _, idxInfo := range tbl.Indices { @@ -209,7 +209,7 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if idxInfo.State != model.StatePublic { continue } - idxValues := getIndexValues(idxInfo, pairs) + idxValues, idxValueParams := getIndexValues(idxInfo, pairs) if idxValues == nil { continue } @@ -220,6 +220,7 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP p := newPointGetPlan(ctx, schema, tbl) p.IndexInfo = idxInfo p.IndexValues = idxValues + p.IndexValueParams = idxValueParams return p } return nil @@ -331,60 +332,74 @@ func getNameValuePairs(nvPairs []nameValuePair, expr ast.ExprNode) []nameValuePa } return nvPairs } else if binOp.Op == opcode.EQ { - colName, ok := binOp.L.(*ast.ColumnNameExpr) - if !ok { - return nil - } var d types.Datum - switch x := binOp.R.(type) { - case *driver.ValueExpr: - d = x.Datum - case *driver.ParamMarkerExpr: - d = x.Datum + var colName *ast.ColumnNameExpr + var param *driver.ParamMarkerExpr + var ok bool + if colName, ok = binOp.L.(*ast.ColumnNameExpr); ok { + switch x := binOp.R.(type) { + case *driver.ValueExpr: + d = x.Datum + case *driver.ParamMarkerExpr: + d = x.Datum + param = x + } + } else if colName, ok = binOp.R.(*ast.ColumnNameExpr); ok { + switch x := binOp.L.(type) { + case *driver.ValueExpr: + d = x.Datum + case *driver.ParamMarkerExpr: + d = x.Datum + param = x + } + } else { + return nil } if d.IsNull() { return nil } - return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d}) + return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}) } return nil } -func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (d types.Datum) { +func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (handlePair nameValuePair) { if !tblInfo.PKIsHandle { - return d + return handlePair } for _, col := range tblInfo.Columns { if mysql.HasPriKeyFlag(col.Flag) { i := findInPairs(col.Name.L, pairs) if i == -1 { - return d + return handlePair } - return pairs[i].value + return pairs[i] } } - return d + return handlePair } -func getIndexValues(idxInfo *model.IndexInfo, pairs []nameValuePair) []types.Datum { +func getIndexValues(idxInfo *model.IndexInfo, pairs []nameValuePair) ([]types.Datum, []*driver.ParamMarkerExpr) { idxValues := make([]types.Datum, 0, 4) + idxValueParams := make([]*driver.ParamMarkerExpr, 0, 4) if len(idxInfo.Columns) != len(pairs) { - return nil + return nil, nil } if idxInfo.HasPrefixIndex() { - return nil + return nil, nil } for _, idxCol := range idxInfo.Columns { i := findInPairs(idxCol.Name.L, pairs) if i == -1 { - return nil + return nil, nil } idxValues = append(idxValues, pairs[i].value) + idxValueParams = append(idxValueParams, pairs[i].param) } if len(idxValues) > 0 { - return idxValues + return idxValues, idxValueParams } - return nil + return nil, nil } func findInPairs(colName string, pairs []nameValuePair) int { diff --git a/planner/core/point_get_plan_test.go b/planner/core/point_get_plan_test.go new file mode 100644 index 0000000000000..3213faa040b84 --- /dev/null +++ b/planner/core/point_get_plan_test.go @@ -0,0 +1,129 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package core_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testleak" + dto "github.com/prometheus/client_model/go" +) + +var _ = Suite(&testPointGetSuite{}) + +type testPointGetSuite struct { +} + +func (s *testPointGetSuite) TestPointGetPlanCache(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + orgEnable := core.PreparedPlanCacheEnabled() + orgCapacity := core.PreparedPlanCacheCapacity + defer func() { + dom.Close() + store.Close() + core.SetPreparedPlanCache(orgEnable) + core.PreparedPlanCacheCapacity = orgCapacity + }() + core.SetPreparedPlanCache(true) + core.PreparedPlanCacheCapacity = 100 + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int primary key, b int, c int, key idx_bc(b,c))") + tk.MustExec("insert into t values(1, 1, 1), (2, 2, 2), (3, 3, 3)") + tk.MustQuery("explain select * from t where a = 1").Check(testkit.Rows( + "Point_Get_1 1.00 root table:t, handle:1", + )) + tk.MustQuery("explain select * from t where 1 = a").Check(testkit.Rows( + "Point_Get_1 1.00 root table:t, handle:1", + )) + tk.MustQuery("explain update t set b=b+1, c=c+1 where a = 1").Check(testkit.Rows( + "Point_Get_1 1.00 root table:t, handle:1", + )) + tk.MustQuery("explain delete from t where a = 1").Check(testkit.Rows( + "Point_Get_1 1.00 root table:t, handle:1", + )) + metrics.PlanCacheCounter.Reset() + counter := metrics.PlanCacheCounter.WithLabelValues("prepare") + pb := &dto.Metric{} + var hit float64 + // PointGetPlan for Select. + tk.MustExec(`prepare stmt1 from "select * from t where a = ?"`) + tk.MustExec(`prepare stmt2 from "select * from t where b = ? and c = ?"`) + tk.MustExec("set @param=1") + tk.MustQuery("execute stmt1 using @param").Check(testkit.Rows("1 1 1")) + counter.Write(pb) + hit = pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(0)) + tk.MustExec("set @param=2") + tk.MustQuery("execute stmt1 using @param").Check(testkit.Rows("2 2 2")) + counter.Write(pb) + hit = pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(1)) + tk.MustQuery("execute stmt2 using @param, @param").Check(testkit.Rows("2 2 2")) + counter.Write(pb) + hit = pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(1)) + tk.MustExec("set @param=1") + tk.MustQuery("execute stmt2 using @param, @param").Check(testkit.Rows("1 1 1")) + counter.Write(pb) + hit = pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(2)) + // PointGetPlan for Update. + tk.MustExec(`prepare stmt3 from "update t set b=b+1, c=c+1 where a = ?"`) + tk.MustExec(`prepare stmt4 from "update t set a=a+1 where b = ? and c = ?"`) + tk.MustExec("set @param=3") + tk.MustExec("execute stmt3 using @param") + tk.MustQuery("select * from t").Check(testkit.Rows( + "1 1 1", + "2 2 2", + "3 4 4", + )) + counter.Write(pb) + hit = pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(2)) + tk.MustExec("set @param=4") + tk.MustExec("execute stmt4 using @param, @param") + tk.MustQuery("select * from t").Check(testkit.Rows( + "1 1 1", + "2 2 2", + "4 4 4", + )) + counter.Write(pb) + hit = pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(2)) + // PointGetPlan for Delete. + tk.MustExec(`prepare stmt5 from "delete from t where a = ?"`) + tk.MustExec(`prepare stmt6 from "delete from t where b = ? and c = ?"`) + tk.MustExec("execute stmt5 using @param") + tk.MustQuery("select * from t").Check(testkit.Rows( + "1 1 1", + "2 2 2", + )) + counter.Write(pb) + hit = pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(2)) + tk.MustExec("set @param=2") + tk.MustExec("execute stmt6 using @param, @param") + tk.MustQuery("select * from t").Check(testkit.Rows( + "1 1 1", + )) + counter.Write(pb) + hit = pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(2)) +}