diff --git a/ast/ast.go b/ast/ast.go index 3cc2da1f20e1b..eef7ac63016fc 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -64,14 +64,6 @@ type ExprNode interface { SetType(tp *types.FieldType) // GetType gets the evaluation type of the expression. GetType() *types.FieldType - // SetValue sets value to the expression. - SetValue(val interface{}) - // GetValue gets value of the expression. - GetValue() interface{} - // SetDatum sets datum to the expression. - SetDatum(datum types.Datum) - // GetDatum gets datum of the expression. - GetDatum() *types.Datum // SetFlag sets flag to the expression. // Flag indicates whether the expression contains // parameter marker, reference, aggregate function... @@ -154,15 +146,6 @@ type RecordSet interface { Close() error } -// RowToDatums converts row to datum slice. -func RowToDatums(row chunk.Row, fields []*ResultField) []types.Datum { - datums := make([]types.Datum, len(fields)) - for i, f := range fields { - datums[i] = row.GetDatum(i, &f.Column.FieldType) - } - return datums -} - // ResultSetNode interface has a ResultFields property, represents a Node that returns result set. // Implementations include SelectStmt, SubqueryExpr, TableSource, TableName and Join. type ResultSetNode interface { diff --git a/ast/base.go b/ast/base.go index 75ad175dce039..c62aa6d8a02f2 100644 --- a/ast/base.go +++ b/ast/base.go @@ -62,20 +62,12 @@ func (dn *dmlNode) dmlStatement() {} // Expression implementations should embed it in. type exprNode struct { node - types.Datum Type types.FieldType flag uint64 } -// SetDatum implements ExprNode interface. -func (en *exprNode) SetDatum(datum types.Datum) { - en.Datum = datum -} - -// GetDatum implements ExprNode interface. -func (en *exprNode) GetDatum() *types.Datum { - return &en.Datum -} +// TexprNode is exported for parser driver. +type TexprNode = exprNode // SetType implements ExprNode interface. func (en *exprNode) SetType(tp *types.FieldType) { diff --git a/ast/expressions.go b/ast/expressions.go index ca22017330e0f..56e95b0f963c5 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -17,13 +17,10 @@ import ( "fmt" "io" "regexp" - "strconv" "strings" "github.com/pingcap/tidb/model" - "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" - "github.com/pingcap/tidb/types" ) var ( @@ -36,7 +33,6 @@ var ( _ ExprNode = &ExistsSubqueryExpr{} _ ExprNode = &IsNullExpr{} _ ExprNode = &IsTruthExpr{} - _ ExprNode = &ParamMarkerExpr{} _ ExprNode = &ParenthesesExpr{} _ ExprNode = &PatternInExpr{} _ ExprNode = &PatternLikeExpr{} @@ -45,7 +41,6 @@ var ( _ ExprNode = &RowExpr{} _ ExprNode = &SubqueryExpr{} _ ExprNode = &UnaryOperationExpr{} - _ ExprNode = &ValueExpr{} _ ExprNode = &ValuesExpr{} _ ExprNode = &VariableExpr{} @@ -53,81 +48,22 @@ var ( _ Node = &WhenClause{} ) -// ValueExpr is the simple value expression. -type ValueExpr struct { - exprNode - projectionOffset int -} - -// Format the ExprNode into a Writer. -func (n *ValueExpr) Format(w io.Writer) { - var s string - switch n.Kind() { - case types.KindNull: - s = "NULL" - case types.KindInt64: - if n.Type.Flag&mysql.IsBooleanFlag != 0 { - if n.GetInt64() > 0 { - s = "TRUE" - } else { - s = "FALSE" - } - } else { - s = strconv.FormatInt(n.GetInt64(), 10) - } - case types.KindUint64: - s = strconv.FormatUint(n.GetUint64(), 10) - case types.KindFloat32: - s = strconv.FormatFloat(n.GetFloat64(), 'e', -1, 32) - case types.KindFloat64: - s = strconv.FormatFloat(n.GetFloat64(), 'e', -1, 64) - case types.KindString, types.KindBytes: - s = strconv.Quote(n.GetString()) - case types.KindMysqlDecimal: - s = n.GetMysqlDecimal().String() - case types.KindBinaryLiteral: - if n.Type.Flag&mysql.UnsignedFlag != 0 { - s = fmt.Sprintf("x'%x'", n.GetBytes()) - } else { - s = n.GetBinaryLiteral().ToBitLiteralString(true) - } - default: - panic("Can't format to string") - } - fmt.Fprint(w, s) +// ValueExpr define a interface for ValueExpr. +type ValueExpr interface { + ExprNode + SetValue(val interface{}) + GetValue() interface{} + GetDatumString() string + GetString() string + GetProjectionOffset() int + SetProjectionOffset(offset int) } // NewValueExpr creates a ValueExpr with value, and sets default field type. -func NewValueExpr(value interface{}) *ValueExpr { - if ve, ok := value.(*ValueExpr); ok { - return ve - } - ve := &ValueExpr{} - ve.SetValue(value) - types.DefaultTypeForValue(value, &ve.Type) - ve.projectionOffset = -1 - return ve -} +var NewValueExpr func(interface{}) ValueExpr -// SetProjectionOffset sets ValueExpr.projectionOffset for logical plan builder. -func (n *ValueExpr) SetProjectionOffset(offset int) { - n.projectionOffset = offset -} - -// GetProjectionOffset returns ValueExpr.projectionOffset. -func (n *ValueExpr) GetProjectionOffset() int { - return n.projectionOffset -} - -// Accept implements Node interface. -func (n *ValueExpr) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*ValueExpr) - return v.Leave(n) -} +// NewParamMarkerExpr creates a ParamMarkerExpr. +var NewParamMarkerExpr func(offset int) ParamMarkerExpr // BetweenExpr is for "between and" or "not between and" expression. type BetweenExpr struct { @@ -721,25 +657,9 @@ func (n *PatternLikeExpr) Accept(v Visitor) (Node, bool) { // ParamMarkerExpr expression holds a place for another expression. // Used in parsing prepare statement. -type ParamMarkerExpr struct { - exprNode - Offset int - Order int -} - -// Format the ExprNode into a Writer. -func (n *ParamMarkerExpr) Format(w io.Writer) { - panic("Not implemented") -} - -// Accept implements Node Accept interface. -func (n *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*ParamMarkerExpr) - return v.Leave(n) +type ParamMarkerExpr interface { + ValueExpr + SetOrder(int) } // ParenthesesExpr is the parentheses expression. diff --git a/ast/expressions_test.go b/ast/expressions_test.go index 9176e36b70f3a..99fe630afaa37 100644 --- a/ast/expressions_test.go +++ b/ast/expressions_test.go @@ -16,6 +16,7 @@ package ast_test import ( . "github.com/pingcap/check" . "github.com/pingcap/tidb/ast" + _ "github.com/pingcap/tidb/types/parser_driver" ) var _ = Suite(&testExpressionsSuite{}) @@ -79,7 +80,7 @@ func (tc *testExpressionsSuite) TestExpresionsVisitorCover(c *C) { {&ExistsSubqueryExpr{Sel: ce}, 1, 1}, {&IsNullExpr{Expr: ce}, 1, 1}, {&IsTruthExpr{Expr: ce}, 1, 1}, - {&ParamMarkerExpr{}, 0, 0}, + {NewParamMarkerExpr(0), 0, 0}, {&ParenthesesExpr{Expr: ce}, 1, 1}, {&PatternInExpr{Expr: ce, List: []ExprNode{ce, ce, ce}, Sel: ce}, 5, 5}, {&PatternLikeExpr{Expr: ce, Pattern: ce}, 2, 2}, @@ -87,7 +88,7 @@ func (tc *testExpressionsSuite) TestExpresionsVisitorCover(c *C) { {&PositionExpr{}, 0, 0}, {&RowExpr{Values: []ExprNode{ce, ce}}, 2, 2}, {&UnaryOperationExpr{V: ce}, 1, 1}, - {&ValueExpr{}, 0, 0}, + {NewValueExpr(0), 0, 0}, {&ValuesExpr{Column: &ColumnNameExpr{Name: &ColumnName{}}}, 0, 0}, {&VariableExpr{Value: ce}, 1, 1}, } diff --git a/ast/flag.go b/ast/flag.go index 6883f82e138e6..773a2b44483e7 100644 --- a/ast/flag.go +++ b/ast/flag.go @@ -32,6 +32,9 @@ func (f *flagSetter) Enter(in Node) (Node, bool) { } func (f *flagSetter) Leave(in Node) (Node, bool) { + if x, ok := in.(ParamMarkerExpr); ok { + x.SetFlag(FlagHasParamMarker) + } switch x := in.(type) { case *AggregateFuncExpr: f.aggregateFunc(x) @@ -57,8 +60,6 @@ func (f *flagSetter) Leave(in Node) (Node, bool) { x.SetFlag(x.Expr.GetFlag()) case *IsTruthExpr: x.SetFlag(x.Expr.GetFlag()) - case *ParamMarkerExpr: - x.SetFlag(FlagHasParamMarker) case *ParenthesesExpr: x.SetFlag(x.Expr.GetFlag()) case *PatternInExpr: @@ -75,7 +76,6 @@ func (f *flagSetter) Leave(in Node) (Node, bool) { x.SetFlag(FlagHasSubquery) case *UnaryOperationExpr: x.SetFlag(x.V.GetFlag()) - case *ValueExpr: case *ValuesExpr: x.SetFlag(FlagHasReference) case *VariableExpr: diff --git a/ast/functions.go b/ast/functions.go index bdd81485b92d6..99e3eda756ff8 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -346,10 +346,10 @@ func (n *FuncCallExpr) specialFormatArgs(w io.Writer) bool { n.Args[0].Format(w) fmt.Fprint(w, ", INTERVAL ") n.Args[1].Format(w) - fmt.Fprintf(w, " %s", n.Args[2].GetDatum().GetString()) + fmt.Fprintf(w, " %s", n.Args[2].(ValueExpr).GetDatumString()) return true case TimestampAdd, TimestampDiff: - fmt.Fprintf(w, "%s, ", n.Args[0].GetDatum().GetString()) + fmt.Fprintf(w, "%s, ", n.Args[0].(ValueExpr).GetDatumString()) n.Args[1].Format(w) fmt.Fprint(w, ", ") n.Args[2].Format(w) diff --git a/ast/functions_test.go b/ast/functions_test.go index f54120c717fd2..01c37fa1b1eb1 100644 --- a/ast/functions_test.go +++ b/ast/functions_test.go @@ -24,10 +24,11 @@ type testFunctionsSuite struct { } func (ts *testFunctionsSuite) TestFunctionsVisitorCover(c *C) { + valueExpr := NewValueExpr(42) stmts := []Node{ - &AggregateFuncExpr{Args: []ExprNode{&ValueExpr{}}}, - &FuncCallExpr{Args: []ExprNode{&ValueExpr{}}}, - &FuncCastExpr{Expr: &ValueExpr{}}, + &AggregateFuncExpr{Args: []ExprNode{valueExpr}}, + &FuncCallExpr{Args: []ExprNode{valueExpr}}, + &FuncCastExpr{Expr: valueExpr}, } for _, stmt := range stmts { diff --git a/ast/misc.go b/ast/misc.go index c4493a24ccd66..54d2b3ef2a335 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -187,7 +187,7 @@ func (n *DeallocateStmt) Accept(v Visitor) (Node, bool) { // Prepared represents a prepared statement. type Prepared struct { Stmt StmtNode - Params []*ParamMarkerExpr + Params []ParamMarkerExpr SchemaVersion int64 UseCache bool } @@ -321,7 +321,7 @@ type VariableAssignment struct { // VariableAssignment should be able to store information for SetCharset/SetPWD Stmt. // For SetCharsetStmt, Value is charset, ExtendValue is collation. // TODO: Use SetStmt to implement set password statement. - ExtendValue *ValueExpr + ExtendValue ValueExpr } // Accept implements Node interface. diff --git a/ast/misc_test.go b/ast/misc_test.go index d9c14f753fe5f..0786d38e14e75 100644 --- a/ast/misc_test.go +++ b/ast/misc_test.go @@ -44,6 +44,7 @@ func (visitor1) Enter(in Node) (Node, bool) { } func (ts *testMiscSuite) TestMiscVisitorCover(c *C) { + valueExpr := NewValueExpr(42) stmts := []Node{ &AdminStmt{}, &AlterUserStmt{}, @@ -53,15 +54,15 @@ func (ts *testMiscSuite) TestMiscVisitorCover(c *C) { &CreateUserStmt{}, &DeallocateStmt{}, &DoStmt{}, - &ExecuteStmt{UsingVars: []ExprNode{&ValueExpr{}}}, + &ExecuteStmt{UsingVars: []ExprNode{valueExpr}}, &ExplainStmt{Stmt: &ShowStmt{}}, &GrantStmt{}, - &PrepareStmt{SQLVar: &VariableExpr{Value: &ValueExpr{}}}, + &PrepareStmt{SQLVar: &VariableExpr{Value: valueExpr}}, &RollbackStmt{}, &SetPwdStmt{}, &SetStmt{Variables: []*VariableAssignment{ { - Value: &ValueExpr{}, + Value: valueExpr, }, }}, &UseStmt{}, @@ -72,7 +73,7 @@ func (ts *testMiscSuite) TestMiscVisitorCover(c *C) { }, &FlushStmt{}, &PrivElem{}, - &VariableAssignment{Value: &ValueExpr{}}, + &VariableAssignment{Value: valueExpr}, &KillStmt{}, &DropStatsStmt{Table: &TableName{}}, } diff --git a/cmd/importer/parser.go b/cmd/importer/parser.go index 13926e1928f58..f97563211ce0e 100644 --- a/cmd/importer/parser.go +++ b/cmd/importer/parser.go @@ -117,7 +117,7 @@ func (col *column) parseColumnOptions(ops []*ast.ColumnOption) { case ast.ColumnOptionPrimaryKey, ast.ColumnOptionUniqKey, ast.ColumnOptionAutoIncrement: col.table.uniqIndices[col.name] = col case ast.ColumnOptionComment: - col.comment = op.Expr.GetDatum().GetString() + col.comment = op.Expr.(ast.ValueExpr).GetDatumString() } } } diff --git a/executor/prepared.go b/executor/prepared.go index a006bf0d6edd3..5fded52276104 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" "github.com/pkg/errors" @@ -39,7 +40,7 @@ var ( ) type paramMarkerSorter struct { - markers []*ast.ParamMarkerExpr + markers []ast.ParamMarkerExpr } func (p *paramMarkerSorter) Len() int { @@ -47,7 +48,7 @@ func (p *paramMarkerSorter) Len() int { } func (p *paramMarkerSorter) Less(i, j int) bool { - return p.markers[i].Offset < p.markers[j].Offset + return p.markers[i].(*driver.ParamMarkerExpr).Offset < p.markers[j].(*driver.ParamMarkerExpr).Offset } func (p *paramMarkerSorter) Swap(i, j int) { @@ -55,7 +56,7 @@ func (p *paramMarkerSorter) Swap(i, j int) { } type paramMarkerExtractor struct { - markers []*ast.ParamMarkerExpr + markers []ast.ParamMarkerExpr } func (e *paramMarkerExtractor) Enter(in ast.Node) (ast.Node, bool) { @@ -63,7 +64,7 @@ func (e *paramMarkerExtractor) Enter(in ast.Node) (ast.Node, bool) { } func (e *paramMarkerExtractor) Leave(in ast.Node) (ast.Node, bool) { - if x, ok := in.(*ast.ParamMarkerExpr); ok { + if x, ok := in.(*driver.ParamMarkerExpr); ok { e.markers = append(e.markers, x) } return in, true @@ -145,7 +146,7 @@ func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error { sort.Sort(sorter) e.ParamCount = len(sorter.markers) for i := 0; i < e.ParamCount; i++ { - sorter.markers[i].Order = i + sorter.markers[i].SetOrder(i) } prepared := &ast.Prepared{ Stmt: stmt, @@ -156,7 +157,7 @@ func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error { // We try to build the real statement of preparedStmt. for i := range prepared.Params { - prepared.Params[i].SetDatum(types.NewIntDatum(0)) + prepared.Params[i].(*driver.ParamMarkerExpr).Datum = types.NewIntDatum(0) } var p plannercore.Plan p, err = plannercore.BuildLogicalPlan(e.ctx, stmt, e.is) diff --git a/expression/helper.go b/expression/helper.go index fe193651712a1..aa2a04c1a8740 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/parser_driver" "github.com/pkg/errors" ) @@ -74,7 +75,7 @@ func GetTimeValue(ctx sessionctx.Context, v interface{}, tp byte, fsp int) (d ty return d, errors.Trace(err) } } - case *ast.ValueExpr: + case *driver.ValueExpr: switch x.Kind() { case types.KindString: value, err = types.ParseTime(sc, x.GetString(), tp, fsp) diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index 70ef1e6541321..6cc6335b96d9f 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/parser_driver" "github.com/pkg/errors" ) @@ -118,7 +119,7 @@ func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok boo return originInNode, false } sr.push(column) - case *ast.ValueExpr: + case *driver.ValueExpr: value := &Constant{Value: v.Datum, RetType: &v.Type} sr.push(value) case *ast.FuncCallExpr: @@ -148,10 +149,10 @@ func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok boo if v.Sel == nil { sr.inToExpression(len(v.List), v.Not, &v.Type) } - case *ast.ParamMarkerExpr: + case *driver.ParamMarkerExpr: tp := types.NewFieldType(mysql.TypeUnspecified) types.DefaultParamTypeForValue(v.GetValue(), tp) - value := &Constant{Value: v.Datum, RetType: tp} + value := &Constant{Value: v.ValueExpr.Datum, RetType: tp} sr.push(value) case *ast.RowExpr: sr.rowToScalarFunc(v) diff --git a/parser/parser.y b/parser/parser.y index 1721e01a487ff..3634e5c7725e5 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -3126,7 +3126,7 @@ StringLiteral: } | StringLiteral stringLit { - valExpr := $1.(*ast.ValueExpr) + valExpr := $1.(ast.ValueExpr) strLit := valExpr.GetString() expr := ast.NewValueExpr(strLit+$2) // Fix #4239, use first string literal as projection name. @@ -3159,7 +3159,7 @@ ByItem: Expression Order { expr := $1 - valueExpr, ok := expr.(*ast.ValueExpr) + valueExpr, ok := expr.(ast.ValueExpr) if ok { position, isPosition := valueExpr.GetValue().(int64) if isPosition { @@ -3308,9 +3308,7 @@ SimpleExpr: | Literal | paramMarker { - $$ = &ast.ParamMarkerExpr{ - Offset: yyS[yypt].offset, - } + $$ = ast.NewParamMarkerExpr(yyS[yypt].offset) } | Variable | SumExpr @@ -4619,7 +4617,7 @@ LimitClause: } | "LIMIT" LimitOption { - $$ = &ast.Limit{Count: $2.(ast.ExprNode)} + $$ = &ast.Limit{Count: $2.(ast.ValueExpr)} } LimitOption: @@ -4629,9 +4627,7 @@ LimitOption: } | paramMarker { - $$ = &ast.ParamMarkerExpr{ - Offset: yyS[yypt].offset, - } + $$ = ast.NewParamMarkerExpr(yyS[yypt].offset) } SelectStmtLimit: diff --git a/parser/parser_test.go b/parser/parser_test.go index 85d5d7d3ade4f..5bae76ff4ac71 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/terror" + _ "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/testleak" "github.com/pkg/errors" @@ -135,7 +136,7 @@ func (s *testParserSuite) TestSimple(c *C) { c.Assert(ok, IsTrue) c.Assert(is.Lists, HasLen, 1) c.Assert(is.Lists[0], HasLen, 1) - c.Assert(is.Lists[0][0].GetDatum().GetString(), Equals, "/*! truncated */") + c.Assert(is.Lists[0][0].(ast.ValueExpr).GetDatumString(), Equals, "/*! truncated */") // Testcase for CONVERT(expr,type) src = "SELECT CONVERT('111', SIGNED);" @@ -2215,12 +2216,12 @@ func (s *testParserSuite) TestTimestampDiffUnit(c *C) { expr := fields[0].Expr f, ok := expr.(*ast.FuncCallExpr) c.Assert(ok, IsTrue) - c.Assert(f.Args[0].GetDatum().GetString(), Equals, "MONTH") + c.Assert(f.Args[0].(ast.ValueExpr).GetDatumString(), Equals, "MONTH") expr = fields[1].Expr f, ok = expr.(*ast.FuncCallExpr) c.Assert(ok, IsTrue) - c.Assert(f.Args[0].GetDatum().GetString(), Equals, "MONTH") + c.Assert(f.Args[0].(ast.ValueExpr).GetDatumString(), Equals, "MONTH") // Test Illegal TimeUnit for TimestampDiff table := []testCase{ @@ -2383,7 +2384,7 @@ func (s *testParserSuite) TestSetTransaction(c *C) { c.Assert(vars.Name, Equals, "tx_isolation") c.Assert(vars.IsGlobal, Equals, t.isGlobal) c.Assert(vars.IsSystem, Equals, true) - c.Assert(vars.Value.GetValue(), Equals, t.value) + c.Assert(vars.Value.(ast.ValueExpr).GetValue(), Equals, t.value) } } diff --git a/planner/core/cacheable_checker.go b/planner/core/cacheable_checker.go index 26b8f9d6ef4eb..be927cc18986d 100644 --- a/planner/core/cacheable_checker.go +++ b/planner/core/cacheable_checker.go @@ -16,6 +16,7 @@ package core import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/types/parser_driver" ) // Cacheable checks whether the input ast is cacheable. @@ -52,13 +53,13 @@ func (checker *cacheableChecker) Enter(in ast.Node) (out ast.Node, skipChildren } case *ast.Limit: if node.Count != nil { - if _, isParamMarker := node.Count.(*ast.ParamMarkerExpr); isParamMarker { + if _, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker { checker.cacheable = false return in, true } } if node.Offset != nil { - if _, isParamMarker := node.Offset.(*ast.ParamMarkerExpr); isParamMarker { + if _, isParamMarker := node.Offset.(*driver.ParamMarkerExpr); isParamMarker { checker.cacheable = false return in, true } diff --git a/planner/core/cacheable_checker_test.go b/planner/core/cacheable_checker_test.go index ebb4a7d7e44a9..47725b5b6f122 100644 --- a/planner/core/cacheable_checker_test.go +++ b/planner/core/cacheable_checker_test.go @@ -18,6 +18,7 @@ import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/types/parser_driver" ) var _ = Suite(&testCacheableSuite{}) @@ -66,7 +67,7 @@ func (s *testCacheableSuite) TestCacheable(c *C) { c.Assert(Cacheable(stmt), IsFalse) limitStmt := &ast.Limit{ - Count: &ast.ParamMarkerExpr{}, + Count: &driver.ParamMarkerExpr{}, } stmt = &ast.SelectStmt{ Limit: limitStmt, @@ -74,7 +75,7 @@ func (s *testCacheableSuite) TestCacheable(c *C) { c.Assert(Cacheable(stmt), IsFalse) limitStmt = &ast.Limit{ - Offset: &ast.ParamMarkerExpr{}, + Offset: &driver.ParamMarkerExpr{}, } stmt = &ast.SelectStmt{ Limit: limitStmt, diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 2525bac168ed3..94892124244c3 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/auth" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/kvcache" @@ -157,7 +158,7 @@ func (e *Execute) OptimizePreparedPlan(ctx sessionctx.Context, is infoschema.Inf if err != nil { return errors.Trace(err) } - prepared.Params[i].SetDatum(val) + prepared.Params[i].(*driver.ParamMarkerExpr).Datum = val vars.PreparedParams = append(vars.PreparedParams, val) } if prepared.SchemaVersion != is.SchemaMetaVersion() { diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 60d1e938b19b7..960d64bfc67f9 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pkg/errors" ) @@ -36,7 +37,7 @@ var EvalSubquery func(p PhysicalPlan, is infoschema.InfoSchema, ctx sessionctx.C // evalAstExpr evaluates ast expression directly. func evalAstExpr(ctx sessionctx.Context, expr ast.ExprNode) (types.Datum, error) { - if val, ok := expr.(*ast.ValueExpr); ok { + if val, ok := expr.(*driver.ValueExpr); ok { return val.Datum, nil } b := &PlanBuilder{ @@ -753,10 +754,10 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok switch v := inNode.(type) { case *ast.AggregateFuncExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause, *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr: - case *ast.ValueExpr: + case *driver.ValueExpr: value := &expression.Constant{Value: v.Datum, RetType: &v.Type} er.ctxStack = append(er.ctxStack, value) - case *ast.ParamMarkerExpr: + case *driver.ParamMarkerExpr: tp := types.NewFieldType(mysql.TypeUnspecified) types.DefaultParamTypeForValue(v.GetValue(), tp) value := &expression.Constant{Value: v.Datum, RetType: tp} @@ -820,7 +821,7 @@ func datumToConstant(d types.Datum, tp byte) *expression.Constant { return &expression.Constant{Value: d, RetType: types.NewFieldType(tp)} } -func (er *expressionRewriter) getParamExpression(v *ast.ParamMarkerExpr) expression.Expression { +func (er *expressionRewriter) getParamExpression(v *driver.ParamMarkerExpr) expression.Expression { f, err := expression.NewFunction(er.ctx, ast.GetParam, &v.Type, diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index a82ca17db0947..c8a1259f55ba7 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pkg/errors" ) @@ -519,7 +520,7 @@ func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectF } innerExpr := getInnerFromParentheses(field.Expr) - valueExpr, isValueExpr := innerExpr.(*ast.ValueExpr) + valueExpr, isValueExpr := innerExpr.(*driver.ValueExpr) // Non-literal: Output as inputed, except that comments need to be removed. if !isValueExpr { @@ -851,13 +852,13 @@ func getUintForLimitOffset(sc *stmtctx.StatementContext, val interface{}) (uint6 func extractLimitCountOffset(sc *stmtctx.StatementContext, limit *ast.Limit) (count uint64, offset uint64, err error) { if limit.Count != nil { - count, err = getUintForLimitOffset(sc, limit.Count.GetValue()) + count, err = getUintForLimitOffset(sc, limit.Count.(ast.ValueExpr).GetValue()) if err != nil { return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT") } } if limit.Offset != nil { - offset, err = getUintForLimitOffset(sc, limit.Offset.GetValue()) + offset, err = getUintForLimitOffset(sc, limit.Offset.(ast.ValueExpr).GetValue()) if err != nil { return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT") } @@ -970,7 +971,7 @@ func (a *havingAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChi switch n.(type) { case *ast.AggregateFuncExpr: a.inAggFunc = true - case *ast.ParamMarkerExpr, *ast.ColumnNameExpr, *ast.ColumnName: + case *driver.ParamMarkerExpr, *ast.ColumnNameExpr, *ast.ColumnName: case *ast.SubqueryExpr, *ast.ExistsSubqueryExpr: // Enter a new context, skip it. // For example: select sum(c) + c + exists(select c from t) from t; @@ -1152,7 +1153,7 @@ func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { switch inNode.(type) { case *ast.SubqueryExpr, *ast.CompareSubqueryExpr, *ast.ExistsSubqueryExpr: return inNode, true - case *ast.ValueExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.ColumnName: + case *driver.ValueExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.ColumnName: default: g.inExpr = true } diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 3962fd67f60e6..570c67ed87a7c 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/ranger" "github.com/pkg/errors" ) @@ -261,8 +262,8 @@ func (b *PlanBuilder) buildSet(v *ast.SetStmt) (Plan, error) { } if vars.ExtendValue != nil { assign.ExtendValue = &expression.Constant{ - Value: vars.ExtendValue.Datum, - RetType: &vars.ExtendValue.Type, + Value: vars.ExtendValue.(*driver.ValueExpr).Datum, + RetType: &vars.ExtendValue.(*driver.ValueExpr).Type, } } p.VarAssigns = append(p.VarAssigns, assign) @@ -399,7 +400,8 @@ func (b *PlanBuilder) buildPrepare(x *ast.PrepareStmt) Plan { Name: x.Name, } if x.SQLVar != nil { - p.SQLText, _ = x.SQLVar.GetValue().(string) + // TODO: Prepared statement from variable expression do not work as expected. + // p.SQLText, _ = x.SQLVar.GetValue().(string) } else { p.SQLText = x.SQLText } @@ -1275,7 +1277,7 @@ func (b *PlanBuilder) buildValuesListOfInsert(insert *ast.InsertStmt, insertPlan } else { expr, err = b.getDefaultValue(affectedValuesCols[j]) } - case *ast.ValueExpr: + case *driver.ValueExpr: expr = &expression.Constant{ Value: x.Datum, RetType: &x.Type, diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index c3dde6f418aa4..64e809e576207 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tipb/go-tipb" "github.com/pkg/errors" ) @@ -336,9 +337,9 @@ func getNameValuePairs(nvPairs []nameValuePair, expr ast.ExprNode) []nameValuePa } var d types.Datum switch x := binOp.R.(type) { - case *ast.ValueExpr: + case *driver.ValueExpr: d = x.Datum - case *ast.ParamMarkerExpr: + case *driver.ParamMarkerExpr: d = x.Datum } if d.IsNull() { diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index d0a17036e1e7f..15ba5ee800fbe 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/charset" "github.com/pkg/errors" ) @@ -95,7 +96,7 @@ func (p *preprocessor) Leave(in ast.Node) (out ast.Node, ok bool) { p.checkContainDotColumn(x) case *ast.DropTableStmt, *ast.AlterTableStmt, *ast.RenameTableStmt: p.inCreateOrDropTable = false - case *ast.ParamMarkerExpr: + case *driver.ParamMarkerExpr: if !p.inPrepare { p.err = parser.ErrSyntax.GenWithStack("syntax error, unexpected '?'") return @@ -134,14 +135,20 @@ func checkAutoIncrementOp(colDef *ast.ColumnDef, num int) (bool, error) { return hasAutoIncrement, nil } for _, op := range colDef.Options[num+1:] { - if op.Tp == ast.ColumnOptionDefaultValue && !op.Expr.GetDatum().IsNull() { - return hasAutoIncrement, errors.Errorf("Invalid default value for '%s'", colDef.Name.Name.O) + if op.Tp == ast.ColumnOptionDefaultValue { + if tmp, ok := op.Expr.(*driver.ValueExpr); ok { + if !tmp.Datum.IsNull() { + return hasAutoIncrement, errors.Errorf("Invalid default value for '%s'", colDef.Name.Name.O) + } + } } } } if colDef.Options[num].Tp == ast.ColumnOptionDefaultValue && len(colDef.Options) != num+1 { - if colDef.Options[num].Expr.GetDatum().IsNull() { - return hasAutoIncrement, nil + if tmp, ok := colDef.Options[num].Expr.(*driver.ValueExpr); ok { + if tmp.Datum.IsNull() { + return hasAutoIncrement, nil + } } for _, op := range colDef.Options[num+1:] { if op.Tp == ast.ColumnOptionAutoIncrement { diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index cff66c5b1b9bb..5790f96fe10ca 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -17,13 +17,13 @@ import ( "fmt" . "github.com/pingcap/check" - "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/util/auth" "github.com/pingcap/tidb/util/testleak" "golang.org/x/net/context" @@ -55,7 +55,7 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { err := r.Next(ctx, chk) c.Assert(err, IsNil) c.Assert(chk.NumRows() == 0, IsFalse) - datums := ast.RowToDatums(chk.GetRow(0), r.Fields()) + datums := statistics.RowToDatums(chk.GetRow(0), r.Fields()) match(c, datums, []byte(`%`), []byte("root"), []byte(""), "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "anyhost"}, []byte(""), []byte("")), IsTrue) @@ -91,7 +91,7 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { chk = r.NewChunk() err = r.Next(ctx, chk) c.Assert(err, IsNil) - datums = ast.RowToDatums(chk.GetRow(0), r.Fields()) + datums = statistics.RowToDatums(chk.GetRow(0), r.Fields()) match(c, datums, 3) mustExecSQL(c, se, "drop table if exists t") se.Close() @@ -159,7 +159,7 @@ func (s *testBootstrapSuite) TestBootstrapWithError(c *C) { c.Assert(err, IsNil) c.Assert(chk.NumRows() == 0, IsFalse) row := chk.GetRow(0) - datums := ast.RowToDatums(row, r.Fields()) + datums := statistics.RowToDatums(row, r.Fields()) match(c, datums, []byte(`%`), []byte("root"), []byte(""), "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") c.Assert(r.Close(), IsNil) diff --git a/statistics/sample.go b/statistics/sample.go index 6b58fb5916c38..cac1d4229cdaa 100644 --- a/statistics/sample.go +++ b/statistics/sample.go @@ -175,7 +175,7 @@ func (s SampleBuilder) CollectColumnStats() ([]*SampleCollector, *SortedBuilder, panic(fmt.Sprintf("%T", s.RecordSet)) } for row := it.Begin(); row != it.End(); row = it.Next() { - datums := ast.RowToDatums(row, s.RecordSet.Fields()) + datums := RowToDatums(row, s.RecordSet.Fields()) if s.PkBuilder != nil { err = s.PkBuilder.Iterate(datums[0]) if err != nil { @@ -192,3 +192,12 @@ func (s SampleBuilder) CollectColumnStats() ([]*SampleCollector, *SortedBuilder, } } } + +// RowToDatums converts row to datum slice. +func RowToDatums(row chunk.Row, fields []*ast.ResultField) []types.Datum { + datums := make([]types.Datum, len(fields)) + for i, f := range fields { + datums[i] = row.GetDatum(i, &f.Column.FieldType) + } + return datums +} diff --git a/statistics/statistics_test.go b/statistics/statistics_test.go index d719edad22df9..02c9da7ae279b 100644 --- a/statistics/statistics_test.go +++ b/statistics/statistics_test.go @@ -183,7 +183,7 @@ func buildPK(sctx sessionctx.Context, numBuckets, id int64, records ast.RecordSe } it := chunk.NewIterator4Chunk(chk) for row := it.Begin(); row != it.End(); row = it.Next() { - datums := ast.RowToDatums(row, records.Fields()) + datums := RowToDatums(row, records.Fields()) err = b.Iterate(datums[0]) if err != nil { return 0, nil, errors.Trace(err) @@ -208,7 +208,7 @@ func buildIndex(sctx sessionctx.Context, numBuckets, id int64, records ast.Recor break } for row := it.Begin(); row != it.End(); row = it.Next() { - datums := ast.RowToDatums(row, records.Fields()) + datums := RowToDatums(row, records.Fields()) buf, err := codec.EncodeKey(sctx.GetSessionVars().StmtCtx, nil, datums...) if err != nil { return 0, nil, nil, errors.Trace(err) diff --git a/types/parser_driver/value_expr.go b/types/parser_driver/value_expr.go new file mode 100644 index 0000000000000..0a87fa2fc9b83 --- /dev/null +++ b/types/parser_driver/value_expr.go @@ -0,0 +1,163 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "fmt" + "io" + "strconv" + + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/types" +) + +// The purpose of driver package is to decompose the dependency of the parser and +// types package. +// It provides the NewValueExpr function for the ast package, so the ast package +// do not depends on the concrete definition of `types.Datum`, thus get rid of +// the dependency of the types package. +// The parser package depends on the ast package, but not the types package. +// The whole relationship: +// ast imports [] +// types imports [] +// parser imports [ast] +// driver imports [ast, types] +// tidb imports [parser, driver] + +func init() { + ast.NewValueExpr = newValueExpr + ast.NewParamMarkerExpr = newParamMarkerExpr +} + +var ( + _ ast.ParamMarkerExpr = &ParamMarkerExpr{} + _ ast.ValueExpr = &ValueExpr{} +) + +// ValueExpr is the simple value expression. +type ValueExpr struct { + ast.TexprNode + types.Datum + projectionOffset int +} + +// GetDatumString implements the ast.ValueExpr interface. +func (n *ValueExpr) GetDatumString() string { + return n.GetString() +} + +// Format the ExprNode into a Writer. +func (n *ValueExpr) Format(w io.Writer) { + var s string + switch n.Kind() { + case types.KindNull: + s = "NULL" + case types.KindInt64: + if n.Type.Flag&mysql.IsBooleanFlag != 0 { + if n.GetInt64() > 0 { + s = "TRUE" + } else { + s = "FALSE" + } + } else { + s = strconv.FormatInt(n.GetInt64(), 10) + } + case types.KindUint64: + s = strconv.FormatUint(n.GetUint64(), 10) + case types.KindFloat32: + s = strconv.FormatFloat(n.GetFloat64(), 'e', -1, 32) + case types.KindFloat64: + s = strconv.FormatFloat(n.GetFloat64(), 'e', -1, 64) + case types.KindString, types.KindBytes: + s = strconv.Quote(n.GetString()) + case types.KindMysqlDecimal: + s = n.GetMysqlDecimal().String() + case types.KindBinaryLiteral: + if n.Type.Flag&mysql.UnsignedFlag != 0 { + s = fmt.Sprintf("x'%x'", n.GetBytes()) + } else { + s = n.GetBinaryLiteral().ToBitLiteralString(true) + } + default: + panic("Can't format to string") + } + fmt.Fprint(w, s) +} + +// newValueExpr creates a ValueExpr with value, and sets default field type. +func newValueExpr(value interface{}) ast.ValueExpr { + if ve, ok := value.(*ValueExpr); ok { + return ve + } + ve := &ValueExpr{} + ve.SetValue(value) + types.DefaultTypeForValue(value, &ve.Type) + ve.projectionOffset = -1 + return ve +} + +// SetProjectionOffset sets ValueExpr.projectionOffset for logical plan builder. +func (n *ValueExpr) SetProjectionOffset(offset int) { + n.projectionOffset = offset +} + +// GetProjectionOffset returns ValueExpr.projectionOffset. +func (n *ValueExpr) GetProjectionOffset() int { + return n.projectionOffset +} + +// Accept implements Node interface. +func (n *ValueExpr) Accept(v ast.Visitor) (ast.Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*ValueExpr) + return v.Leave(n) +} + +// ParamMarkerExpr expression holds a place for another expression. +// Used in parsing prepare statement. +type ParamMarkerExpr struct { + ValueExpr + Offset int + Order int +} + +func newParamMarkerExpr(offset int) ast.ParamMarkerExpr { + return &ParamMarkerExpr{ + Offset: offset, + } +} + +// Format the ExprNode into a Writer. +func (n *ParamMarkerExpr) Format(w io.Writer) { + panic("Not implemented") +} + +// Accept implements Node Accept interface. +func (n *ParamMarkerExpr) Accept(v ast.Visitor) (ast.Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*ParamMarkerExpr) + return v.Leave(n) +} + +// SetOrder implements the ast.ParamMarkerExpr interface. +func (n *ParamMarkerExpr) SetOrder(order int) { + n.Order = order +}