Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gen4: Adds support for comments and comment directives in Gen4 #8547

Merged
merged 1 commit into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ type (
SetLock(lock Lock)
MakeDistinct()
GetColumnCount() int
SetComments(comments Comments)
GetComments() Comments
}

// DDLStatement represents any DDL Statement
Expand Down
30 changes: 30 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,16 @@ func (node *Select) GetColumnCount() int {
return len(node.SelectExprs)
}

// SetComments implements the SelectStatement interface
func (node *Select) SetComments(comments Comments) {
node.Comments = comments
}

// GetComments implements the SelectStatement interface
func (node *Select) GetComments() Comments {
return node.Comments
}

// AddWhere adds the boolean expression to the
// WHERE clause as an AND condition.
func (node *Select) AddWhere(expr Expr) {
Expand Down Expand Up @@ -806,6 +816,16 @@ func (node *ParenSelect) GetColumnCount() int {
return node.Select.GetColumnCount()
}

// SetComments implements the SelectStatement interface
func (node *ParenSelect) SetComments(comments Comments) {
node.Select.SetComments(comments)
}

// GetComments implements the SelectStatement interface
func (node *ParenSelect) GetComments() Comments {
return node.Select.GetComments()
}

// AddWhere adds the boolean expression to the
// WHERE clause as an AND condition.
func (node *Update) AddWhere(expr Expr) {
Expand Down Expand Up @@ -847,6 +867,16 @@ func (node *Union) GetColumnCount() int {
return node.FirstStatement.GetColumnCount()
}

// SetComments implements the SelectStatement interface
func (node *Union) SetComments(comments Comments) {
node.FirstStatement.SetComments(comments)
}

// GetComments implements the SelectStatement interface
func (node *Union) GetComments() Comments {
return node.FirstStatement.GetComments()
}

//Unionize returns a UNION, either creating one or adding SELECT to an existing one
func Unionize(lhs, rhs SelectStatement, distinct bool, by OrderBy, limit *Limit, lock Lock) *Union {
union, isUnion := lhs.(*Union)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/abstract/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func FuzzAnalyse(data []byte) int {
}
switch stmt := tree.(type) {
case *sqlparser.Select:
semTable, err := semantics.Analyze(tree, "", &fakeFuzzSI{})
semTable, err := semantics.Analyze(stmt, "", &fakeFuzzSI{})
if err != nil {
return 0
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/abstract/operator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ JoinPredicates:
t.Run(fmt.Sprintf("%d %s", i, sql), func(t *testing.T) {
tree, err := sqlparser.Parse(sql)
require.NoError(t, err)
semTable, err := semantics.Analyze(tree, "", &semantics.FakeSI{})
semTable, err := semantics.Analyze(tree.(sqlparser.SelectStatement), "", &semantics.FakeSI{})
require.NoError(t, err)
optree, err := CreateOperatorFromSelect(tree.(*sqlparser.Select), semTable)
require.NoError(t, err)
Expand Down
12 changes: 8 additions & 4 deletions go/vt/vtgate/planbuilder/expand_star_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ func TestExpandStar(t *testing.T) {
t.Run(tcase.sql, func(t *testing.T) {
ast, err := sqlparser.Parse(tcase.sql)
require.NoError(t, err)
semTable, err := semantics.Analyze(ast, cDB, schemaInfo)
selectStatement, isSelectStatement := ast.(*sqlparser.Select)
require.True(t, isSelectStatement, "analyzer expects a select statement")
semTable, err := semantics.Analyze(selectStatement, cDB, schemaInfo)
require.NoError(t, err)
expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable)
expandedSelect, err := expandStar(selectStatement, semTable)
if tcase.expErr == "" {
require.NoError(t, err)
assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect))
Expand Down Expand Up @@ -158,9 +160,11 @@ func TestSemTableDependenciesAfterExpandStar(t *testing.T) {
t.Run(tcase.sql, func(t *testing.T) {
ast, err := sqlparser.Parse(tcase.sql)
require.NoError(t, err)
semTable, err := semantics.Analyze(ast, "", schemaInfo)
selectStatement, isSelectStatement := ast.(*sqlparser.Select)
require.True(t, isSelectStatement, "analyzer expects a select statement")
semTable, err := semantics.Analyze(selectStatement, "", schemaInfo)
require.NoError(t, err)
expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable)
expandedSelect, err := expandStar(selectStatement, semTable)
require.NoError(t, err)
assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect))
if tcase.otherTbl != -1 {
Expand Down
5 changes: 3 additions & 2 deletions go/vt/vtgate/planbuilder/jointree_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
func transformToLogicalPlan(tree joinTree, semTable *semantics.SemTable) (logicalPlan, error) {
switch n := tree.(type) {
case *routePlan:
return transformRoutePlan(n)
return transformRoutePlan(n, semTable)

case *joinPlan:
return transformJoinPlan(n, semTable)
Expand Down Expand Up @@ -63,7 +63,7 @@ func transformJoinPlan(n *joinPlan, semTable *semantics.SemTable) (logicalPlan,
}, nil
}

func transformRoutePlan(n *routePlan) (*route, error) {
func transformRoutePlan(n *routePlan, semTable *semantics.SemTable) (*route, error) {
var tablesForSelect sqlparser.TableExprs
tableNameMap := map[string]interface{}{}

Expand Down Expand Up @@ -156,6 +156,7 @@ func transformRoutePlan(n *routePlan) (*route, error) {
SelectExprs: expressions,
From: tablesForSelect,
Where: where,
Comments: semTable.Comments,
},
tables: n.solved,
}, nil
Expand Down
16 changes: 11 additions & 5 deletions go/vt/vtgate/planbuilder/route_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,6 @@ func gen4Planner(_ string) func(sqlparser.Statement, *sqlparser.ReservedVars, Co
}

func newBuildSelectPlan(sel *sqlparser.Select, reservedVars *sqlparser.ReservedVars, vschema ContextVSchema) (logicalPlan, error) {

directives := sqlparser.ExtractCommentDirectives(sel.Comments)
if len(directives) > 0 {
return nil, semantics.Gen4NotSupportedF("comment directives")
}
ksName := ""
if ks, _ := vschema.DefaultKeyspace(); ks != nil {
ksName = ks.Name
Expand Down Expand Up @@ -118,6 +113,17 @@ func newBuildSelectPlan(sel *sqlparser.Select, reservedVars *sqlparser.ReservedV
return nil, err
}

directives := sqlparser.ExtractCommentDirectives(sel.Comments)
if directives.IsSet(sqlparser.DirectiveScatterErrorsAsWarnings) {
visit(plan, func(logicalPlan logicalPlan) (bool, logicalPlan, error) {
switch plan := logicalPlan.(type) {
case *route:
plan.eroute.ScatterErrorsAsWarnings = true
}
return true, logicalPlan, nil
})
}

return plan, nil
}

Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/select_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Gen4 plan same as above
]
}
}
Gen4 plan same as above

# select limit with timeout directive sets QueryTimeout in the route
"select /*vt+ QUERY_TIMEOUT_MS=1000 */ * from user limit 10"
Expand Down Expand Up @@ -164,6 +165,7 @@ Gen4 plan same as above
]
}
}
Gen4 plan same as above

# select aggregation with partial scatter directive - added comments to try to confuse the hint extraction
"/*VT_SPAN_CONTEXT=123*/select /*vt+ SCATTER_ERRORS_AS_WARNINGS=1 */ count(*) from user"
Expand All @@ -190,6 +192,7 @@ Gen4 plan same as above
]
}
}
Gen4 plan same as above

# select limit with partial scatter directive
"select /*vt+ SCATTER_ERRORS_AS_WARNINGS=1 */ * from user limit 10"
Expand Down Expand Up @@ -807,6 +810,7 @@ Gen4 plan same as above
]
}
}
Gen4 plan same as above

