Skip to content

Commit

Permalink
Unnest subqueries that don't need to be subqueries
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Oct 8, 2020
1 parent 139f048 commit c1d81af
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 43 deletions.
11 changes: 11 additions & 0 deletions go/test/endtoend/vtgate/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,17 @@ func TestUseStmtInOLAP(t *testing.T) {
}
}

func TestInformationSchemaWithSubquery(t *testing.T) {
defer cluster.PanicHandler(t)
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
require.NoError(t, err)
defer conn.Close()

result := exec(t, conn, "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = (SELECT SCHEMA()) AND TABLE_NAME = 'not_exists'")
assert.Empty(t, result.Rows)
}

func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) {
t.Helper()
qr := exec(t, conn, query)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/sqlparser/ast_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) {
er := newExpressionRewriter()
er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in)
setRewriter := &setNormalizer{}
out, ok := Rewrite(in, er.goingDown, setRewriter.rewriteSetComingUp).(Statement)
out, ok := Rewrite(in, er.rewrite, setRewriter.rewriteSetComingUp).(Statement)
if !ok {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "statement rewriting returned a non statement: %s", String(out))
}
Expand Down
8 changes: 7 additions & 1 deletion go/vt/sqlparser/bind_var_needs.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type BindVarNeeds struct {
NeedSystemVariable,
// NeedUserDefinedVariables keeps track of all user defined variables a query is using
NeedUserDefinedVariables []string
otherRewrites bool
}

//MergeWith adds bind vars needs coming from sub scopes
Expand Down Expand Up @@ -56,8 +57,13 @@ func (bvn *BindVarNeeds) NeedsSysVar(name string) bool {
return contains(bvn.NeedSystemVariable, name)
}

func (bvn *BindVarNeeds) NoteRewrite() {
bvn.otherRewrites = true
}

func (bvn *BindVarNeeds) HasRewrites() bool {
return len(bvn.NeedFunctionResult) > 0 ||
return bvn.otherRewrites ||
len(bvn.NeedFunctionResult) > 0 ||
len(bvn.NeedUserDefinedVariables) > 0 ||
len(bvn.NeedSystemVariable) > 0
}
Expand Down
49 changes: 45 additions & 4 deletions go/vt/sqlparser/expression_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ const (
UserDefinedVariableName = "__vtudv"
)

func (er *expressionRewriter) rewriteAliasedExpr(cursor *Cursor, node *AliasedExpr) (*BindVarNeeds, error) {
func (er *expressionRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) {
inner := newExpressionRewriter()
inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc
tmp := Rewrite(node.Expr, inner.goingDown, nil)
tmp := Rewrite(node.Expr, inner.rewrite, nil)
newExpr, ok := tmp.(Expr)
if !ok {
return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp))
Expand All @@ -83,7 +83,7 @@ func (er *expressionRewriter) rewriteAliasedExpr(cursor *Cursor, node *AliasedEx
return inner.bindVars, nil
}

func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
func (er *expressionRewriter) rewrite(cursor *Cursor) bool {
switch node := cursor.Node().(type) {
// select last_insert_id() -> select :__lastInsertId as `last_insert_id()`
case *Select:
Expand All @@ -92,7 +92,7 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
if ok && aliasedExpr.As.IsEmpty() {
buf := NewTrackedBuffer(nil)
aliasedExpr.Expr.Format(buf)
innerBindVarNeeds, err := er.rewriteAliasedExpr(cursor, aliasedExpr)
innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr)
if err != nil {
er.err = err
return false
Expand All @@ -112,6 +112,8 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
case DoubleAt:
er.sysVarRewrite(cursor, node)
}
case *Subquery:
er.unnestSubQueries(cursor, node)
}
return true
}
Expand Down Expand Up @@ -159,6 +161,45 @@ func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) {
}
}

