Skip to content

Commit

Permalink
Merge pull request #6642 from planetscale/ast-names
Browse files Browse the repository at this point in the history
AST struct name rewording
  • Loading branch information
harshit-gangal authored Sep 1, 2020
2 parents 7f5dbdd + da23db3 commit c8b47c8
Show file tree
Hide file tree
Showing 42 changed files with 342 additions and 341 deletions.
2 changes: 1 addition & 1 deletion go/cmd/query_analyzer/query_analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func analyze(line []byte) {
}

func formatWithBind(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) {
v, ok := node.(*sqlparser.SQLVal)
v, ok := node.(*sqlparser.Literal)
if !ok {
node.Format(buf)
return
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,8 @@ func (c *Conn) handleNextCommand(handler Handler) error {
paramsCount := uint16(0)
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
switch node := node.(type) {
case *sqlparser.SQLVal:
if strings.HasPrefix(string(node.Val), ":v") {
case sqlparser.Argument:
if strings.HasPrefix(string(node), ":v") {
paramsCount++
}
}
Expand Down
12 changes: 7 additions & 5 deletions go/vt/sqlparser/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,11 @@ func IsColName(node Expr) bool {
// NULL is not considered to be a value.
func IsValue(node Expr) bool {
switch v := node.(type) {
case *SQLVal:
case Argument:
return true
case *Literal:
switch v.Type {
case StrVal, HexVal, IntVal, ValArg:
case StrVal, HexVal, IntVal:
return true
}
}
Expand Down Expand Up @@ -358,10 +360,10 @@ func IsSimpleTuple(node Expr) bool {
// NewPlanValue builds a sqltypes.PlanValue from an Expr.
func NewPlanValue(node Expr) (sqltypes.PlanValue, error) {
switch node := node.(type) {
case *SQLVal:
case Argument:
return sqltypes.PlanValue{Key: string(node[1:])}, nil
case *Literal:
switch node.Type {
case ValArg:
return sqltypes.PlanValue{Key: string(node.Val[1:])}, nil
case IntVal:
n, err := sqltypes.NewIntegral(string(node.Val))
if err != nil {
Expand Down
87 changes: 42 additions & 45 deletions go/vt/sqlparser/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func TestIsColName(t *testing.T) {
in: &ColName{},
out: true,
}, {
in: newHexVal(""),
in: newHexLiteral(""),
}}
for _, tc := range testcases {
out := IsColName(tc.in)
Expand All @@ -256,32 +256,35 @@ func TestIsValue(t *testing.T) {
in Expr
out bool
}{{
in: newStrVal("aa"),
in: newStrLiteral("aa"),
out: true,
}, {
in: newHexVal("3131"),
in: newHexLiteral("3131"),
out: true,
}, {
in: newIntVal("1"),
in: newIntLiteral("1"),
out: true,
}, {
in: newValArg(":a"),
in: newArgument(":a"),
out: true,
}, {
in: &NullVal{},
out: false,
}}
for _, tc := range testcases {
out := IsValue(tc.in)
if out != tc.out {
t.Errorf("IsValue(%T): %v, want %v", tc.in, out, tc.out)
}
if tc.out {
// NewPlanValue should not fail for valid values.
if _, err := NewPlanValue(tc.in); err != nil {
t.Error(err)
t.Run(String(tc.in), func(t *testing.T) {
out := IsValue(tc.in)
if out != tc.out {
t.Errorf("IsValue(%T): %v, want %v", tc.in, out, tc.out)
}
}
if tc.out {
// NewPlanValue should not fail for valid values.
if _, err := NewPlanValue(tc.in); err != nil {
t.Error(err)
}
}

})
}
}

Expand All @@ -293,7 +296,7 @@ func TestIsNull(t *testing.T) {
in: &NullVal{},
out: true,
}, {
in: newStrVal(""),
in: newStrLiteral(""),
}}
for _, tc := range testcases {
out := IsNull(tc.in)
Expand All @@ -308,7 +311,7 @@ func TestIsSimpleTuple(t *testing.T) {
in Expr
out bool
}{{
in: ValTuple{newStrVal("aa")},
in: ValTuple{newStrLiteral("aa")},
out: true,
}, {
in: ValTuple{&ColName{}},
Expand Down Expand Up @@ -338,43 +341,40 @@ func TestNewPlanValue(t *testing.T) {
out sqltypes.PlanValue
err string
}{{
in: &SQLVal{
Type: ValArg,
Val: []byte(":valarg"),
},
in: Argument(":valarg"),
out: sqltypes.PlanValue{Key: "valarg"},
}, {
in: &SQLVal{
in: &Literal{
Type: IntVal,
Val: []byte("10"),
},
out: sqltypes.PlanValue{Value: sqltypes.NewInt64(10)},
}, {
in: &SQLVal{
in: &Literal{
Type: IntVal,
Val: []byte("1111111111111111111111111111111111111111"),
},
err: "value out of range",
}, {
in: &SQLVal{
in: &Literal{
Type: StrVal,
Val: []byte("strval"),
},
out: sqltypes.PlanValue{Value: sqltypes.NewVarBinary("strval")},
}, {
in: &SQLVal{
in: &Literal{
Type: BitVal,
Val: []byte("01100001"),
},
err: "expression is too complex",
}, {
in: &SQLVal{
in: &Literal{
Type: HexVal,
Val: []byte("3131"),
},
out: sqltypes.PlanValue{Value: sqltypes.NewVarBinary("11")},
}, {
in: &SQLVal{
in: &Literal{
Type: HexVal,
Val: []byte("313"),
},
Expand All @@ -384,11 +384,8 @@ func TestNewPlanValue(t *testing.T) {
out: sqltypes.PlanValue{ListKey: "list"},
}, {
in: ValTuple{
&SQLVal{
Type: ValArg,
Val: []byte(":valarg"),
},
&SQLVal{
Argument(":valarg"),
&Literal{
Type: StrVal,
Val: []byte("strval"),
},
Expand All @@ -409,15 +406,15 @@ func TestNewPlanValue(t *testing.T) {
in: &NullVal{},
out: sqltypes.PlanValue{},
}, {
in: &SQLVal{
in: &Literal{
Type: FloatVal,
Val: []byte("2.1"),
},
out: sqltypes.PlanValue{Value: sqltypes.NewFloat64(2.1)},
}, {
in: &UnaryExpr{
Operator: Latin1Str,
Expr: &SQLVal{
Expr: &Literal{
Type: StrVal,
Val: []byte("strval"),
},
Expand All @@ -426,7 +423,7 @@ func TestNewPlanValue(t *testing.T) {
}, {
in: &UnaryExpr{
Operator: UBinaryStr,
Expr: &SQLVal{
Expr: &Literal{
Type: StrVal,
Val: []byte("strval"),
},
Expand All @@ -435,7 +432,7 @@ func TestNewPlanValue(t *testing.T) {
}, {
in: &UnaryExpr{
Operator: Utf8mb4Str,
Expr: &SQLVal{
Expr: &Literal{
Type: StrVal,
Val: []byte("strval"),
},
Expand All @@ -444,7 +441,7 @@ func TestNewPlanValue(t *testing.T) {
}, {
in: &UnaryExpr{
Operator: Utf8Str,
Expr: &SQLVal{
Expr: &Literal{
Type: StrVal,
Val: []byte("strval"),
},
Expand All @@ -453,7 +450,7 @@ func TestNewPlanValue(t *testing.T) {
}, {
in: &UnaryExpr{
Operator: MinusStr,
Expr: &SQLVal{
Expr: &Literal{
Type: FloatVal,
Val: []byte("2.1"),
},
Expand Down Expand Up @@ -482,18 +479,18 @@ var mustMatch = utils.MustMatchFn(
[]string{".Conn"}, // ignored fields
)

func newStrVal(in string) *SQLVal {
return NewStrVal([]byte(in))
func newStrLiteral(in string) *Literal {
return NewStrLiteral([]byte(in))
}

func newIntVal(in string) *SQLVal {
return NewIntVal([]byte(in))
func newIntLiteral(in string) *Literal {
return NewIntLiteral([]byte(in))
}

func newHexVal(in string) *SQLVal {
return NewHexVal([]byte(in))
func newHexLiteral(in string) *Literal {
return NewHexLiteral([]byte(in))
}

func newValArg(in string) *SQLVal {
return NewValArg([]byte(in))
func newArgument(in string) Expr {
return NewArgument([]byte(in))
}
35 changes: 21 additions & 14 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,13 @@ type ColumnType struct {
Autoincrement BoolVal
Default Expr
OnUpdate Expr
Comment *SQLVal
Comment *Literal

// Numeric field options
Length *SQLVal
Length *Literal
Unsigned BoolVal
Zerofill BoolVal
Scale *SQLVal
Scale *Literal

// Text field options
Charset string
Expand Down Expand Up @@ -614,12 +614,15 @@ type (
Subquery *Subquery
}

// SQLVal represents a single value.
SQLVal struct {
// Literal represents a fixed value.
Literal struct {
Type ValType
Val []byte
}

// Argument represents bindvariable expression
Argument []byte

// NullVal represents a NULL value.
NullVal struct{}

Expand Down Expand Up @@ -711,7 +714,7 @@ type (
// In this case StrVal will be set instead of Name.
SubstrExpr struct {
Name *ColName
StrVal *SQLVal
StrVal *Literal
From Expr
To Expr
}
Expand Down Expand Up @@ -771,7 +774,8 @@ func (*ComparisonExpr) iExpr() {}
func (*RangeCond) iExpr() {}
func (*IsExpr) iExpr() {}
func (*ExistsExpr) iExpr() {}
func (*SQLVal) iExpr() {}
func (*Literal) iExpr() {}
func (Argument) iExpr() {}
func (*NullVal) iExpr() {}
func (BoolVal) iExpr() {}
func (*ColName) iExpr() {}
Expand Down Expand Up @@ -805,8 +809,8 @@ func (ListArg) iColTuple() {}
// ConvertType represents the type in call to CONVERT(expr, type)
type ConvertType struct {
Type string
Length *SQLVal
Scale *SQLVal
Length *Literal
Scale *Literal
Operator string
Charset string
}
Expand Down Expand Up @@ -1579,7 +1583,7 @@ func (node *ExistsExpr) Format(buf *TrackedBuffer) {
}

// Format formats the node.
func (node *SQLVal) Format(buf *TrackedBuffer) {
func (node *Literal) Format(buf *TrackedBuffer) {
switch node.Type {
case StrVal:
sqltypes.MakeTrusted(sqltypes.VarBinary, node.Val).EncodeSQL(buf)
Expand All @@ -1589,13 +1593,16 @@ func (node *SQLVal) Format(buf *TrackedBuffer) {
buf.astPrintf(node, "X'%s'", node.Val)
case BitVal:
buf.astPrintf(node, "B'%s'", node.Val)
case ValArg:
buf.WriteArg(string(node.Val))
default:
panic("unexpected")
}
}

// Format formats the node.
func (node Argument) Format(buf *TrackedBuffer) {
buf.WriteArg(string(node))
}

// Format formats the node.
func (node *NullVal) Format(buf *TrackedBuffer) {
buf.astPrintf(node, "null")
Expand Down Expand Up @@ -1864,8 +1871,8 @@ func (node *SetExpr) Format(buf *TrackedBuffer) {
case node.Name.EqualString("charset") || node.Name.EqualString("names"):
buf.astPrintf(node, "%s %v", node.Name.String(), node.Expr)
case node.Name.EqualString(TransactionStr):
sqlVal := node.Expr.(*SQLVal)
buf.astPrintf(node, "%s %s", node.Name.String(), strings.ToLower(string(sqlVal.Val)))
literal := node.Expr.(*Literal)
buf.astPrintf(node, "%s %s", node.Name.String(), strings.ToLower(string(literal.Val)))
default:
buf.astPrintf(node, "%v = %v", node.Name, node.Expr)
}
Expand Down
Loading

0 comments on commit c8b47c8

Please sign in to comment.