Skip to content

Commit

Permalink
plan: support ? in Order By / Group By / Limit Offset clauses (#8206)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbjoa authored and zz-jason committed Dec 3, 2018
1 parent 04682ce commit c677187
Show file tree
Hide file tree
Showing 10 changed files with 234 additions and 28 deletions.
3 changes: 1 addition & 2 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/sqlexec"
Expand Down Expand Up @@ -161,7 +160,7 @@ func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error {

// We try to build the real statement of preparedStmt.
for i := range prepared.Params {
prepared.Params[i].(*driver.ParamMarkerExpr).Datum = types.NewIntDatum(0)
prepared.Params[i].(*driver.ParamMarkerExpr).Datum.SetNull()
}
var p plannercore.Plan
p, err = plannercore.BuildLogicalPlan(e.ctx, stmt, e.is)
Expand Down
57 changes: 57 additions & 0 deletions executor/prepared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -688,3 +688,60 @@ func (s *testSuite) TestPrepareDealloc(c *C) {
tk.MustExec("deallocate prepare stmt4")
c.Assert(tk.Se.PreparedPlanCache().Size(), Equals, 0)
}

func (s *testSuite) TestPreparedIssue8153(c *C) {
orgEnable := plannercore.PreparedPlanCacheEnabled()
orgCapacity := plannercore.PreparedPlanCacheCapacity
orgMemGuardRatio := plannercore.PreparedPlanCacheMemoryGuardRatio
orgMaxMemory := plannercore.PreparedPlanCacheMaxMemory
defer func() {
plannercore.SetPreparedPlanCache(orgEnable)
plannercore.PreparedPlanCacheCapacity = orgCapacity
plannercore.PreparedPlanCacheMemoryGuardRatio = orgMemGuardRatio
plannercore.PreparedPlanCacheMaxMemory = orgMaxMemory
}()
flags := []bool{false, true}
for _, flag := range flags {
var err error
plannercore.SetPreparedPlanCache(flag)
plannercore.PreparedPlanCacheCapacity = 100
plannercore.PreparedPlanCacheMemoryGuardRatio = 0.1
plannercore.PreparedPlanCacheMaxMemory, err = memory.MemTotal()
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int)")
tk.MustExec("insert into t (a, b) values (1,3), (2,2), (3,1)")

tk.MustExec(`prepare stmt from 'select * from t order by ? asc'`)
r := tk.MustQuery(`execute stmt using @param;`)
r.Check(testkit.Rows("1 3", "2 2", "3 1"))

tk.MustExec(`set @param = 1`)
r = tk.MustQuery(`execute stmt using @param;`)
r.Check(testkit.Rows("1 3", "2 2", "3 1"))

tk.MustExec(`set @param = 2`)
r = tk.MustQuery(`execute stmt using @param;`)
r.Check(testkit.Rows("3 1", "2 2", "1 3"))

tk.MustExec(`set @param = 3`)
_, err = tk.Exec(`execute stmt using @param;`)
c.Assert(err.Error(), Equals, "[planner:1054]Unknown column '?' in 'order clause'")

tk.MustExec(`set @param = '##'`)
r = tk.MustQuery(`execute stmt using @param;`)
r.Check(testkit.Rows("1 3", "2 2", "3 1"))

tk.MustExec("insert into t (a, b) values (1,1), (1,2), (2,1), (2,3), (3,2), (3,3)")
tk.MustExec(`prepare stmt from 'select ?, sum(a) from t group by ?'`)

tk.MustExec(`set @a=1,@b=1`)
r = tk.MustQuery(`execute stmt using @a,@b;`)
r.Check(testkit.Rows("1 18"))

tk.MustExec(`set @a=1,@b=2`)
_, err = tk.Exec(`execute stmt using @a,@b;`)
c.Assert(err.Error(), Equals, "[planner:1056]Can't group on 'sum(a)'")
}
}
2 changes: 1 addition & 1 deletion expression/simple_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok boo
}
case *driver.ParamMarkerExpr:
var value Expression
value, sr.err = GetParamExpression(sr.ctx, v, sr.useCache())
value, sr.err = GetParamExpression(sr.ctx, v)
if sr.err != nil {
return retNode, false
}
Expand Down
51 changes: 50 additions & 1 deletion expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,8 @@ func DatumToConstant(d types.Datum, tp byte) *Constant {
}

// GetParamExpression generate a getparam function expression.
func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr, useCache bool) (Expression, error) {
func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (Expression, error) {
useCache := ctx.GetSessionVars().StmtCtx.UseCache
tp := types.NewFieldType(mysql.TypeUnspecified)
types.DefaultParamTypeForValue(v.GetValue(), tp)
value := &Constant{Value: v.Datum, RetType: tp}
Expand All @@ -526,3 +527,51 @@ func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr, useCa
}
return value, nil
}

// ConstructPositionExpr constructs PositionExpr with the given ParamMarkerExpr.
func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr {
return &ast.PositionExpr{P: p}
}

// PosFromPositionExpr generates a position value from PositionExpr.
func PosFromPositionExpr(ctx sessionctx.Context, v *ast.PositionExpr) (int, bool, error) {
if v.P == nil {
return v.N, false, nil
}
value, err := GetParamExpression(ctx, v.P.(*driver.ParamMarkerExpr))
if err != nil {
return 0, true, err
}
pos, isNull, err := GetIntFromConstant(ctx, value)
if err != nil || isNull {
return 0, true, errors.Trace(err)
}
return pos, false, nil
}