func (er *expressionRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) {
sel, isSimpleSelect := subquery.Select.(*Select)
// Today, subqueries and derived tables use the same AST struct,
// so we have to check what the parent is so we don't accidentally
// rewrite a FROM clause instead of an expression
_, isDerivedTable := cursor.Parent().(*AliasedTableExpr)

if isDerivedTable || !isSimpleSelect {
return
}

if !(len(sel.SelectExprs) != 1 ||
len(sel.OrderBy) != 0 ||
len(sel.GroupBy) != 0 ||
len(sel.From) != 1 ||
sel.Where == nil ||
sel.Having == nil ||
sel.Limit == nil) && sel.Lock == NoLock {
return
}
aliasedTable, ok := sel.From[0].(*AliasedTableExpr)
if !ok {
return
}
table, ok := aliasedTable.Expr.(TableName)
if !ok || table.Name.String() != "dual" {
return
}
expr, ok := sel.SelectExprs[0].(*AliasedExpr)
if !ok {
return
}
er.bindVars.NoteRewrite()
// we need to make sure that the inner expression also gets rewritten,
// so we fire off another rewriter traversal here
rewrittenExpr := Rewrite(expr.Expr, er.rewrite, nil)
cursor.Replace(rewrittenExpr)
}

func bindVarExpression(name string) Expr {
return NewArgument([]byte(":" + name))
}
21 changes: 18 additions & 3 deletions go/vt/sqlparser/expression_rewriting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,18 @@ func TestRewrites(in *testing.T) {
expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`",
db: true, liid: true,
}, {
// unnest database() call
in: "select (select database()) from test",
expected: "select (select database() from dual) from test",
expected: "select database() as `(select database() from dual)` from test",
// no bindvar needs
}, {
// unnest database() call
in: "select (select database() from dual) from test",
expected: "select (select database() from dual) from test",
expected: "select database() as `(select database() from dual)` from test",
// no bindvar needs
}, {
in: "select (select database() from dual) from dual",
expected: "select (select :__vtdbname as `database()` from dual) as `(select database() from dual)` from dual",
expected: "select :__vtdbname as `(select database() from dual)` from dual",
db: true,
}, {
in: "select id from user where database()",
Expand Down Expand Up @@ -130,6 +132,19 @@ func TestRewrites(in *testing.T) {
in: "SELECT @@workload",
expected: "SELECT :__vtworkload as `@@workload`",
workload: true,
}, {
in: "select (select 42) from dual",
expected: "select 42 as `(select 42 from dual)` from dual",
}, {
in: "select * from user where col = (select 42)",
expected: "select * from user where col = 42",
}, {
in: "select * from (select 42) as t", // this is not an expression, and should not be rewritten
expected: "select * from (select 42) as t",
}, {
in: `select (select (select (select (select (select last_insert_id()))))) as x`,
expected: "select :__lastInsertId as x from dual",
liid: true,
}}

for _, tc := range tests {
Expand Down
1 change: 1 addition & 0 deletions go/vt/sqlparser/rewriter_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,5 @@ func (c *Cursor) Parent() SQLNode { return c.parent }
// replace the object with something of the wrong type, or the visitor will panic.
func (c *Cursor) Replace(newNode SQLNode) {
c.replacer(newNode, c.parent)
c.node = newNode
}
34 changes: 17 additions & 17 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,29 +491,29 @@ func TestLastInsertIDInVirtualTable(t *testing.T) {
}

func TestLastInsertIDInSubQueryExpression(t *testing.T) {
executor, sbc1, _, _ := createLegacyExecutorEnv()
executor, sbc1, sbc2, _ := createLegacyExecutorEnv()
executor.normalize = true
result1 := []*sqltypes.Result{{
masterSession.LastInsertId = 12345
defer func() {
// clean up global state
masterSession.LastInsertId = 0
}()
rs, err := executorExec(executor, "select (select last_insert_id()) as x", nil)
require.NoError(t, err)
wantResult := &sqltypes.Result{
RowsAffected: 1,
Fields: []*querypb.Field{
{Name: "id", Type: sqltypes.Int32},
{Name: "col", Type: sqltypes.Int32},
{Name: "x", Type: sqltypes.Uint64},
},
RowsAffected: 1,
InsertID: 0,
Rows: [][]sqltypes.Value{{
sqltypes.NewInt32(1),
sqltypes.NewInt32(3),
sqltypes.NewUint64(12345),
}},
}}
sbc1.SetResults(result1)
_, err := executorExec(executor, "select (select last_insert_id()) as x", nil)
require.NoError(t, err)
wantQueries := []*querypb.BoundQuery{{
Sql: "select (select :__lastInsertId as `last_insert_id()` from dual) as x from dual",
BindVariables: map[string]*querypb.BindVariable{"__lastInsertId": sqltypes.Uint64BindVariable(0)},
}}
}
utils.MustMatch(t, rs, wantResult, "Mismatch")

assert.Equal(t, wantQueries, sbc1.Queries)
// the query will get rewritten into a simpler query that can be run entirely on the vtgate
assert.Empty(t, sbc1.Queries)
assert.Empty(t, sbc2.Queries)
}

func TestSelectDatabase(t *testing.T) {
Expand Down
36 changes: 24 additions & 12 deletions go/vt/vtgate/planbuilder/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,25 +164,37 @@ func (pb *primitiveBuilder) findOrigin(expr sqlparser.Expr) (pullouts []*pullout
case *sqlparser.ComparisonExpr:
if construct.Operator == sqlparser.InOp {
// a in (subquery) -> (:__sq_has_values = 1 and (a in ::__sq))
right := &sqlparser.ComparisonExpr{
Operator: construct.Operator,
Left: construct.Left,
Right: sqlparser.ListArg("::" + sqName),
}
left := &sqlparser.ComparisonExpr{
Left: sqlparser.NewArgument([]byte(":" + hasValues)),
Operator: sqlparser.EqualOp,
Right: sqlparser.NewIntLiteral([]byte("1")),
}
newExpr := &sqlparser.AndExpr{
Left: &sqlparser.ComparisonExpr{
Left: sqlparser.NewArgument([]byte(":" + hasValues)),
Operator: sqlparser.EqualOp,
Right: sqlparser.NewIntLiteral([]byte("1")),
},
Right: sqlparser.ReplaceExpr(construct, sqi.ast, sqlparser.ListArg([]byte("::"+sqName))),
Left: left,
Right: right,
}
expr = sqlparser.ReplaceExpr(expr, construct, newExpr)
pullouts = append(pullouts, newPulloutSubquery(engine.PulloutIn, sqName, hasValues, sqi.bldr))
} else {
// a not in (subquery) -> (:__sq_has_values = 0 or (a not in ::__sq))
left := &sqlparser.ComparisonExpr{
Left: sqlparser.NewArgument([]byte(":" + hasValues)),
Operator: sqlparser.EqualOp,
Right: sqlparser.NewIntLiteral([]byte("0")),
}
right := &sqlparser.ComparisonExpr{
Operator: construct.Operator,
Left: construct.Left,
Right: sqlparser.ListArg("::" + sqName),
}
newExpr := &sqlparser.OrExpr{
Left: &sqlparser.ComparisonExpr{
Left: sqlparser.NewArgument([]byte(":" + hasValues)),
Operator: sqlparser.EqualOp,
Right: sqlparser.NewIntLiteral([]byte("0")),
},
Right: sqlparser.ReplaceExpr(construct, sqi.ast, sqlparser.ListArg([]byte("::"+sqName))),
Left: left,
Right: right,
}
expr = sqlparser.ReplaceExpr(expr, construct, newExpr)
pullouts = append(pullouts, newPulloutSubquery(engine.PulloutNotIn, sqName, hasValues, sqi.bldr))
Expand Down
21 changes: 20 additions & 1 deletion go/vt/vtgate/planbuilder/testdata/dml_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -875,11 +875,30 @@
},
"TargetTabletType": "MASTER",
"MultiShardAutocommit": false,
"Query": "insert into unsharded values ((select 1 from dual), 1)",
"Query": "insert into unsharded values (1, 1)",
"TableName": "unsharded"
}
}

# sharded insert subquery in insert value
"insert into user(id, val) values((select 1), 1)"
{
"QueryType": "INSERT",
"Original": "insert into user(id, val) values((select 1), 1)",
"Instructions": {
"OperatorType": "Insert",
"Variant": "Sharded",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"TargetTabletType": "MASTER",
"MultiShardAutocommit": false,
"Query": "insert into user(id, val, Name, Costly) values (:_Id_0, 1, :_Name_0, :_Costly_0)",
"TableName": "user"
}
}

# insert into a routed table
"insert into route1(id) values (1)"
{
Expand Down
4 changes: 0 additions & 4 deletions go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,6 @@
"insert into user(id) select 1 from dual"
"unsupported: insert into select"

# sharded insert subquery in insert value
"insert into user(id, val) values((select 1), 1)"
"unsupported: subquery in insert values"

# sharded replace no vindex
"replace into user(val) values(1, 'foo')"
"unsupported: REPLACE INTO with sharded schema"
Expand Down

0 comments on commit c1d81af

Please sign in to comment.