# for update
"select user.col from user join user_extra for update"
Expand Down Expand Up @@ -1448,6 +1452,7 @@ Gen4 plan same as above
"Table": "unsharded"
}
}
Gen4 plan same as above

# testing SingleRow Projection
"select 42"
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/semantics/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ func newAnalyzer(dbName string, si SchemaInformation) *analyzer {
}

// Analyze analyzes the parsed query.
func Analyze(statement sqlparser.Statement, currentDb string, si SchemaInformation) (*SemTable, error) {
func Analyze(statement sqlparser.SelectStatement, currentDb string, si SchemaInformation) (*SemTable, error) {
analyzer := newAnalyzer(currentDb, si)
// Initial scope
err := analyzer.analyze(statement)
if err != nil {
return nil, err
}
return &SemTable{exprDependencies: analyzer.exprDeps, Tables: analyzer.Tables, selectScope: analyzer.rScope, ProjectionErr: analyzer.projErr}, nil
return &SemTable{exprDependencies: analyzer.exprDeps, Tables: analyzer.Tables, selectScope: analyzer.rScope, ProjectionErr: analyzer.projErr, Comments: statement.GetComments()}, nil
}

func (a *analyzer) setError(err error) {
Expand Down
18 changes: 9 additions & 9 deletions go/vt/vtgate/semantics/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func TestBindingSingleTable(t *testing.T) {
t.Run(query, func(t *testing.T) {
parse, err := sqlparser.Parse(query)
require.NoError(t, err)
_, err = Analyze(parse, "d", &FakeSI{})
_, err = Analyze(parse.(sqlparser.SelectStatement), "d", &FakeSI{})
require.Error(t, err)
})
}
Expand Down Expand Up @@ -222,7 +222,7 @@ func TestBindingSingleAliasedTable(t *testing.T) {
t.Run(query, func(t *testing.T) {
parse, err := sqlparser.Parse(query)
require.NoError(t, err)
_, err = Analyze(parse, "", &FakeSI{
_, err = Analyze(parse.(sqlparser.SelectStatement), "", &FakeSI{
Tables: map[string]*vindexes.Table{
"t": {Name: sqlparser.NewTableIdent("t")},
},
Expand Down Expand Up @@ -321,7 +321,7 @@ func TestBindingMultiTable(t *testing.T) {
t.Run(query, func(t *testing.T) {
parse, err := sqlparser.Parse(query)
require.NoError(t, err)
_, err = Analyze(parse, "d", &FakeSI{
_, err = Analyze(parse.(sqlparser.SelectStatement), "d", &FakeSI{
Tables: map[string]*vindexes.Table{
"tabl": {Name: sqlparser.NewTableIdent("tabl")},
"foo": {Name: sqlparser.NewTableIdent("foo")},
Expand Down Expand Up @@ -357,7 +357,7 @@ func TestNotUniqueTableName(t *testing.T) {
t.Skip("derived tables not implemented")
}
parse, _ := sqlparser.Parse(query)
_, err := Analyze(parse, "test", &FakeSI{})
_, err := Analyze(parse.(sqlparser.SelectStatement), "test", &FakeSI{})
require.Error(t, err)
require.Contains(t, err.Error(), "Not unique table/alias")
})
Expand All @@ -372,7 +372,7 @@ func TestMissingTable(t *testing.T) {
for _, query := range queries {
t.Run(query, func(t *testing.T) {
parse, _ := sqlparser.Parse(query)
_, err := Analyze(parse, "", &FakeSI{})
_, err := Analyze(parse.(sqlparser.SelectStatement), "", &FakeSI{})
require.Error(t, err)
require.Contains(t, err.Error(), "symbol t.col not found")
})
Expand Down Expand Up @@ -451,7 +451,7 @@ func TestUnknownColumnMap2(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
si := &FakeSI{Tables: test.schema}
tbl, err := Analyze(parse, "", si)
tbl, err := Analyze(parse.(sqlparser.SelectStatement), "", si)
require.NoError(t, err)

if test.err {
Expand Down Expand Up @@ -488,7 +488,7 @@ func TestUnknownPredicate(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
si := &FakeSI{Tables: test.schema}
_, err := Analyze(parse, "", si)
_, err := Analyze(parse.(sqlparser.SelectStatement), "", si)
if test.err {
require.Error(t, err)
} else {
Expand All @@ -512,7 +512,7 @@ func TestScoping(t *testing.T) {
t.Run(query.query, func(t *testing.T) {
parse, err := sqlparser.Parse(query.query)
require.NoError(t, err)
_, err = Analyze(parse, "user", &FakeSI{
_, err = Analyze(parse.(sqlparser.SelectStatement), "user", &FakeSI{
Tables: map[string]*vindexes.Table{
"t": {Name: sqlparser.NewTableIdent("t")},
},
Expand Down Expand Up @@ -543,7 +543,7 @@ func parseAndAnalyze(t *testing.T, query, dbName string) (sqlparser.Statement, *
Type: querypb.Type_VARCHAR,
}}

semTable, err := Analyze(parse, dbName, &FakeSI{
semTable, err := Analyze(parse.(sqlparser.SelectStatement), dbName, &FakeSI{
Tables: map[string]*vindexes.Table{
"t": {Name: sqlparser.NewTableIdent("t")},
"t1": {Name: sqlparser.NewTableIdent("t1"), Columns: cols1, ColumnListAuthoritative: true},
Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/semantics/semantic_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type (
TableSet uint64 // we can only join 64 tables with this underlying data type
// TODO : change uint64 to struct to support arbitrary number of tables.

// ExprDependencies stores the tables that an expression depends on as a map
ExprDependencies map[sqlparser.Expr]TableSet

// SemTable contains semantic analysis information about the query.
Expand All @@ -82,6 +83,7 @@ type (
ProjectionErr error
exprDependencies ExprDependencies
selectScope map[*sqlparser.Select]*scope
Comments sqlparser.Comments
}

scope struct {
Expand Down Expand Up @@ -297,6 +299,7 @@ func (st *SemTable) Dependencies(expr sqlparser.Expr) TableSet {
return st.exprDependencies.Dependencies(expr)
}

// Dependencies return the table dependencies of the expression. This method finds table dependencies recursively
func (d ExprDependencies) Dependencies(expr sqlparser.Expr) TableSet {
deps, found := d[expr]
if found {
Expand Down