// GetStringFromConstant gets a string value from the Constant expression.
func GetStringFromConstant(ctx sessionctx.Context, value Expression) (string, bool, error) {
con, ok := value.(*Constant)
if !ok {
err := errors.Errorf("Not a Constant expression %+v", value)
return "", true, errors.Trace(err)
}
str, isNull, err := con.EvalString(ctx, chunk.Row{})
if err != nil || isNull {
return "", true, errors.Trace(err)
}
return str, false, nil
}

// GetIntFromConstant gets an interger value from the Constant expression.
func GetIntFromConstant(ctx sessionctx.Context, value Expression) (int, bool, error) {
str, isNull, err := GetStringFromConstant(ctx, value)
if err != nil || isNull {
return 0, true, errors.Trace(err)
}
intNum, err := strconv.Atoi(str)
if err != nil {
return 0, true, nil
}
return intNum, false, nil
}
14 changes: 14 additions & 0 deletions planner/core/cacheable_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ func (checker *cacheableChecker) Enter(in ast.Node) (out ast.Node, skipChildren
checker.cacheable = false
return in, true
}
case *ast.OrderByClause:
for _, item := range node.Items {
if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker {
checker.cacheable = false
return in, true
}
}
case *ast.GroupByClause:
for _, item := range node.Items {
if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker {
checker.cacheable = false
return in, true
}
}
case *ast.Limit:
if node.Count != nil {
if _, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker {
Expand Down
14 changes: 14 additions & 0 deletions planner/core/cacheable_checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,18 @@ func (s *testCacheableSuite) TestCacheable(c *C) {
Limit: limitStmt,
}
c.Assert(Cacheable(stmt), IsTrue)

paramExpr := &driver.ParamMarkerExpr{}
orderByClause := &ast.OrderByClause{Items: []*ast.ByItem{{Expr: paramExpr}}}
stmt = &ast.SelectStmt{
OrderBy: orderByClause,
}
c.Assert(Cacheable(stmt), IsFalse)

valExpr := &driver.ValueExpr{}
orderByClause = &ast.OrderByClause{Items: []*ast.ByItem{{Expr: valExpr}}}
stmt = &ast.SelectStmt{
OrderBy: orderByClause,
}
c.Assert(Cacheable(stmt), IsTrue)
}
2 changes: 2 additions & 0 deletions planner/core/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (

codeWrongUsage = mysql.ErrWrongUsage
codeAmbiguous = mysql.ErrNonUniq
codeUnknown = mysql.ErrUnknown
codeUnknownColumn = mysql.ErrBadField
codeUnknownTable = mysql.ErrUnknownTable
codeWrongArguments = mysql.ErrWrongArguments
Expand Down Expand Up @@ -64,6 +65,7 @@ var (

ErrWrongUsage = terror.ClassOptimizer.New(codeWrongUsage, mysql.MySQLErrName[mysql.ErrWrongUsage])
ErrAmbiguous = terror.ClassOptimizer.New(codeAmbiguous, mysql.MySQLErrName[mysql.ErrNonUniq])
ErrUnknown = terror.ClassOptimizer.New(codeUnknown, mysql.MySQLErrName[mysql.ErrUnknown])
ErrUnknownColumn = terror.ClassOptimizer.New(codeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField])
ErrUnknownTable = terror.ClassOptimizer.New(codeUnknownTable, mysql.MySQLErrName[mysql.ErrUnknownTable])
ErrWrongArguments = terror.ClassOptimizer.New(codeWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments])
Expand Down
24 changes: 20 additions & 4 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
er.ctxStack = append(er.ctxStack, value)
case *driver.ParamMarkerExpr:
var value expression.Expression
value, er.err = expression.GetParamExpression(er.ctx, v, er.useCache())
value, er.err = expression.GetParamExpression(er.ctx, v)
if er.err != nil {
return retNode, false
}
Expand Down Expand Up @@ -941,10 +941,26 @@ func (er *expressionRewriter) isNullToExpression(v *ast.IsNullExpr) {
}

func (er *expressionRewriter) positionToScalarFunc(v *ast.PositionExpr) {
if v.N > 0 && v.N <= er.schema.Len() {
er.ctxStack = append(er.ctxStack, er.schema.Columns[v.N-1])
pos := v.N
str := strconv.Itoa(pos)
if v.P != nil {
stkLen := len(er.ctxStack)
val := er.ctxStack[stkLen-1]
intNum, isNull, err := expression.GetIntFromConstant(er.ctx, val)
str = "?"
if err == nil {
if isNull {
return
}
pos = intNum
er.ctxStack = er.ctxStack[:stkLen-1]
}
er.err = err
}
if er.err == nil && pos > 0 && pos <= er.schema.Len() {
er.ctxStack = append(er.ctxStack, er.schema.Columns[pos-1])
} else {
er.err = ErrUnknownColumn.GenWithStackByArgs(strconv.Itoa(v.N), clauseMsg[er.b.curClause])
er.err = ErrUnknownColumn.GenWithStackByArgs(str, clauseMsg[er.b.curClause])
}
}

Expand Down
Loading

0 comments on commit c677187

Please sign in to comment.