Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only rewrite database() against dual #5793

Merged
merged 8 commits into from
Feb 6, 2020
28 changes: 25 additions & 3 deletions go/vt/sqlparser/expression_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func PrepareAST(in Statement, bindVars map[string]*querypb.BindVariable, prefix
// RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries
func RewriteAST(in Statement) (*RewriteASTResult, error) {
er := new(expressionRewriter)
er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in)
Rewrite(in, er.goingDown, nil)

return &RewriteASTResult{
Expand All @@ -41,6 +42,25 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) {
}, nil
}

func shouldRewriteDatabaseFunc(in Statement) bool {
selct, ok := in.(*Select)
if !ok {
return false
}
if len(selct.From) != 1 {
return false
}
aliasedTable, ok := selct.From[0].(*AliasedTableExpr)
if !ok {
return false
}
tableName, ok := aliasedTable.Expr.(TableName)
if !ok {
return false
}
return tableName.Name.String() == "dual"
}

// RewriteASTResult contains the rewritten ast and meta information about it
type RewriteASTResult struct {
AST Statement
Expand All @@ -49,8 +69,9 @@ type RewriteASTResult struct {
}

type expressionRewriter struct {
lastInsertID, database bool
err error
lastInsertID, database bool
shouldRewriteDatabaseFunc bool
err error
}

const (
Expand All @@ -67,6 +88,7 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
buf := NewTrackedBuffer(nil)
node.Expr.Format(buf)
inner := new(expressionRewriter)
inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc
tmp := Rewrite(node.Expr, inner.goingDown, nil)
newExpr, ok := tmp.(Expr)
if !ok {
Expand All @@ -91,7 +113,7 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
cursor.Replace(bindVarExpression(LastInsertIDName))
er.lastInsertID = true
}
case node.Name.EqualString("database"):
case node.Name.EqualString("database") && er.shouldRewriteDatabaseFunc:
if len(node.Exprs) > 0 {
er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. DATABASE() takes no arguments")
} else {
Expand Down
38 changes: 26 additions & 12 deletions go/vt/sqlparser/expression_rewriting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ func TestRewrites(in *testing.T) {
expected: "SELECT :__vtdbname as `database()`",
db: true, liid: false,
},
{
in: "SELECT database() from test",
expected: "SELECT database() from test",
db: false, liid: false,
deepthi marked this conversation as resolved.
Show resolved Hide resolved
},
{
in: "SELECT last_insert_id() as test",
expected: "SELECT :__lastInsertId as test",
Expand All @@ -55,14 +60,29 @@ func TestRewrites(in *testing.T) {
db: true, liid: true,
},
{
in: "select (select database() from test) from test",
expected: "select (select :__vtdbname as `database()` from test) as `(select database() from test)` from test",
in: "select (select database()) from test",
expected: "select (select database() from dual) from test",
db: false, liid: false,
},
{
in: "select (select database() from dual) from test",
expected: "select (select database() from dual) from test",
db: false, liid: false,
deepthi marked this conversation as resolved.
Show resolved Hide resolved
},
{
in: "select (select database() from dual) from dual",
expected: "select (select :__vtdbname as `database()` from dual) as `(select database() from dual)` from dual",
db: true, liid: false,
},
{
in: "select id from user where database()",
expected: "select id from user where :__vtdbname",
db: true, liid: false,
expected: "select id from user where database()",
db: false, liid: false,
},
{
in: "select table_name from information_schema.tables where table_schema = database()",
expected: "select table_name from information_schema.tables where table_schema = database()",
db: false, liid: false,
},
}

Expand All @@ -77,16 +97,10 @@ func TestRewrites(in *testing.T) {
expected, err := Parse(tc.expected)
require.NoError(t, err)

s := toString(expected)
require.Equal(t, s, toString(result.AST))
s := String(expected)
require.Equal(t, s, String(result.AST))
require.Equal(t, tc.liid, result.NeedLastInsertID, "should need last insert id")
require.Equal(t, tc.db, result.NeedDatabase, "should need database name")
})
}
}

func toString(node SQLNode) string {
buf := NewTrackedBuffer(nil)
node.Format(buf)
return buf.String()
}
1 change: 0 additions & 1 deletion go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql
sql = comments.Leading + normalized + comments.Trailing
if rewriteResult.NeedDatabase {
keyspace, _, _, _ := e.ParseDestinationTarget(safeSession.TargetString)
log.Warningf("This is the keyspace name: ---> %v", keyspace)
if keyspace == "" {
bindVars[sqlparser.DBVarName] = sqltypes.NullBindVariable
} else {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/testdata/filter_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@
"Name": "user",
"Sharded": true
},
"Query": "select id from user where :__vtdbname",
"Query": "select id from user where database()",
"FieldQuery": "select id from user where 1 != 1",
"Table": "user"
}
Expand Down