Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only rewrite database() against dual #5793

Merged
merged 8 commits into from
Feb 6, 2020
119 changes: 119 additions & 0 deletions go/test/endtoend/tabletmanager/dbnameoverride/tablet_master_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
Copyright 2019 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package master

import (
"context"
"flag"
"os"
"testing"

"github.com/stretchr/testify/require"
"vitess.io/vitess/go/mysql"

"vitess.io/vitess/go/test/endtoend/cluster"
)

var (
clusterInstance *cluster.LocalProcessCluster
vtParams mysql.ConnParams
hostname = "localhost"
keyspaceName = "ks"
cell = "zone1"
sqlSchema = `
create table t1(
id bigint,
value varchar(16),
primary key(id)
) Engine=InnoDB;
`

vSchema = `
{
"sharded": true,
"vindexes": {
"hash": {
"type": "hash"
}
},
"tables": {
"t1": {
"column_vindexes": [
{
"column": "id",
"name": "hash"
}
]
}
}
}`
)

const dbName = "myDbName"

func TestMain(m *testing.M) {
flag.Parse()

exitCode := func() int {
clusterInstance = cluster.NewCluster(cell, hostname)
defer clusterInstance.Teardown()

// Start topo server
err := clusterInstance.StartTopo()
if err != nil {
return 1
}

// Set extra tablet args for lock timeout
clusterInstance.VtTabletExtraArgs = []string{
"-init_db_name_override", dbName,
}

// Start keyspace
keyspace := &cluster.Keyspace{
Name: keyspaceName,
SchemaSQL: sqlSchema,
VSchema: vSchema,
}

if err = clusterInstance.StartUnshardedKeyspace(*keyspace, 1, false); err != nil {
return 1
}

if err = clusterInstance.StartVtgate(); err != nil {
return 1
}
vtParams = mysql.ConnParams{
Host: clusterInstance.Hostname,
Port: clusterInstance.VtgateMySQLPort,
}

return m.Run()
}()
os.Exit(exitCode)
}

func TestDbNameOverride(t *testing.T) {
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
require.NoError(t, err)
defer conn.Close()
qr, err := conn.ExecuteFetch("SELECT database() FROM information_schema.tables WHERE table_schema = database()", 1000, true)

require.NoError(t, err)
require.Equal(t, 1, len(qr.Rows), "did not get enough rows back")
require.Equal(t, dbName, qr.Rows[0][0].ToString())
}
28 changes: 25 additions & 3 deletions go/vt/sqlparser/expression_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func PrepareAST(in Statement, bindVars map[string]*querypb.BindVariable, prefix
// RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries
func RewriteAST(in Statement) (*RewriteASTResult, error) {
er := new(expressionRewriter)
er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in)
Rewrite(in, er.goingDown, nil)

return &RewriteASTResult{
Expand All @@ -41,6 +42,25 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) {
}, nil
}

func shouldRewriteDatabaseFunc(in Statement) bool {
selct, ok := in.(*Select)
if !ok {
return false
}
if len(selct.From) != 1 {
return false
}
aliasedTable, ok := selct.From[0].(*AliasedTableExpr)
if !ok {
return false
}
tableName, ok := aliasedTable.Expr.(TableName)
if !ok {
return false
}
return tableName.Name.String() == "dual"
}

// RewriteASTResult contains the rewritten ast and meta information about it
type RewriteASTResult struct {
AST Statement
Expand All @@ -49,8 +69,9 @@ type RewriteASTResult struct {
}

type expressionRewriter struct {
lastInsertID, database bool
err error
lastInsertID, database bool
shouldRewriteDatabaseFunc bool
err error
}

const (
Expand All @@ -67,6 +88,7 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
buf := NewTrackedBuffer(nil)
node.Expr.Format(buf)
inner := new(expressionRewriter)
inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc
tmp := Rewrite(node.Expr, inner.goingDown, nil)
newExpr, ok := tmp.(Expr)
if !ok {
Expand All @@ -91,7 +113,7 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
cursor.Replace(bindVarExpression(LastInsertIDName))
er.lastInsertID = true
}
case node.Name.EqualString("database"):
case node.Name.EqualString("database") && er.shouldRewriteDatabaseFunc:
if len(node.Exprs) > 0 {
er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. DATABASE() takes no arguments")
} else {
Expand Down
38 changes: 26 additions & 12 deletions go/vt/sqlparser/expression_rewriting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ func TestRewrites(in *testing.T) {
expected: "SELECT :__vtdbname as `database()`",
db: true, liid: false,
},
{
in: "SELECT database() from test",
expected: "SELECT database() from test",
db: false, liid: false,
deepthi marked this conversation as resolved.
Show resolved Hide resolved
},
{
in: "SELECT last_insert_id() as test",
expected: "SELECT :__lastInsertId as test",
Expand All @@ -55,14 +60,29 @@ func TestRewrites(in *testing.T) {
db: true, liid: true,
},
{
in: "select (select database() from test) from test",
expected: "select (select :__vtdbname as `database()` from test) as `(select database() from test)` from test",
in: "select (select database()) from test",
expected: "select (select database() from dual) from test",
db: false, liid: false,
},
{
in: "select (select database() from dual) from test",
expected: "select (select database() from dual) from test",
db: false, liid: false,
deepthi marked this conversation as resolved.
Show resolved Hide resolved
},
{
in: "select (select database() from dual) from dual",
expected: "select (select :__vtdbname as `database()` from dual) as `(select database() from dual)` from dual",
db: true, liid: false,
},
{
in: "select id from user where database()",
expected: "select id from user where :__vtdbname",
db: true, liid: false,
expected: "select id from user where database()",
db: false, liid: false,
},
{
in: "select table_name from information_schema.tables where table_schema = database()",
expected: "select table_name from information_schema.tables where table_schema = database()",
db: false, liid: false,
},
}

Expand All @@ -77,16 +97,10 @@ func TestRewrites(in *testing.T) {
expected, err := Parse(tc.expected)
require.NoError(t, err)

s := toString(expected)
require.Equal(t, s, toString(result.AST))
s := String(expected)
require.Equal(t, s, String(result.AST))
require.Equal(t, tc.liid, result.NeedLastInsertID, "should need last insert id")
require.Equal(t, tc.db, result.NeedDatabase, "should need database name")
})
}
}

func toString(node SQLNode) string {
buf := NewTrackedBuffer(nil)
node.Format(buf)
return buf.String()
}
1 change: 0 additions & 1 deletion go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql
sql = comments.Leading + normalized + comments.Trailing
if rewriteResult.NeedDatabase {
keyspace, _, _, _ := e.ParseDestinationTarget(safeSession.TargetString)
log.Warningf("This is the keyspace name: ---> %v", keyspace)
if keyspace == "" {
bindVars[sqlparser.DBVarName] = sqltypes.NullBindVariable
} else {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/testdata/filter_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@
"Name": "user",
"Sharded": true
},
"Query": "select id from user where :__vtdbname",
"Query": "select id from user where database()",
"FieldQuery": "select id from user where 1 != 1",
"Table": "user"
}
Expand Down