diff --git a/go/mysql/sql_error.go b/go/mysql/sql_error.go index 369e7a6f0eb..c9265280e5a 100644 --- a/go/mysql/sql_error.go +++ b/go/mysql/sql_error.go @@ -163,6 +163,7 @@ var stateToMysqlCode = map[vterrors.State]struct { vterrors.AccessDeniedError: {num: ERAccessDeniedError, state: SSAccessDeniedError}, vterrors.BadDb: {num: ERBadDb, state: SSClientError}, vterrors.BadFieldError: {num: ERBadFieldError, state: SSBadFieldError}, + vterrors.BadTableError: {num: ERBadTable, state: SSUnknownTable}, vterrors.CantUseOptionHere: {num: ERCantUseOptionHere, state: SSClientError}, vterrors.DataOutOfRange: {num: ERDataOutOfRange, state: SSDataOutOfRange}, vterrors.DbCreateExists: {num: ERDbCreateExists, state: SSUnknownSQLState}, diff --git a/go/vt/vterrors/state.go b/go/vt/vterrors/state.go index 29575de8dd5..d4ad6eea7e1 100644 --- a/go/vt/vterrors/state.go +++ b/go/vt/vterrors/state.go @@ -25,6 +25,7 @@ const ( // invalid argument BadFieldError + BadTableError CantUseOptionHere DataOutOfRange EmptyQuery diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index 6ed1768f407..3caaca20399 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -55,6 +55,11 @@ func newBuildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (engine.P return nil, err } + sel, err = expandStar(sel, semTable) + if err != nil { + return nil, err + } + qgraph, err := createQGFromSelect(sel, semTable) if err != nil { return nil, err @@ -97,6 +102,111 @@ func newBuildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (engine.P return plan.Primitive(), nil } +type starRewriter struct { + err error + semTable *semantics.SemTable +} + +func (sr *starRewriter) starRewrite(cursor *sqlparser.Cursor) bool { + switch node := cursor.Node().(type) { + case *sqlparser.Select: + tables := sr.semTable.GetSelectTables(node) + var selExprs sqlparser.SelectExprs + for _, selectExpr := range node.SelectExprs { + starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr) + if !isStarExpr { + selExprs = append(selExprs, selectExpr) + continue + } + colNames, expStar, err := expandTableColumns(tables, starExpr) + if err != nil { + sr.err = err + return false + } + if !expStar.proceed { + selExprs = append(selExprs, selectExpr) + continue + } + selExprs = append(selExprs, colNames...) + for tbl, cols := range expStar.tblColMap { + sr.semTable.AddExprs(tbl, cols) + } + } + node.SelectExprs = selExprs + } + return true +} + +func expandTableColumns(tables []*semantics.TableInfo, starExpr *sqlparser.StarExpr) (sqlparser.SelectExprs, *expandStarInfo, error) { + unknownTbl := true + var colNames sqlparser.SelectExprs + expStar := &expandStarInfo{ + tblColMap: map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs{}, + } + + for _, tbl := range tables { + if !starExpr.TableName.IsEmpty() { + if !tbl.ASTNode.As.IsEmpty() { + if !starExpr.TableName.Qualifier.IsEmpty() { + continue + } + if starExpr.TableName.Name.String() != tbl.ASTNode.As.String() { + continue + } + } else { + if !starExpr.TableName.Qualifier.IsEmpty() { + if starExpr.TableName.Qualifier.String() != tbl.Table.Keyspace.Name { + continue + } + } + tblName := tbl.ASTNode.Expr.(sqlparser.TableName) + if starExpr.TableName.Name.String() != tblName.Name.String() { + continue + } + } + } + unknownTbl = false + if tbl.Table == nil || !tbl.Table.ColumnListAuthoritative { + expStar.proceed = false + break + } + expStar.proceed = true + tblName, err := tbl.ASTNode.TableName() + if err != nil { + return nil, nil, err + } + for _, col := range tbl.Table.Columns { + colNames = append(colNames, &sqlparser.AliasedExpr{ + Expr: sqlparser.NewColNameWithQualifier(col.Name.String(), tblName), + As: sqlparser.NewColIdent(col.Name.String()), + }) + } + expStar.tblColMap[tbl.ASTNode] = colNames + } + + if unknownTbl { + // This will only happen for case when starExpr has qualifier. + return nil, nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadDb, "Unknown table '%s'", sqlparser.String(starExpr.TableName)) + } + return colNames, expStar, nil +} + +type expandStarInfo struct { + proceed bool + tblColMap map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs +} + +func expandStar(sel *sqlparser.Select, semTable *semantics.SemTable) (*sqlparser.Select, error) { + // TODO we could store in semTable whether there are any * in the query that needs expanding or not + sr := &starRewriter{semTable: semTable} + + _ = sqlparser.Rewrite(sel, sr.starRewrite, nil) + if sr.err != nil { + return nil, sr.err + } + return sel, nil +} + func planLimit(limit *sqlparser.Limit, plan logicalPlan) (logicalPlan, error) { if limit == nil { return plan, nil diff --git a/go/vt/vtgate/planbuilder/route_planning_test.go b/go/vt/vtgate/planbuilder/route_planning_test.go index 70161799c35..0def30a8d07 100644 --- a/go/vt/vtgate/planbuilder/route_planning_test.go +++ b/go/vt/vtgate/planbuilder/route_planning_test.go @@ -20,6 +20,10 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vtgate/semantics" "github.com/stretchr/testify/assert" @@ -122,3 +126,154 @@ func TestClone(t *testing.T) { assert.NotNil(t, clonedRP.vindexPreds[0].foundVindex) assert.Nil(t, original.vindexPreds[0].foundVindex) } + +func TestExpandStar(t *testing.T) { + schemaInfo := &fakeSI{ + tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewTableIdent("t1"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewColIdent("a"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewColIdent("b"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewColIdent("c"), + Type: sqltypes.VarChar, + }}, + ColumnListAuthoritative: true, + }, + "t2": { + Name: sqlparser.NewTableIdent("t2"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewColIdent("c1"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewColIdent("c2"), + Type: sqltypes.VarChar, + }}, + ColumnListAuthoritative: true, + }, + "t3": { // non authoritative table. + Name: sqlparser.NewTableIdent("t3"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewColIdent("col"), + Type: sqltypes.VarChar, + }}, + ColumnListAuthoritative: false, + }, + }, + } + cDB := "db" + tcases := []struct { + sql string + expSQL string + expErr string + }{{ + sql: "select * from t1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1", + }, { + sql: "select t1.* from t1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1", + }, { + sql: "select *, 42, t1.* from t1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, 42, t1.a as a, t1.b as b, t1.c as c from t1", + }, { + sql: "select 42, t1.* from t1", + expSQL: "select 42, t1.a as a, t1.b as b, t1.c as c from t1", + }, { + sql: "select * from t1, t2", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1, t2", + }, { + sql: "select t1.* from t1, t2", + expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1, t2", + }, { + sql: "select *, t1.* from t1, t2", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t1.a as a, t1.b as b, t1.c as c from t1, t2", + }, { // aliased table + sql: "select * from t1 a, t2 b", + expSQL: "select a.a as a, a.b as b, a.c as c, b.c1 as c1, b.c2 as c2 from t1 as a, t2 as b", + }, { // t3 is non-authoritative table + sql: "select * from t3", + expSQL: "select * from t3", + }, { // t3 is non-authoritative table + sql: "select * from t1, t2, t3", + expSQL: "select * from t1, t2, t3", + }, { // t3 is non-authoritative table + sql: "select t1.*, t2.*, t3.* from t1, t2, t3", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t3.* from t1, t2, t3", + }, { + sql: "select foo.* from t1, t2", + expErr: "Unknown table 'foo'", + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.Parse(tcase.sql) + require.NoError(t, err) + semTable, err := semantics.Analyze(ast, cDB, schemaInfo) + require.NoError(t, err) + expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable) + if tcase.expErr == "" { + require.NoError(t, err) + assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect)) + } else { + require.EqualError(t, err, tcase.expErr) + } + }) + } +} + +func TestSemTableDependenciesAfterExpandStar(t *testing.T) { + schemaInfo := &fakeSI{tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewTableIdent("t1"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewColIdent("a"), + Type: sqltypes.VarChar, + }}, + ColumnListAuthoritative: true, + }}} + tcases := []struct { + sql string + expSQL string + sameTbl int + otherTbl int + expandedCol int + }{{ + sql: "select a, * from t1", + expSQL: "select a, t1.a as a from t1", + otherTbl: -1, sameTbl: 0, expandedCol: 1, + }, { + sql: "select t2.a, t1.a, t1.* from t1, t2", + expSQL: "select t2.a, t1.a, t1.a as a from t1, t2", + otherTbl: 0, sameTbl: 1, expandedCol: 2, + }, { + sql: "select t2.a, t.a, t.* from t1 t, t2", + expSQL: "select t2.a, t.a, t.a as a from t1 as t, t2", + otherTbl: 0, sameTbl: 1, expandedCol: 2, + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.Parse(tcase.sql) + require.NoError(t, err) + semTable, err := semantics.Analyze(ast, "", schemaInfo) + require.NoError(t, err) + expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable) + require.NoError(t, err) + assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect)) + if tcase.otherTbl != -1 { + assert.NotEqual(t, + semTable.Dependencies(expandedSelect.SelectExprs[tcase.otherTbl].(*sqlparser.AliasedExpr).Expr), + semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), + ) + } + if tcase.sameTbl != -1 { + assert.Equal(t, + semTable.Dependencies(expandedSelect.SelectExprs[tcase.sameTbl].(*sqlparser.AliasedExpr).Expr), + semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), + ) + } + }) + } +} diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index cc82c982669..0594b86309e 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -34,15 +34,18 @@ type ( exprDeps map[sqlparser.Expr]TableSet err error currentDb string + + selectScope map[*sqlparser.Select]*scope } ) // newAnalyzer create the semantic analyzer func newAnalyzer(dbName string, si SchemaInformation) *analyzer { return &analyzer{ - exprDeps: map[sqlparser.Expr]TableSet{}, - currentDb: dbName, - si: si, + exprDeps: map[sqlparser.Expr]TableSet{}, + selectScope: map[*sqlparser.Select]*scope{}, + currentDb: dbName, + si: si, } } @@ -54,7 +57,7 @@ func Analyze(statement sqlparser.Statement, currentDb string, si SchemaInformati if err != nil { return nil, err } - return &SemTable{exprDependencies: analyzer.exprDeps, Tables: analyzer.Tables}, nil + return &SemTable{exprDependencies: analyzer.exprDeps, Tables: analyzer.Tables, selectScope: analyzer.selectScope}, nil } // analyzeDown pushes new scopes when we encounter sub queries, @@ -65,6 +68,7 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool { switch node := n.(type) { case *sqlparser.Select: a.push(newScope(current)) + a.selectScope[node] = a.currentScope() if err := a.analyzeTableExprs(node.From); err != nil { a.err = err return false @@ -87,6 +91,7 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool { a.exprDeps[node] = t } } + // this is the visitor going down the tree. Returning false here would just not visit the children // to the current node, but that is not what we want if we have encountered an error. // In order to abort the whole visitation, we have to return true here and then return false in the `analyzeUp` method @@ -163,7 +168,7 @@ func (a *analyzer) resolveUnQualifiedColumn(current *scope, expr *sqlparser.ColN var tblInfo *TableInfo for _, tbl := range current.tables { - if !tbl.Table.ColumnListAuthoritative { + if tbl.Table == nil || !tbl.Table.ColumnListAuthoritative { return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqError, fmt.Sprintf("Column '%s' in field list is ambiguous", sqlparser.String(expr))) } for _, col := range tbl.Table.Columns { diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 3fbfb749ba4..648ad473138 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -43,6 +43,7 @@ type ( SemTable struct { Tables []*TableInfo exprDependencies map[sqlparser.Expr]TableSet + selectScope map[*sqlparser.Select]*scope } scope struct { @@ -100,6 +101,20 @@ func (st *SemTable) Dependencies(expr sqlparser.Expr) TableSet { return deps } +// GetSelectTables returns the table in the select. +func (st *SemTable) GetSelectTables(node *sqlparser.Select) []*TableInfo { + scope := st.selectScope[node] + return scope.tables +} + +// AddExprs adds new select exprs to the SemTable. +func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.SelectExprs) { + tableSet := st.TableSetFor(tbl) + for _, col := range cols { + st.exprDependencies[col.(*sqlparser.AliasedExpr).Expr] = tableSet + } +} + func newScope(parent *scope) *scope { return &scope{parent: parent} }