Skip to content

Commit

Permalink
Merge pull request #9118 from planetscale/NormalizeHexValuesInQueries
Browse files Browse the repository at this point in the history
Ensure that hex query predicates are normalized for planner cache
  • Loading branch information
systay authored Nov 4, 2021
2 parents 133acbe + e622ebd commit f085f12
Show file tree
Hide file tree
Showing 13 changed files with 303 additions and 20 deletions.
10 changes: 10 additions & 0 deletions go/sqltypes/bind_variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ func BuildBindVariables(in map[string]interface{}) (map[string]*querypb.BindVari
return out, nil
}

// HexNumBindVariable converts bytes representing a hex number to a bind var.
func HexNumBindVariable(v []byte) *querypb.BindVariable {
return ValueBindVariable(NewHexNum(v))
}

// HexValBindVariable converts bytes representing a hex encoded string to a bind var.
func HexValBindVariable(v []byte) *querypb.BindVariable {
return ValueBindVariable(NewHexVal(v))
}

// Int8BindVariable converts an int8 to a bind var.
func Int8BindVariable(v int8) *querypb.BindVariable {
return ValueBindVariable(NewInt8(v))
Expand Down
2 changes: 2 additions & 0 deletions go/sqltypes/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ const (
Geometry = querypb.Type_GEOMETRY
TypeJSON = querypb.Type_JSON
Expression = querypb.Type_EXPRESSION
HexNum = querypb.Type_HEXNUM
HexVal = querypb.Type_HEXVAL
)

// bit-shift the mysql flags by two byte so we
Expand Down
10 changes: 9 additions & 1 deletion go/sqltypes/type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ func TestTypeValues(t *testing.T) {
}, {
defined: Expression,
expected: 31,
}, {
defined: HexNum,
expected: 32 | flagIsText,
}, {
defined: HexVal,
expected: 33 | flagIsText,
}}
for _, tcase := range testcases {
if int(tcase.defined) != tcase.expected {
Expand Down Expand Up @@ -162,6 +168,8 @@ func TestCategory(t *testing.T) {
Geometry,
TypeJSON,
Expression,
HexNum,
HexVal,
}
for _, typ := range alltypes {
matched := false
Expand Down Expand Up @@ -192,7 +200,7 @@ func TestCategory(t *testing.T) {
}
matched = true
}
if typ == Null || typ == Decimal || typ == Expression || typ == Bit {
if typ == Null || typ == Decimal || typ == Expression || typ == Bit || typ == HexNum || typ == HexVal {
if matched {
t.Errorf("%v matched more than one category", typ)
}
Expand Down
12 changes: 11 additions & 1 deletion go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func NewValue(typ querypb.Type, val []byte) (v Value, err error) {
return NULL, err
}
return MakeTrusted(typ, val), nil
case IsQuoted(typ) || typ == Bit || typ == Null:
case IsQuoted(typ) || typ == Bit || typ == HexNum || typ == HexVal || typ == Null:
return MakeTrusted(typ, val), nil
}
// All other types are unsafe or invalid.
Expand All @@ -102,6 +102,16 @@ func MakeTrusted(typ querypb.Type, val []byte) Value {
return Value{typ: typ, val: val}
}

// NewHexNum builds an Hex Value.
func NewHexNum(v []byte) Value {
return MakeTrusted(HexNum, v)
}

// NewHexVal builds a HexVal Value.
func NewHexVal(v []byte) Value {
return MakeTrusted(HexVal, v)
}

// NewInt64 builds an Int64 Value.
func NewInt64(v int64) Value {
return MakeTrusted(Int64, strconv.AppendInt(nil, v, 10))
Expand Down
92 changes: 92 additions & 0 deletions go/test/endtoend/vtgate/queries/normalize/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
Copyright 2021 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package normalize

import (
"flag"
"os"
"testing"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/test/endtoend/cluster"
)

var (
clusterInstance *cluster.LocalProcessCluster
vtParams mysql.ConnParams
KeyspaceName = "ks_normalize"
Cell = "test_normalize"
SchemaSQL = `
create table t1(
id bigint unsigned not null,
charcol char(10),
vcharcol varchar(50),
bincol binary(50),
varbincol varbinary(50),
floatcol float,
deccol decimal(5,2),
bitcol bit,
datecol date,
enumcol enum('small', 'medium', 'large'),
setcol set('a', 'b', 'c'),
jsoncol json,
geocol geometry,
primary key(id)
) Engine=InnoDB;
`
)

func TestMain(m *testing.M) {
defer cluster.PanicHandler(nil)
flag.Parse()

exitCode := func() int {
clusterInstance = cluster.NewCluster(Cell, "localhost")
defer clusterInstance.Teardown()

// Start topo server
err := clusterInstance.StartTopo()
if err != nil {
return 1
}

// Start keyspace
keyspace := &cluster.Keyspace{
Name: KeyspaceName,
SchemaSQL: SchemaSQL,
}
clusterInstance.VtGateExtraArgs = []string{}
clusterInstance.VtTabletExtraArgs = []string{}
err = clusterInstance.StartKeyspace(*keyspace, []string{"-"}, 1, false)
if err != nil {
return 1
}

// Start vtgate
err = clusterInstance.StartVtgate()
if err != nil {
return 1
}

vtParams = mysql.ConnParams{
Host: clusterInstance.Hostname,
Port: clusterInstance.VtgateMySQLPort,
}
return m.Run()
}()
os.Exit(exitCode)
}
82 changes: 82 additions & 0 deletions go/test/endtoend/vtgate/queries/normalize/normalize_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
Copyright 2021 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package normalize

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/test/endtoend/vtgate/utils"
)

func TestNormalizeAllFields(t *testing.T) {
conn, err := mysql.Connect(context.Background(), &vtParams)
require.NoError(t, err)
defer conn.Close()

insertQuery := string(`insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5))`)
normalizedInsertQuery := string(`insert into t1 values (:vtg1, :vtg2, :vtg3, :vtg4, :vtg5, :vtg6, :vtg7, :vtg8, :vtg9, :vtg10, :vtg11, :vtg12, point(:vtg13, :vtg14))`)
selectQuery := "select * from t1"
utils.Exec(t, conn, insertQuery)
qr := utils.Exec(t, conn, selectQuery)
assert.Equal(t, 1, len(qr.Rows), "wrong number of table rows, expected 1 but had %d. Results: %v", len(qr.Rows), qr.Rows)

// Now need to figure out the best way to check the normalized query in the planner cache...
results, err := getPlanCache(fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateProcess.Port))
require.Nil(t, err)
found := false
for _, record := range results {
key := record["Key"].(string)
if key == normalizedInsertQuery {
found = true
break
}
}
assert.True(t, found, "correctly normalized record not found in planner cache")
}

func getPlanCache(vtgateHostPort string) ([]map[string]interface{}, error) {
var results []map[string]interface{}
client := http.Client{
Timeout: 10 * time.Second,
}
resp, err := client.Get(fmt.Sprintf("http://%s/debug/query_plans", vtgateHostPort))
if err != nil {
return results, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return results, err
}

err = json.Unmarshal(body, &results)
if err != nil {
return results, err
}

return results, nil
}
32 changes: 22 additions & 10 deletions go/vt/proto/query/query.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package sqlparser

import (
"bytes"
"encoding/hex"
"encoding/json"
"strings"
Expand Down Expand Up @@ -480,6 +481,28 @@ func (node *Literal) HexDecode() ([]byte, error) {
return hex.DecodeString(node.Val)
}

// EncodeHexValToMySQLQueryFormat encodes the hexval back into the query format
// for passing on to MySQL as a bind var
func (node *Literal) encodeHexValToMySQLQueryFormat() ([]byte, error) {
nb := node.Bytes()
if node.Type != HexVal {
return nb, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Literal value is not a HexVal")
}

// Let's make this idempotent in case it's called more than once
if nb[0] == 'x' && nb[1] == '0' && nb[len(nb)-1] == '\'' {
return nb, nil
}

var bb bytes.Buffer
bb.WriteByte('x')
bb.WriteByte('\'')
bb.WriteString(string(nb))
bb.WriteByte('\'')
nb = bb.Bytes()
return nb, nil
}

// Equal returns true if the column names match.
func (node *ColName) Equal(c *ColName) bool {
// Failsafe: ColName should not be empty.
Expand Down
11 changes: 11 additions & 0 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,17 @@ func (nz *normalizer) sqlToBindvar(node SQLNode) *querypb.BindVariable {
v, err = sqltypes.NewValue(sqltypes.Int64, node.Bytes())
case FloatVal:
v, err = sqltypes.NewValue(sqltypes.Float64, node.Bytes())
case HexNum:
v, err = sqltypes.NewValue(sqltypes.HexNum, node.Bytes())
case HexVal:
// We parse the `x'7b7d'` string literal into a hex encoded string of `7b7d` in the parser
// We need to re-encode it back to the original MySQL query format before passing it on as a bindvar value to MySQL
var vbytes []byte
vbytes, err = node.encodeHexValToMySQLQueryFormat()
if err != nil {
return nil
}
v, err = sqltypes.NewValue(sqltypes.HexVal, vbytes)
default:
return nil
}
Expand Down
Loading

0 comments on commit f085f12

Please sign in to comment.