Skip to content

Commit

Permalink
Merge pull request #6298 from luisfmcalado/prep-stmt-bind-variables
Browse files Browse the repository at this point in the history
Fix prepared statements in column specs
  • Loading branch information
sougou authored Jun 15, 2020
2 parents 0b3a934 + cb803d9 commit 2bc5eac
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 7 deletions.
8 changes: 7 additions & 1 deletion go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -926,9 +926,15 @@ func (c *Conn) handleNextCommand(handler Handler) error {
prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount)
}

bindVars := make(map[string]*querypb.BindVariable, paramsCount)
for i := uint16(0); i < paramsCount; i++ {
parameterID := fmt.Sprintf("v%d", i+1)
bindVars[parameterID] = &querypb.BindVariable{}
}

c.PrepareData[c.StatementID] = prepare

fld, err := handler.ComPrepare(c, queries[0])
fld, err := handler.ComPrepare(c, queries[0], bindVars)

if err != nil {
if werr := c.writeErrorPacketFromError(err); werr != nil {
Expand Down
2 changes: 1 addition & 1 deletion go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ func (db *DB) comQueryOrdered(query string) (*sqltypes.Result, error) {
}

// ComPrepare is part of the mysql.Handler interface.
func (db *DB) ComPrepare(c *mysql.Conn, query string) ([]*querypb.Field, error) {
func (db *DB) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
return nil, nil
}

Expand Down
2 changes: 1 addition & 1 deletion go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ type Handler interface {

// ComPrepare is called when a connection receives a prepared
// statement query.
ComPrepare(c *Conn, query string) ([]*querypb.Field, error)
ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error)

// ComStmtExecute is called when a connection receives a statement
// execute query.
Expand Down
2 changes: 1 addition & 1 deletion go/mysql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.R
return nil
}

func (th *testHandler) ComPrepare(c *Conn, query string) ([]*querypb.Field, error) {
func (th *testHandler) ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
return nil, nil
}

Expand Down
40 changes: 40 additions & 0 deletions go/test/endtoend/preparestmt/stmt_methods_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,46 @@ func reconnectAndTest(t *testing.T) {

}

// TestColumnParameter query database using column
// parameter.
func TestColumnParameter(t *testing.T) {
defer cluster.PanicHandler(t)
dbo := Connect(t)
defer dbo.Close()

id := 1000
parameter1 := "param1"
message := "TestColumnParameter"
insertStmt := "INSERT INTO " + tableName + " (id, msg, keyspace_id) VALUES (?, ?, ?);"
values := []interface{}{
id,
message,
2000,
}
exec(t, dbo, insertStmt, values...)

var param, msg string
var recID int

selectStmt := "SELECT COALESCE(?, id), msg FROM " + tableName + " WHERE msg = ? LIMIT ?"

results1, err := dbo.Query(selectStmt, parameter1, message, 1)
require.Nil(t, err)
require.True(t, results1.Next())

results1.Scan(&param, &msg)
assert.Equal(t, parameter1, param)
assert.Equal(t, message, msg)

results2, err := dbo.Query(selectStmt, nil, message, 1)
require.Nil(t, err)
require.True(t, results2.Next())

results2.Scan(&recID, &msg)
assert.Equal(t, id, recID)
assert.Equal(t, message, msg)
}

// TestWrongTableName query database using invalid
// tablename and validate error.
func TestWrongTableName(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq
}

// ComPrepare is the handler for command prepare.
func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string) ([]*querypb.Field, error) {
func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
var ctx context.Context
var cancel context.CancelFunc
if *mysqlQueryTimeout != 0 {
Expand Down Expand Up @@ -252,7 +252,7 @@ func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string) ([]*querypb.Fie
}
}()

session, fld, err := vh.vtg.Prepare(ctx, session, query, make(map[string]*querypb.BindVariable))
session, fld, err := vh.vtg.Prepare(ctx, session, query, bindVars)
err = mysql.NewSQLErrorFromError(err)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/plugin_mysql_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes
return nil
}

func (th *testHandler) ComPrepare(c *mysql.Conn, q string) ([]*querypb.Field, error) {
func (th *testHandler) ComPrepare(c *mysql.Conn, q string, b map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
return nil, nil
}

Expand Down

0 comments on commit 2bc5eac

Please sign in to comment.