diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 90503f46c2c..03917fef944 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -47,6 +47,8 @@ type ( SetLock(lock Lock) MakeDistinct() GetColumnCount() int + SetComments(comments Comments) + GetComments() Comments } // DDLStatement represents any DDL Statement diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index f00f52cebe1..f695489b7fc 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -749,6 +749,16 @@ func (node *Select) GetColumnCount() int { return len(node.SelectExprs) } +// SetComments implements the SelectStatement interface +func (node *Select) SetComments(comments Comments) { + node.Comments = comments +} + +// GetComments implements the SelectStatement interface +func (node *Select) GetComments() Comments { + return node.Comments +} + // AddWhere adds the boolean expression to the // WHERE clause as an AND condition. func (node *Select) AddWhere(expr Expr) { @@ -806,6 +816,16 @@ func (node *ParenSelect) GetColumnCount() int { return node.Select.GetColumnCount() } +// SetComments implements the SelectStatement interface +func (node *ParenSelect) SetComments(comments Comments) { + node.Select.SetComments(comments) +} + +// GetComments implements the SelectStatement interface +func (node *ParenSelect) GetComments() Comments { + return node.Select.GetComments() +} + // AddWhere adds the boolean expression to the // WHERE clause as an AND condition. func (node *Update) AddWhere(expr Expr) { @@ -847,6 +867,16 @@ func (node *Union) GetColumnCount() int { return node.FirstStatement.GetColumnCount() } +// SetComments implements the SelectStatement interface +func (node *Union) SetComments(comments Comments) { + node.FirstStatement.SetComments(comments) +} + +// GetComments implements the SelectStatement interface +func (node *Union) GetComments() Comments { + return node.FirstStatement.GetComments() +} + //Unionize returns a UNION, either creating one or adding SELECT to an existing one func Unionize(lhs, rhs SelectStatement, distinct bool, by OrderBy, limit *Limit, lock Lock) *Union { union, isUnion := lhs.(*Union) diff --git a/go/vt/vtgate/planbuilder/abstract/fuzz.go b/go/vt/vtgate/planbuilder/abstract/fuzz.go index 0a032ade2ae..71f61e33b1d 100644 --- a/go/vt/vtgate/planbuilder/abstract/fuzz.go +++ b/go/vt/vtgate/planbuilder/abstract/fuzz.go @@ -50,7 +50,7 @@ func FuzzAnalyse(data []byte) int { } switch stmt := tree.(type) { case *sqlparser.Select: - semTable, err := semantics.Analyze(tree, "", &fakeFuzzSI{}) + semTable, err := semantics.Analyze(stmt, "", &fakeFuzzSI{}) if err != nil { return 0 } diff --git a/go/vt/vtgate/planbuilder/abstract/operator_test.go b/go/vt/vtgate/planbuilder/abstract/operator_test.go index 5eedf851c55..bf9971cf973 100644 --- a/go/vt/vtgate/planbuilder/abstract/operator_test.go +++ b/go/vt/vtgate/planbuilder/abstract/operator_test.go @@ -162,7 +162,7 @@ JoinPredicates: t.Run(fmt.Sprintf("%d %s", i, sql), func(t *testing.T) { tree, err := sqlparser.Parse(sql) require.NoError(t, err) - semTable, err := semantics.Analyze(tree, "", &semantics.FakeSI{}) + semTable, err := semantics.Analyze(tree.(sqlparser.SelectStatement), "", &semantics.FakeSI{}) require.NoError(t, err) optree, err := CreateOperatorFromSelect(tree.(*sqlparser.Select), semTable) require.NoError(t, err) diff --git a/go/vt/vtgate/planbuilder/expand_star_test.go b/go/vt/vtgate/planbuilder/expand_star_test.go index 62ffb23a4d5..4eb9e566816 100644 --- a/go/vt/vtgate/planbuilder/expand_star_test.go +++ b/go/vt/vtgate/planbuilder/expand_star_test.go @@ -112,9 +112,11 @@ func TestExpandStar(t *testing.T) { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.Parse(tcase.sql) require.NoError(t, err) - semTable, err := semantics.Analyze(ast, cDB, schemaInfo) + selectStatement, isSelectStatement := ast.(*sqlparser.Select) + require.True(t, isSelectStatement, "analyzer expects a select statement") + semTable, err := semantics.Analyze(selectStatement, cDB, schemaInfo) require.NoError(t, err) - expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable) + expandedSelect, err := expandStar(selectStatement, semTable) if tcase.expErr == "" { require.NoError(t, err) assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect)) @@ -158,9 +160,11 @@ func TestSemTableDependenciesAfterExpandStar(t *testing.T) { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.Parse(tcase.sql) require.NoError(t, err) - semTable, err := semantics.Analyze(ast, "", schemaInfo) + selectStatement, isSelectStatement := ast.(*sqlparser.Select) + require.True(t, isSelectStatement, "analyzer expects a select statement") + semTable, err := semantics.Analyze(selectStatement, "", schemaInfo) require.NoError(t, err) - expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable) + expandedSelect, err := expandStar(selectStatement, semTable) require.NoError(t, err) assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect)) if tcase.otherTbl != -1 { diff --git a/go/vt/vtgate/planbuilder/jointree_transformers.go b/go/vt/vtgate/planbuilder/jointree_transformers.go index 44aae7f5f4a..aff58d9d786 100644 --- a/go/vt/vtgate/planbuilder/jointree_transformers.go +++ b/go/vt/vtgate/planbuilder/jointree_transformers.go @@ -32,7 +32,7 @@ import ( func transformToLogicalPlan(tree joinTree, semTable *semantics.SemTable) (logicalPlan, error) { switch n := tree.(type) { case *routePlan: - return transformRoutePlan(n) + return transformRoutePlan(n, semTable) case *joinPlan: return transformJoinPlan(n, semTable) @@ -63,7 +63,7 @@ func transformJoinPlan(n *joinPlan, semTable *semantics.SemTable) (logicalPlan, }, nil } -func transformRoutePlan(n *routePlan) (*route, error) { +func transformRoutePlan(n *routePlan, semTable *semantics.SemTable) (*route, error) { var tablesForSelect sqlparser.TableExprs tableNameMap := map[string]interface{}{} @@ -156,6 +156,7 @@ func transformRoutePlan(n *routePlan) (*route, error) { SelectExprs: expressions, From: tablesForSelect, Where: where, + Comments: semTable.Comments, }, tables: n.solved, }, nil diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index 74a618f267f..f9807528001 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -63,11 +63,6 @@ func gen4Planner(_ string) func(sqlparser.Statement, *sqlparser.ReservedVars, Co } func newBuildSelectPlan(sel *sqlparser.Select, reservedVars *sqlparser.ReservedVars, vschema ContextVSchema) (logicalPlan, error) { - - directives := sqlparser.ExtractCommentDirectives(sel.Comments) - if len(directives) > 0 { - return nil, semantics.Gen4NotSupportedF("comment directives") - } ksName := "" if ks, _ := vschema.DefaultKeyspace(); ks != nil { ksName = ks.Name @@ -118,6 +113,17 @@ func newBuildSelectPlan(sel *sqlparser.Select, reservedVars *sqlparser.ReservedV return nil, err } + directives := sqlparser.ExtractCommentDirectives(sel.Comments) + if directives.IsSet(sqlparser.DirectiveScatterErrorsAsWarnings) { + visit(plan, func(logicalPlan logicalPlan) (bool, logicalPlan, error) { + switch plan := logicalPlan.(type) { + case *route: + plan.eroute.ScatterErrorsAsWarnings = true + } + return true, logicalPlan, nil + }) + } + return plan, nil } diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index 8c71190c5b2..a3899c346e1 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -95,6 +95,7 @@ Gen4 plan same as above ] } } +Gen4 plan same as above # select limit with timeout directive sets QueryTimeout in the route "select /*vt+ QUERY_TIMEOUT_MS=1000 */ * from user limit 10" @@ -164,6 +165,7 @@ Gen4 plan same as above ] } } +Gen4 plan same as above # select aggregation with partial scatter directive - added comments to try to confuse the hint extraction "/*VT_SPAN_CONTEXT=123*/select /*vt+ SCATTER_ERRORS_AS_WARNINGS=1 */ count(*) from user" @@ -190,6 +192,7 @@ Gen4 plan same as above ] } } +Gen4 plan same as above # select limit with partial scatter directive "select /*vt+ SCATTER_ERRORS_AS_WARNINGS=1 */ * from user limit 10" @@ -807,6 +810,7 @@ Gen4 plan same as above ] } } +Gen4 plan same as above # for update "select user.col from user join user_extra for update" @@ -1448,6 +1452,7 @@ Gen4 plan same as above "Table": "unsharded" } } +Gen4 plan same as above # testing SingleRow Projection "select 42" diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 2defc115955..44bc88e0f51 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -59,14 +59,14 @@ func newAnalyzer(dbName string, si SchemaInformation) *analyzer { } // Analyze analyzes the parsed query. -func Analyze(statement sqlparser.Statement, currentDb string, si SchemaInformation) (*SemTable, error) { +func Analyze(statement sqlparser.SelectStatement, currentDb string, si SchemaInformation) (*SemTable, error) { analyzer := newAnalyzer(currentDb, si) // Initial scope err := analyzer.analyze(statement) if err != nil { return nil, err } - return &SemTable{exprDependencies: analyzer.exprDeps, Tables: analyzer.Tables, selectScope: analyzer.rScope, ProjectionErr: analyzer.projErr}, nil + return &SemTable{exprDependencies: analyzer.exprDeps, Tables: analyzer.Tables, selectScope: analyzer.rScope, ProjectionErr: analyzer.projErr, Comments: statement.GetComments()}, nil } func (a *analyzer) setError(err error) { diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index efee37c6bce..2caec9cf235 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -99,7 +99,7 @@ func TestBindingSingleTable(t *testing.T) { t.Run(query, func(t *testing.T) { parse, err := sqlparser.Parse(query) require.NoError(t, err) - _, err = Analyze(parse, "d", &FakeSI{}) + _, err = Analyze(parse.(sqlparser.SelectStatement), "d", &FakeSI{}) require.Error(t, err) }) } @@ -222,7 +222,7 @@ func TestBindingSingleAliasedTable(t *testing.T) { t.Run(query, func(t *testing.T) { parse, err := sqlparser.Parse(query) require.NoError(t, err) - _, err = Analyze(parse, "", &FakeSI{ + _, err = Analyze(parse.(sqlparser.SelectStatement), "", &FakeSI{ Tables: map[string]*vindexes.Table{ "t": {Name: sqlparser.NewTableIdent("t")}, }, @@ -321,7 +321,7 @@ func TestBindingMultiTable(t *testing.T) { t.Run(query, func(t *testing.T) { parse, err := sqlparser.Parse(query) require.NoError(t, err) - _, err = Analyze(parse, "d", &FakeSI{ + _, err = Analyze(parse.(sqlparser.SelectStatement), "d", &FakeSI{ Tables: map[string]*vindexes.Table{ "tabl": {Name: sqlparser.NewTableIdent("tabl")}, "foo": {Name: sqlparser.NewTableIdent("foo")}, @@ -357,7 +357,7 @@ func TestNotUniqueTableName(t *testing.T) { t.Skip("derived tables not implemented") } parse, _ := sqlparser.Parse(query) - _, err := Analyze(parse, "test", &FakeSI{}) + _, err := Analyze(parse.(sqlparser.SelectStatement), "test", &FakeSI{}) require.Error(t, err) require.Contains(t, err.Error(), "Not unique table/alias") }) @@ -372,7 +372,7 @@ func TestMissingTable(t *testing.T) { for _, query := range queries { t.Run(query, func(t *testing.T) { parse, _ := sqlparser.Parse(query) - _, err := Analyze(parse, "", &FakeSI{}) + _, err := Analyze(parse.(sqlparser.SelectStatement), "", &FakeSI{}) require.Error(t, err) require.Contains(t, err.Error(), "symbol t.col not found") }) @@ -451,7 +451,7 @@ func TestUnknownColumnMap2(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { si := &FakeSI{Tables: test.schema} - tbl, err := Analyze(parse, "", si) + tbl, err := Analyze(parse.(sqlparser.SelectStatement), "", si) require.NoError(t, err) if test.err { @@ -488,7 +488,7 @@ func TestUnknownPredicate(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { si := &FakeSI{Tables: test.schema} - _, err := Analyze(parse, "", si) + _, err := Analyze(parse.(sqlparser.SelectStatement), "", si) if test.err { require.Error(t, err) } else { @@ -512,7 +512,7 @@ func TestScoping(t *testing.T) { t.Run(query.query, func(t *testing.T) { parse, err := sqlparser.Parse(query.query) require.NoError(t, err) - _, err = Analyze(parse, "user", &FakeSI{ + _, err = Analyze(parse.(sqlparser.SelectStatement), "user", &FakeSI{ Tables: map[string]*vindexes.Table{ "t": {Name: sqlparser.NewTableIdent("t")}, }, @@ -543,7 +543,7 @@ func parseAndAnalyze(t *testing.T, query, dbName string) (sqlparser.Statement, * Type: querypb.Type_VARCHAR, }} - semTable, err := Analyze(parse, dbName, &FakeSI{ + semTable, err := Analyze(parse.(sqlparser.SelectStatement), dbName, &FakeSI{ Tables: map[string]*vindexes.Table{ "t": {Name: sqlparser.NewTableIdent("t")}, "t1": {Name: sqlparser.NewTableIdent("t1"), Columns: cols1, ColumnListAuthoritative: true}, diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 1db90280295..7e0321e4266 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -72,6 +72,7 @@ type ( TableSet uint64 // we can only join 64 tables with this underlying data type // TODO : change uint64 to struct to support arbitrary number of tables. + // ExprDependencies stores the tables that an expression depends on as a map ExprDependencies map[sqlparser.Expr]TableSet // SemTable contains semantic analysis information about the query. @@ -82,6 +83,7 @@ type ( ProjectionErr error exprDependencies ExprDependencies selectScope map[*sqlparser.Select]*scope + Comments sqlparser.Comments } scope struct { @@ -297,6 +299,7 @@ func (st *SemTable) Dependencies(expr sqlparser.Expr) TableSet { return st.exprDependencies.Dependencies(expr) } +// Dependencies return the table dependencies of the expression. This method finds table dependencies recursively func (d ExprDependencies) Dependencies(expr sqlparser.Expr) TableSet { deps, found := d[expr] if found {