Skip to content

Commit

Permalink
Merge pull request #6157 from planetscale/fix-init-db
Browse files Browse the repository at this point in the history
Fix all com_init_db and mysql client related issue with use <dbname>
  • Loading branch information
harshit-gangal authored May 7, 2020
2 parents 1915215 + d90138e commit e019bc7
Show file tree
Hide file tree
Showing 14 changed files with 27 additions and 52 deletions.
5 changes: 1 addition & 4 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,10 +780,7 @@ func (c *Conn) handleNextCommand(handler Handler) error {
case ComInitDB:
db := c.parseComInitDB(data)
c.recycleReadPacket()
c.schemaName = db
handler.ComInitDB(c, db)
if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Errorf("Error writing ComInitDB result to %s: %v", c, err)
if err := c.execQuery(fmt.Sprintf("use `%s`", db), handler, false); err != nil {
return err
}
case ComQuery:
Expand Down
4 changes: 0 additions & 4 deletions go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,6 @@ func (db *DB) ConnectionClosed(c *mysql.Conn) {
delete(db.connections, c.ConnectionID)
}

// ComInitDB is part of the mysql.Handler interface.
func (db *DB) ComInitDB(c *mysql.Conn, schemaName string) {
}

// ComQuery is part of the mysql.Handler interface.
func (db *DB) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
return db.Handler.HandleQuery(c, query, callback)
Expand Down
19 changes: 12 additions & 7 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package mysql

import (
"crypto/tls"
"fmt"
"io"
"net"
"strings"
Expand Down Expand Up @@ -90,10 +91,6 @@ type Handler interface {
// ConnectionClosed is called when a connection is closed.
ConnectionClosed(c *Conn)

// InitDB is called once at the beginning to set db name,
// and subsequently for every ComInitDB event.
ComInitDB(c *Conn, schemaName string)

// ComQuery is called when a connection receives a query.
// Note the contents of the query slice may change after
// the first call to callback. So the Handler should not
Expand Down Expand Up @@ -441,6 +438,17 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
defer connCountPerUser.Add(c.User, -1)
}

// Set initial db name.
if c.schemaName != "" {
err = l.handler.ComQuery(c, fmt.Sprintf("use `%s`", c.schemaName), func(result *sqltypes.Result) error {
return nil
})
if err != nil {
c.writeErrorPacketFromError(err)
return
}
}

// Negotiation worked, send OK packet.
if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Errorf("Cannot write OK packet to %s: %v", c, err)
Expand All @@ -457,9 +465,6 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
log.Warningf("Slow connection from %s: %v", c, connectTime)
}

// Set initial db name.
l.handler.ComInitDB(c, c.schemaName)

for {
err := c.handleNextCommand(l.handler)
if err != nil {
Expand Down
3 changes: 0 additions & 3 deletions go/mysql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ func (th *testHandler) NewConnection(c *Conn) {
func (th *testHandler) ConnectionClosed(c *Conn) {
}

func (th *testHandler) ComInitDB(c *Conn, schemaName string) {
}

func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error {
if result := th.Result(); result != nil {
callback(result)
Expand Down
7 changes: 1 addition & 6 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,7 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st
// addNeededBindVars adds bind vars that are needed by the plan
func (e *Executor) addNeededBindVars(bindVarNeeds sqlparser.BindVarNeeds, bindVars map[string]*querypb.BindVariable, session *SafeSession) error {
if bindVarNeeds.NeedDatabase {
keyspace, _, _, _ := e.ParseDestinationTarget(session.TargetString)
if keyspace == "" {
bindVars[sqlparser.DBVarName] = sqltypes.NullBindVariable
} else {
bindVars[sqlparser.DBVarName] = sqltypes.StringBindVariable(keyspace)
}
bindVars[sqlparser.DBVarName] = sqltypes.StringBindVariable(session.TargetString)
}

if bindVarNeeds.NeedLastInsertID {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func TestSelectDatabase(t *testing.T) {
{Name: "database()", Type: sqltypes.VarBinary},
},
Rows: [][]sqltypes.Value{{
sqltypes.NewVarBinary("TestExecutor"),
sqltypes.NewVarBinary("TestExecutor@master"),
}},
}
require.NoError(t, err)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func TestDirectTargetRewrites(t *testing.T) {
require.NoError(t, err)
testQueries(t, "sbclookup", sbclookup, []*querypb.BoundQuery{{
Sql: "select :__vtdbname as `database()` from dual",
BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestUnsharded")},
BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestUnsharded/0@master")},
}})
}

Expand Down Expand Up @@ -1038,7 +1038,7 @@ func TestExecutorUse(t *testing.T) {
}

_, err = executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{}), "use UnexistentKeyspace", nil)
wantErr = "invalid keyspace provided: UnexistentKeyspace"
wantErr = "Unknown database 'UnexistentKeyspace' (errno 1049) (sqlstate 42000)"
if err == nil || err.Error() != wantErr {
t.Errorf("got: %v, want %v", err, wantErr)
}
Expand Down
16 changes: 3 additions & 13 deletions go/vt/vtgate/mysql_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func TestMySQLProtocolExecuteUseStatement(t *testing.T) {
// No such keyspace this will fail
_, err = c.ExecuteFetch("use InvalidKeyspace", 0, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid keyspace provided: InvalidKeyspace")
assert.Contains(t, err.Error(), "Unknown database 'InvalidKeyspace' (errno 1049) (sqlstate 42000)")

// That doesn't reset the vitess_target
qr, err = c.ExecuteFetch("show vitess_target", 1, false)
Expand All @@ -135,18 +135,8 @@ func TestMySQLProtocolExecuteUseStatement(t *testing.T) {
}

func TestMysqlProtocolInvalidDB(t *testing.T) {
c, err := mysqlConnect(&mysql.ConnParams{DbName: "invalidDB"})
if err != nil {
t.Fatal(err)
}
defer c.Close()

_, err = c.ExecuteFetch("select id from t1", 10, true /* wantfields */)
c.Close()
want := "vtgate: : keyspace invalidDB not found in vschema (errno 1105) (sqlstate HY000) during query: select id from t1"
if err == nil || err.Error() != want {
t.Errorf("exec with db:\n%v, want\n%s", err, want)
}
_, err := mysqlConnect(&mysql.ConnParams{DbName: "invalidDB"})
require.EqualError(t, err, "vtgate: : Unknown database 'invalidDB' (errno 1049) (sqlstate 42000) (errno 1049) (sqlstate 42000)")
}

func TestMySQLProtocolClientFoundRows(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/plan_executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func TestPlanSelectDatabase(t *testing.T) {
{Name: "database()", Type: sqltypes.VarBinary},
},
Rows: [][]sqltypes.Value{{
sqltypes.NewVarBinary("TestExecutor"),
sqltypes.NewVarBinary("TestExecutor@master"),
}},
}
require.NoError(t, err)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/plan_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func TestPlanDirectTargetRewrites(t *testing.T) {
require.NoError(t, err)
testQueries(t, "sbclookup", sbclookup, []*querypb.BoundQuery{{
Sql: "select :__vtdbname as `database()` from dual",
BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestUnsharded")},
BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestUnsharded/0@master")},
}})
}

Expand Down Expand Up @@ -996,7 +996,7 @@ func TestPlanExecutorUse(t *testing.T) {
}

_, err = executor.Execute(context.Background(), "TestExecute", NewSafeSession(&vtgatepb.Session{}), "use UnexistentKeyspace", nil)
wantErr = "invalid keyspace provided: UnexistentKeyspace"
wantErr = "Unknown database 'UnexistentKeyspace' (errno 1049) (sqlstate 42000)"
if err == nil || err.Error() != wantErr {
t.Errorf("got: %v, want %v", err, wantErr)
}
Expand Down
4 changes: 0 additions & 4 deletions go/vt/vtgate/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,6 @@ func startSpan(ctx context.Context, query, label string) (trace.Span, context.Co
return startSpanTestable(ctx, query, label, trace.NewSpan, trace.NewFromString)
}

func (vh *vtgateHandler) ComInitDB(c *mysql.Conn, schemaName string) {
vh.session(c).TargetString = schemaName
}

func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
ctx := context.Background()
var cancel context.CancelFunc
Expand Down
3 changes: 0 additions & 3 deletions go/vt/vtgate/plugin_mysql_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ func (th *testHandler) NewConnection(c *mysql.Conn) {
func (th *testHandler) ConnectionClosed(c *mysql.Conn) {
}

func (th *testHandler) ComInitDB(c *mysql.Conn, schemaName string) {
}

func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes.Result) error) error {
return nil
}
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/vcursor_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"sync/atomic"
"time"

"vitess.io/vitess/go/mysql"

"vitess.io/vitess/go/vt/callerid"
vschemapb "vitess.io/vitess/go/vt/proto/vschema"
"vitess.io/vitess/go/vt/topotools"
Expand Down Expand Up @@ -345,7 +347,7 @@ func (vc *vcursorImpl) SetTarget(target string) error {
return err
}
if _, ok := vc.vschema.Keyspaces[keyspace]; keyspace != "" && !ok {
return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid keyspace provided: %s", keyspace)
return mysql.NewSQLError(mysql.ERBadDb, "42000", "Unknown database '%s'", keyspace)
}

if vc.safeSession.InTransaction() && tabletType != topodatapb.TabletType_MASTER {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/vcursor_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func TestSetTarget(t *testing.T) {
}, {
vschema: vschemaWith2KS,
targetString: "ks3",
expectedError: "invalid keyspace provided: ks3",
expectedError: "Unknown database 'ks3' (errno 1049) (sqlstate 42000)",
}, {
vschema: vschemaWith2KS,
targetString: "ks2@replica",
Expand Down

0 comments on commit e019bc7

Please sign in to comment.