From c0c136094310575a58f58d05e9ae081eb9dc2a58 Mon Sep 17 00:00:00 2001 From: Jun-Seok Heo Date: Mon, 3 Dec 2018 22:48:54 +0900 Subject: [PATCH] plan: support `?` in Order By / Group By / Limit Offset clauses (#8206) --- executor/prepared.go | 3 +- executor/prepared_test.go | 51 ++++++++++++++ expression/simple_rewriter.go | 8 ++- expression/util.go | 72 ++++++++++++++++++++ go.mod | 2 + go.sum | 4 +- planner/core/cacheable_checker.go | 14 ++++ planner/core/cacheable_checker_test.go | 14 ++++ planner/core/errors.go | 2 + planner/core/expression_rewriter.go | 31 ++++++--- planner/core/logical_plan_builder.go | 94 ++++++++++++++++++++------ planner/core/point_get_plan.go | 3 +- 12 files changed, 262 insertions(+), 36 deletions(-) diff --git a/executor/prepared.go b/executor/prepared.go index da0b14c1d9edb..e7454a29f5111 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -25,7 +25,6 @@ import ( plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" @@ -165,7 +164,7 @@ func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error { // We try to build the real statement of preparedStmt. for i := range prepared.Params { - prepared.Params[i].(*driver.ParamMarkerExpr).Datum = types.NewIntDatum(0) + prepared.Params[i].(*driver.ParamMarkerExpr).Datum.SetNull() } var p plannercore.Plan p, err = plannercore.BuildLogicalPlan(e.ctx, stmt, e.is) diff --git a/executor/prepared_test.go b/executor/prepared_test.go index 565f2f7077d99..bd5d13779093e 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -366,3 +366,54 @@ func generateBatchSQL(paramCount int) (sql string, paramSlice []interface{}) { } return "insert into t values " + strings.Join(placeholders, ","), params } + +func (s *testSuite) TestPreparedIssue8153(c *C) { + orgEnable := plannercore.PreparedPlanCacheEnabled() + orgCapacity := plannercore.PreparedPlanCacheCapacity + defer func() { + plannercore.SetPreparedPlanCache(orgEnable) + plannercore.PreparedPlanCacheCapacity = orgCapacity + }() + flags := []bool{false, true} + for _, flag := range flags { + var err error + plannercore.SetPreparedPlanCache(flag) + plannercore.PreparedPlanCacheCapacity = 100 + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int)") + tk.MustExec("insert into t (a, b) values (1,3), (2,2), (3,1)") + + tk.MustExec(`prepare stmt from 'select * from t order by ? asc'`) + r := tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("1 3", "2 2", "3 1")) + + tk.MustExec(`set @param = 1`) + r = tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("1 3", "2 2", "3 1")) + + tk.MustExec(`set @param = 2`) + r = tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("3 1", "2 2", "1 3")) + + tk.MustExec(`set @param = 3`) + _, err = tk.Exec(`execute stmt using @param;`) + c.Assert(err.Error(), Equals, "[planner:1054]Unknown column '?' in 'order clause'") + + tk.MustExec(`set @param = '##'`) + r = tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("1 3", "2 2", "3 1")) + + tk.MustExec("insert into t (a, b) values (1,1), (1,2), (2,1), (2,3), (3,2), (3,3)") + tk.MustExec(`prepare stmt from 'select ?, sum(a) from t group by ?'`) + + tk.MustExec(`set @a=1,@b=1`) + r = tk.MustQuery(`execute stmt using @a,@b;`) + r.Check(testkit.Rows("1 18")) + + tk.MustExec(`set @a=1,@b=2`) + _, err = tk.Exec(`execute stmt using @a,@b;`) + c.Assert(err.Error(), Equals, "[planner:1056]Can't group on 'sum(a)'") + } +} diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index 5bc87ab74d7c0..f32401041de08 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -156,9 +156,11 @@ func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok boo sr.inToExpression(len(v.List), v.Not, &v.Type) } case *driver.ParamMarkerExpr: - tp := types.NewFieldType(mysql.TypeUnspecified) - types.DefaultParamTypeForValue(v.GetValue(), tp) - value := &Constant{Value: v.ValueExpr.Datum, RetType: tp} + var value Expression + value, sr.err = GetParamExpression(sr.ctx, v) + if sr.err != nil { + return retNode, false + } sr.push(value) case *ast.RowExpr: sr.rowToScalarFunc(v) diff --git a/expression/util.go b/expression/util.go index 2047bf6bd924b..38778c4ad8093 100644 --- a/expression/util.go +++ b/expression/util.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" + driver "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" "golang.org/x/tools/container/intsets" @@ -555,3 +556,74 @@ func DisableParseJSONFlag4Expr(expr Expression) { } expr.GetType().Flag &= ^mysql.ParseToJSONFlag } + +// DatumToConstant generates a Constant expression from a Datum. +func DatumToConstant(d types.Datum, tp byte) *Constant { + return &Constant{Value: d, RetType: types.NewFieldType(tp)} +} + +// GetParamExpression generate a getparam function expression. +func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (Expression, error) { + useCache := ctx.GetSessionVars().StmtCtx.UseCache + tp := types.NewFieldType(mysql.TypeUnspecified) + types.DefaultParamTypeForValue(v.GetValue(), tp) + value := &Constant{Value: v.Datum, RetType: tp} + if useCache { + f, err := NewFunctionBase(ctx, ast.GetParam, &v.Type, + DatumToConstant(types.NewIntDatum(int64(v.Order)), mysql.TypeLonglong)) + if err != nil { + return nil, errors.Trace(err) + } + f.GetType().Tp = v.Type.Tp + value.DeferredExpr = f + } + return value, nil +} + +// ConstructPositionExpr constructs PositionExpr with the given ParamMarkerExpr. +func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr { + return &ast.PositionExpr{P: p} +} + +// PosFromPositionExpr generates a position value from PositionExpr. +func PosFromPositionExpr(ctx sessionctx.Context, v *ast.PositionExpr) (int, bool, error) { + if v.P == nil { + return v.N, false, nil + } + value, err := GetParamExpression(ctx, v.P.(*driver.ParamMarkerExpr)) + if err != nil { + return 0, true, err + } + pos, isNull, err := GetIntFromConstant(ctx, value) + if err != nil || isNull { + return 0, true, errors.Trace(err) + } + return pos, false, nil +} + +// GetStringFromConstant gets a string value from the Constant expression. +func GetStringFromConstant(ctx sessionctx.Context, value Expression) (string, bool, error) { + con, ok := value.(*Constant) + if !ok { + err := errors.Errorf("Not a Constant expression %+v", value) + return "", true, errors.Trace(err) + } + str, isNull, err := con.EvalString(ctx, chunk.Row{}) + if err != nil || isNull { + return "", true, errors.Trace(err) + } + return str, false, nil +} + +// GetIntFromConstant gets an interger value from the Constant expression. +func GetIntFromConstant(ctx sessionctx.Context, value Expression) (int, bool, error) { + str, isNull, err := GetStringFromConstant(ctx, value) + if err != nil || isNull { + return 0, true, errors.Trace(err) + } + intNum, err := strconv.Atoi(str) + if err != nil { + return 0, true, nil + } + return intNum, false, nil +} diff --git a/go.mod b/go.mod index 8eb18aad6377f..1a3caf5e1abbb 100644 --- a/go.mod +++ b/go.mod @@ -84,3 +84,5 @@ require ( gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) + +replace github.com/pingcap/parser => github.com/zz-jason/parser v0.0.0-20191003033834-cce7a9500e2e diff --git a/go.sum b/go.sum index 28506835b806e..45af0db02ed77 100644 --- a/go.sum +++ b/go.sum @@ -109,8 +109,6 @@ github.com/pingcap/kvproto v0.0.0-20190826051950-fc8799546726 h1:AzGIEmaYVYMtmki github.com/pingcap/kvproto v0.0.0-20190826051950-fc8799546726/go.mod h1:0gwbe1F2iBIjuQ9AH0DbQhL+Dpr5GofU8fgYyXk+ykk= github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596 h1:t2OQTpPJnrPDGlvA+3FwJptMTt6MEPdzK1Wt99oaefQ= github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw= -github.com/pingcap/parser v0.0.0-20190910040957-e998b3c52469 h1:JS/p4qMInVXTyV0kjFz+n0DBGn/n1T0cZDjEYHdTQow= -github.com/pingcap/parser v0.0.0-20190910040957-e998b3c52469/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= github.com/pingcap/pd v2.1.12+incompatible h1:6N3LBxx2aSZqT+IWEG730EDNDttP7dXO8J6yvBh+HXw= github.com/pingcap/pd v2.1.12+incompatible/go.mod h1:nD3+EoYes4+aNNODO99ES59V83MZSI+dFbhyr667a0E= github.com/pingcap/tidb-tools v2.1.3-0.20190116051332-34c808eef588+incompatible h1:e9Gi/LP9181HT3gBfSOeSBA+5JfemuE4aEAhqNgoE4k= @@ -151,6 +149,8 @@ github.com/unrolled/render v0.0.0-20171102162132-65450fb6b2d3/go.mod h1:tu82oB5W github.com/xiang90/probing v0.0.0-20160813154853-07dd2e8dfe18 h1:MPPkRncZLN9Kh4MEFmbnK4h3BD7AUmskWv2+EeZJCCs= github.com/xiang90/probing v0.0.0-20160813154853-07dd2e8dfe18/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/yookoala/realpath v1.0.0/go.mod h1:gJJMA9wuX7AcqLy1+ffPatSCySA1FQ2S8Ya9AIoYBpE= +github.com/zz-jason/parser v0.0.0-20191003033834-cce7a9500e2e h1:oxazCGeHJ+CdDGPGVeIpIBzJ4dw0DNqDI5wdXPVZb8Q= +github.com/zz-jason/parser v0.0.0-20191003033834-cce7a9500e2e/go.mod h1:mnf7H9ngMZzobilLo3+bu86/+DSlGQBnmse9S5K8PKQ= go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.uber.org/atomic v1.3.2 h1:2Oa65PReHzfn29GpvgsYwloV9AVFHPDk8tYxt2c2tr4= diff --git a/planner/core/cacheable_checker.go b/planner/core/cacheable_checker.go index 38f417d804456..b2f6e8665471d 100644 --- a/planner/core/cacheable_checker.go +++ b/planner/core/cacheable_checker.go @@ -51,6 +51,20 @@ func (checker *cacheableChecker) Enter(in ast.Node) (out ast.Node, skipChildren checker.cacheable = false return in, true } + case *ast.OrderByClause: + for _, item := range node.Items { + if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker { + checker.cacheable = false + return in, true + } + } + case *ast.GroupByClause: + for _, item := range node.Items { + if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker { + checker.cacheable = false + return in, true + } + } case *ast.Limit: if node.Count != nil { if _, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker { diff --git a/planner/core/cacheable_checker_test.go b/planner/core/cacheable_checker_test.go index 9b6c1367e3042..d2d0a7e896f48 100644 --- a/planner/core/cacheable_checker_test.go +++ b/planner/core/cacheable_checker_test.go @@ -87,4 +87,18 @@ func (s *testCacheableSuite) TestCacheable(c *C) { Limit: limitStmt, } c.Assert(Cacheable(stmt), IsTrue) + + paramExpr := &driver.ParamMarkerExpr{} + orderByClause := &ast.OrderByClause{Items: []*ast.ByItem{{Expr: paramExpr}}} + stmt = &ast.SelectStmt{ + OrderBy: orderByClause, + } + c.Assert(Cacheable(stmt), IsFalse) + + valExpr := &driver.ValueExpr{} + orderByClause = &ast.OrderByClause{Items: []*ast.ByItem{{Expr: valExpr}}} + stmt = &ast.SelectStmt{ + OrderBy: orderByClause, + } + c.Assert(Cacheable(stmt), IsTrue) } diff --git a/planner/core/errors.go b/planner/core/errors.go index 4cfaa978aeacf..1e769e95bb5e4 100644 --- a/planner/core/errors.go +++ b/planner/core/errors.go @@ -28,6 +28,7 @@ const ( codeWrongUsage = mysql.ErrWrongUsage codeAmbiguous = mysql.ErrNonUniq + codeUnknown = mysql.ErrUnknown codeUnknownColumn = mysql.ErrBadField codeUnknownTable = mysql.ErrUnknownTable codeWrongArguments = mysql.ErrWrongArguments @@ -65,6 +66,7 @@ var ( ErrWrongUsage = terror.ClassOptimizer.New(codeWrongUsage, mysql.MySQLErrName[mysql.ErrWrongUsage]) ErrAmbiguous = terror.ClassOptimizer.New(codeAmbiguous, mysql.MySQLErrName[mysql.ErrNonUniq]) + ErrUnknown = terror.ClassOptimizer.New(codeUnknown, mysql.MySQLErrName[mysql.ErrUnknown]) ErrUnknownColumn = terror.ClassOptimizer.New(codeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField]) ErrUnknownTable = terror.ClassOptimizer.New(codeUnknownTable, mysql.MySQLErrName[mysql.ErrUnknownTable]) ErrWrongArguments = terror.ClassOptimizer.New(codeWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments]) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 07220d8851070..f41765ccacd82 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -820,11 +820,10 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok value := &expression.Constant{Value: v.Datum, RetType: &v.Type} er.ctxStack = append(er.ctxStack, value) case *driver.ParamMarkerExpr: - tp := types.NewFieldType(mysql.TypeUnspecified) - types.DefaultParamTypeForValue(v.GetValue(), tp) - value := &expression.Constant{Value: v.Datum, RetType: tp} - if er.useCache() { - value.DeferredExpr = er.getParamExpression(v) + var value expression.Expression + value, er.err = expression.GetParamExpression(er.ctx, v) + if er.err != nil { + return retNode, false } er.ctxStack = append(er.ctxStack, value) case *ast.VariableExpr: @@ -1044,10 +1043,26 @@ func (er *expressionRewriter) isNullToExpression(v *ast.IsNullExpr) { } func (er *expressionRewriter) positionToScalarFunc(v *ast.PositionExpr) { - if v.N > 0 && v.N <= er.schema.Len() { - er.ctxStack = append(er.ctxStack, er.schema.Columns[v.N-1]) + pos := v.N + str := strconv.Itoa(pos) + if v.P != nil { + stkLen := len(er.ctxStack) + val := er.ctxStack[stkLen-1] + intNum, isNull, err := expression.GetIntFromConstant(er.ctx, val) + str = "?" + if err == nil { + if isNull { + return + } + pos = intNum + er.ctxStack = er.ctxStack[:stkLen-1] + } + er.err = err + } + if er.err == nil && pos > 0 && pos <= er.schema.Len() { + er.ctxStack = append(er.ctxStack, er.schema.Columns[pos-1]) } else { - er.err = ErrUnknownColumn.GenWithStackByArgs(strconv.Itoa(v.N), clauseMsg[er.b.curClause]) + er.err = ErrUnknownColumn.GenWithStackByArgs(str, clauseMsg[er.b.curClause]) } } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 19e9391c4ab60..eea9cda709175 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -34,11 +34,10 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/types/parser_driver" + driver "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" ) @@ -867,6 +866,23 @@ func (by *ByItems) Clone() *ByItems { return &ByItems{Expr: by.Expr.Clone(), Desc: by.Desc} } +// itemTransformer transforms ParamMarkerExpr to PositionExpr in the context of ByItem +type itemTransformer struct { +} + +func (t *itemTransformer) Enter(inNode ast.Node) (ast.Node, bool) { + switch n := inNode.(type) { + case *driver.ParamMarkerExpr: + newNode := expression.ConstructPositionExpr(n) + return newNode, true + } + return inNode, false +} + +func (t *itemTransformer) Leave(inNode ast.Node) (ast.Node, bool) { + return inNode, false +} + func (b *planBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int) (*LogicalSort, error) { if _, isUnion := p.(*LogicalUnionAll); isUnion { b.curClause = globalOrderByClause @@ -875,7 +891,10 @@ func (b *planBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper } sort := LogicalSort{}.init(b.ctx) exprs := make([]*ByItems, 0, len(byItems)) + transformer := &itemTransformer{} for _, item := range byItems { + newExpr, _ := item.Expr.Accept(transformer) + item.Expr = newExpr.(ast.ExprNode) it, np, err := b.rewrite(item.Expr, p, aggMapper, true) if err != nil { return nil, errors.Trace(err) @@ -892,7 +911,27 @@ func (b *planBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper // getUintForLimitOffset gets uint64 value for limit/offset. // For ordinary statement, limit/offset should be uint64 constant value. // For prepared statement, limit/offset is string. We should convert it to uint64. -func getUintForLimitOffset(sc *stmtctx.StatementContext, val interface{}) (uint64, error) { +func getUintForLimitOffset(ctx sessionctx.Context, n ast.Node) (uint64, error) { + var val interface{} + switch v := n.(type) { + case *driver.ValueExpr: + val = v.GetValue() + case *driver.ParamMarkerExpr: + param, err := expression.GetParamExpression(ctx, v) + if err != nil { + return 0, errors.Trace(err) + } + str, isNull, err := expression.GetStringFromConstant(ctx, param) + if err != nil { + return 0, errors.Trace(err) + } + if isNull { + return 0, nil + } + val = str + default: + return 0, errors.Errorf("Invalid type %T for LogicalLimit/Offset", v) + } switch v := val.(type) { case uint64: return v, nil @@ -901,22 +940,23 @@ func getUintForLimitOffset(sc *stmtctx.StatementContext, val interface{}) (uint6 return uint64(v), nil } case string: + sc := ctx.GetSessionVars().StmtCtx uVal, err := types.StrToUint(sc, v) return uVal, errors.Trace(err) } return 0, errors.Errorf("Invalid type %T for LogicalLimit/Offset", val) } -func extractLimitCountOffset(sc *stmtctx.StatementContext, limit *ast.Limit) (count uint64, +func extractLimitCountOffset(ctx sessionctx.Context, limit *ast.Limit) (count uint64, offset uint64, err error) { if limit.Count != nil { - count, err = getUintForLimitOffset(sc, limit.Count.(ast.ValueExpr).GetValue()) + count, err = getUintForLimitOffset(ctx, limit.Count) if err != nil { return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT") } } if limit.Offset != nil { - offset, err = getUintForLimitOffset(sc, limit.Offset.(ast.ValueExpr).GetValue()) + offset, err = getUintForLimitOffset(ctx, limit.Offset) if err != nil { return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT") } @@ -930,8 +970,7 @@ func (b *planBuilder) buildLimit(src LogicalPlan, limit *ast.Limit) (LogicalPlan offset, count uint64 err error ) - sc := b.ctx.GetSessionVars().StmtCtx - if count, offset, err = extractLimitCountOffset(sc, limit); err != nil { + if count, offset, err = extractLimitCountOffset(b.ctx, limit); err != nil { return nil, err } @@ -1201,16 +1240,22 @@ func (b *planBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.Aggrega // gbyResolver resolves group by items from select fields. type gbyResolver struct { - fields []*ast.SelectField - schema *expression.Schema - err error - inExpr bool + ctx sessionctx.Context + fields []*ast.SelectField + schema *expression.Schema + err error + inExpr bool + isParam bool } func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { - switch inNode.(type) { + switch n := inNode.(type) { case *ast.SubqueryExpr, *ast.CompareSubqueryExpr, *ast.ExistsSubqueryExpr: return inNode, true + case *driver.ParamMarkerExpr: + newNode := expression.ConstructPositionExpr(n) + g.isParam = true + return newNode, true case *driver.ValueExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.ColumnName: default: g.inExpr = true @@ -1245,14 +1290,21 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { return inNode, false } case *ast.PositionExpr: - if v.N < 1 || v.N > len(g.fields) { - g.err = errors.Errorf("Unknown column '%d' in 'group statement'", v.N) + pos, isNull, err := expression.PosFromPositionExpr(g.ctx, v) + if err != nil { + g.err = ErrUnknown.GenWithStackByArgs() + } + if err != nil || isNull { return inNode, false } - ret := g.fields[v.N-1].Expr + if pos < 1 || pos > len(g.fields) { + g.err = errors.Errorf("Unknown column '%d' in 'group statement'", pos) + return inNode, false + } + ret := g.fields[pos-1].Expr ret.Accept(extractor) if len(extractor.AggFuncs) != 0 { - g.err = ErrWrongGroupField.GenWithStackByArgs(g.fields[v.N-1].Text()) + g.err = ErrWrongGroupField.GenWithStackByArgs(g.fields[pos-1].Text()) return inNode, false } return ret, true @@ -1625,6 +1677,7 @@ func (b *planBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fie b.curClause = groupByClause exprs := make([]expression.Expression, 0, len(gby.Items)) resolver := &gbyResolver{ + ctx: b.ctx, fields: fields, schema: p.Schema(), } @@ -1634,9 +1687,12 @@ func (b *planBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fie if resolver.err != nil { return nil, nil, errors.Trace(resolver.err) } + if !resolver.isParam { + item.Expr = retExpr.(ast.ExprNode) + } - item.Expr = retExpr.(ast.ExprNode) - expr, np, err := b.rewrite(item.Expr, p, nil, true) + itemExpr := retExpr.(ast.ExprNode) + expr, np, err := b.rewrite(itemExpr, p, nil, true) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index b1f794fbb969f..40f4ae5c8d9f4 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -153,8 +153,7 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if selStmt.Having != nil || selStmt.LockTp != ast.SelectLockNone { return nil } else if selStmt.Limit != nil { - sc := ctx.GetSessionVars().StmtCtx - count, offset, err := extractLimitCountOffset(sc, selStmt.Limit) + count, offset, err := extractLimitCountOffset(ctx, selStmt.Limit) if err != nil || count == 0 || offset > 0 { return nil }