Skip to content

Commit

Permalink
fix issue with json unmarshalling of operators with space in them (#1…
Browse files Browse the repository at this point in the history
…6905)

Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay authored Oct 8, 2024
1 parent 2cef46e commit f40e076
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 161 deletions.
30 changes: 15 additions & 15 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1562,36 +1562,36 @@ func (op ComparisonExprOperator) ToString() string {
}
}

func ComparisonExprOperatorFromJson(s string) ComparisonExprOperator {
func ComparisonExprOperatorFromJson(s string) (ComparisonExprOperator, error) {
switch s {
case EqualStr:
return EqualOp
return EqualOp, nil
case JsonLessThanStr:
return LessThanOp
return LessThanOp, nil
case JsonGreaterThanStr:
return GreaterThanOp
return GreaterThanOp, nil
case JsonLessThanOrEqualStr:
return LessEqualOp
return LessEqualOp, nil
case JsonGreaterThanOrEqualStr:
return GreaterEqualOp
return GreaterEqualOp, nil
case NotEqualStr:
return NotEqualOp
return NotEqualOp, nil
case NullSafeEqualStr:
return NullSafeEqualOp
return NullSafeEqualOp, nil
case InStr:
return InOp
return InOp, nil
case NotInStr:
return NotInOp
return NotInOp, nil
case LikeStr:
return LikeOp
return LikeOp, nil
case NotLikeStr:
return NotLikeOp
return NotLikeOp, nil
case RegexpStr:
return RegexpOp
return RegexpOp, nil
case NotRegexpStr:
return NotRegexpOp
return NotRegexpOp, nil
default:
return 0
return 0, fmt.Errorf("unknown ComparisonExpOperator: %s", s)
}
}

Expand Down
177 changes: 40 additions & 137 deletions go/vt/vtgate/executor_vexplain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ package vtgate

import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -115,153 +118,53 @@ func TestSimpleVexplainTrace(t *testing.T) {
}

func TestVExplainKeys(t *testing.T) {
tests := []struct {
query string
expectedRowString string
}{
{
query: "select count(*), col2 from music group by col2",
expectedRowString: `{
"statementType": "SELECT",
"groupingColumns": [
"music.col2"
],
"selectColumns": [
"music.col2"
]
}`,
}, {
query: "select * from user u join user_extra ue on u.id = ue.user_id where u.col1 > 100 and ue.noLimit = 'foo'",
expectedRowString: `{
"statementType": "SELECT",
"joinColumns": [
"user.id =",
"user_extra.user_id ="
],
"filterColumns": [
"user.col1 gt",
"user_extra.noLimit ="
]
}`,
}, {
// same as above, but written differently
query: "select * from user_extra ue, user u where ue.noLimit = 'foo' and u.col1 > 100 and ue.user_id = u.id",
expectedRowString: `{
"statementType": "SELECT",
"joinColumns": [
"user.id =",
"user_extra.user_id ="
],
"filterColumns": [
"user.col1 gt",
"user_extra.noLimit ="
]
}`,
},
{
query: "select u.foo, ue.bar, count(*) from user u join user_extra ue on u.id = ue.user_id where u.name = 'John Doe' group by 1, 2",
expectedRowString: `{
"statementType": "SELECT",
"groupingColumns": [
"user.foo",
"user_extra.bar"
],
"joinColumns": [
"user.id =",
"user_extra.user_id ="
],
"filterColumns": [
"user.name ="
],
"selectColumns": [
"user.foo",
"user_extra.bar"
]
}`,
},
{
query: "select * from (select * from user) as derived where derived.amount > 1000",
expectedRowString: `{
"statementType": "SELECT"
}`,
},
{
query: "select name, sum(amount) from user group by name",
expectedRowString: `{
"statementType": "SELECT",
"groupingColumns": [
"user.name"
],
"selectColumns": [
"user.amount",
"user.name"
]
}`,
},
{
query: "select name from user where age > 30",
expectedRowString: `{
"statementType": "SELECT",
"filterColumns": [
"user.age gt"
],
"selectColumns": [
"user.name"
]
}`,
},
{
query: "select * from user where name = 'apa' union select * from user_extra where name = 'monkey'",
expectedRowString: `{
"statementType": "SELECT",
"filterColumns": [
"user.name =",
"user_extra.name ="
]
}`,
},
{
query: "update user set name = 'Jane Doe' where id = 1",
expectedRowString: `{
"statementType": "UPDATE",
"filterColumns": [
"user.id ="
]
}`,
},
{
query: "delete from user where order_date < '2023-01-01'",
expectedRowString: `{
"statementType": "DELETE",
"filterColumns": [
"user.order_date lt"
]
}`,
},
{
query: "select * from user where name between 'A' and 'C'",
expectedRowString: `{
"statementType": "SELECT",
"filterColumns": [
"user.name ge",
"user.name le"
]
}`,
},
type testCase struct {
Query string `json:"query"`
Expected json.RawMessage `json:"expected"`
}

var tests []testCase
data, err := os.ReadFile("testdata/executor_vexplain.json")
require.NoError(t, err)

err = json.Unmarshal(data, &tests)
require.NoError(t, err)

var updatedTests []testCase

for _, tt := range tests {
t.Run(tt.query, func(t *testing.T) {
t.Run(tt.Query, func(t *testing.T) {
executor, _, _, _, _ := createExecutorEnv(t)
session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary"})
gotResult, err := executor.Execute(context.Background(), nil, "Execute", session, "vexplain keys "+tt.query, nil)
gotResult, err := executor.Execute(context.Background(), nil, "Execute", session, "vexplain keys "+tt.Query, nil)
require.NoError(t, err)

gotRowString := gotResult.Rows[0][0].ToString()
assert.Equal(t, tt.expectedRowString, gotRowString)
assert.JSONEq(t, string(tt.Expected), gotRowString)

updatedTests = append(updatedTests, testCase{
Query: tt.Query,
Expected: json.RawMessage(gotRowString),
})

if t.Failed() {
fmt.Println(gotRowString)
fmt.Println("Test failed for query:", tt.Query)
fmt.Println("Got result:", gotRowString)
}
})
}

// If anything failed, write the updated test cases to a temp file
if t.Failed() {
tempFilePath := filepath.Join(os.TempDir(), "updated_vexplain_keys_tests.json")
fmt.Println("Writing updated tests to:", tempFilePath)

updatedTestsData, err := json.MarshalIndent(updatedTests, "", "\t")
require.NoError(t, err)

err = os.WriteFile(tempFilePath, updatedTestsData, 0644)
require.NoError(t, err)

fmt.Println("Updated tests written to:", tempFilePath)
}
}
39 changes: 33 additions & 6 deletions go/vt/vtgate/planbuilder/operators/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,37 @@ func (cu *ColumnUse) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, &s); err != nil {
return err
}
parts := strings.Fields(s)
if len(parts) != 2 {
spaceIdx := strings.LastIndex(s, " ")
if spaceIdx == -1 {
return fmt.Errorf("invalid ColumnUse format: %s", s)
}
if err := cu.Column.UnmarshalJSON([]byte(`"` + parts[0] + `"`)); err != nil {
return err

for i := spaceIdx - 1; i >= 0; i-- {
// table.column not like
// table.`tricky not` like
if s[i] == '`' || s[i] == '.' {
break
}
if s[i] == ' ' {
spaceIdx = i
break
}
if i == 0 {
return fmt.Errorf("invalid ColumnUse format: %s", s)
}
}

colStr, opStr := s[:spaceIdx], s[spaceIdx+1:]

err := cu.Column.UnmarshalJSON([]byte(`"` + colStr + `"`))
if err != nil {
return fmt.Errorf("failed to unmarshal column: %w", err)
}

cu.Uses, err = sqlparser.ComparisonExprOperatorFromJson(strings.ToLower(opStr))
if err != nil {
return fmt.Errorf("failed to unmarshal operator: %w", err)
}
cu.Uses = sqlparser.ComparisonExprOperatorFromJson(strings.ToLower(parts[1]))
return nil
}

Expand Down Expand Up @@ -209,5 +232,9 @@ func createColumn(ctx *plancontext.PlanningContext, col *sqlparser.ColName) *Col
if table == nil {
return nil
}
return &Column{Table: table.Name.String(), Name: col.Name.String()}
return &Column{
// we want the escaped versions of the names
Table: sqlparser.String(table.Name),
Name: sqlparser.String(col.Name),
}
}
7 changes: 4 additions & 3 deletions go/vt/vtgate/planbuilder/operators/keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,21 @@ func TestMarshalUnmarshal(t *testing.T) {
StatementType: "SELECT",
TableName: []string{"users", "orders"},
GroupingColumns: []Column{
{Table: "", Name: "category"},
{Table: "orders", Name: "category"},
{Table: "users", Name: "department"},
},
JoinColumns: []ColumnUse{
{Column: Column{Table: "users", Name: "id"}, Uses: sqlparser.EqualOp},
{Column: Column{Table: "orders", Name: "user_id"}, Uses: sqlparser.EqualOp},
},
FilterColumns: []ColumnUse{
{Column: Column{Table: "", Name: "age"}, Uses: sqlparser.GreaterThanOp},
{Column: Column{Table: "users", Name: "age"}, Uses: sqlparser.GreaterThanOp},
{Column: Column{Table: "orders", Name: "total"}, Uses: sqlparser.LessThanOp},
{Column: Column{Table: "orders", Name: "`tricky name not`"}, Uses: sqlparser.InOp},
},
SelectColumns: []Column{
{Table: "users", Name: "name"},
{Table: "", Name: "email"},
{Table: "users", Name: "email"},
{Table: "orders", Name: "amount"},
},
}
Expand Down
Loading

0 comments on commit f40e076

Please sign in to comment.