Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gen4: expand star in projection list #8325

Merged
merged 8 commits into from
Jun 20, 2021
1 change: 1 addition & 0 deletions go/mysql/sql_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ var stateToMysqlCode = map[vterrors.State]struct {
vterrors.AccessDeniedError: {num: ERAccessDeniedError, state: SSAccessDeniedError},
vterrors.BadDb: {num: ERBadDb, state: SSClientError},
vterrors.BadFieldError: {num: ERBadFieldError, state: SSBadFieldError},
vterrors.BadTableError: {num: ERBadTable, state: SSUnknownTable},
vterrors.CantUseOptionHere: {num: ERCantUseOptionHere, state: SSClientError},
vterrors.DataOutOfRange: {num: ERDataOutOfRange, state: SSDataOutOfRange},
vterrors.DbCreateExists: {num: ERDbCreateExists, state: SSUnknownSQLState},
Expand Down
1 change: 1 addition & 0 deletions go/vt/vterrors/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (

// invalid argument
BadFieldError
BadTableError
CantUseOptionHere
DataOutOfRange
EmptyQuery
Expand Down
110 changes: 110 additions & 0 deletions go/vt/vtgate/planbuilder/route_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ func newBuildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (engine.P
return nil, err
}

sel, err = expandStar(sel, semTable)
if err != nil {
return nil, err
}

qgraph, err := createQGFromSelect(sel, semTable)
if err != nil {
return nil, err
Expand Down Expand Up @@ -97,6 +102,111 @@ func newBuildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (engine.P
return plan.Primitive(), nil
}

type starRewriter struct {
err error
semTable *semantics.SemTable
}

func (sr *starRewriter) starRewrite(cursor *sqlparser.Cursor) bool {
switch node := cursor.Node().(type) {
case *sqlparser.Select:
tables := sr.semTable.GetSelectTables(node)
var selExprs sqlparser.SelectExprs
for _, selectExpr := range node.SelectExprs {
starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr)
if !isStarExpr {
selExprs = append(selExprs, selectExpr)
continue
}
colNames, expStar, err := expandTableColumns(tables, starExpr)
if err != nil {
sr.err = err
return false
}
if !expStar.proceed {
selExprs = append(selExprs, selectExpr)
continue
}
selExprs = append(selExprs, colNames...)
for tbl, cols := range expStar.tblColMap {
sr.semTable.AddExprs(tbl, cols)
}
}
node.SelectExprs = selExprs
}
return true
}

func expandTableColumns(tables []*semantics.TableInfo, starExpr *sqlparser.StarExpr) (sqlparser.SelectExprs, *expandStarInfo, error) {
unknownTbl := true
var colNames sqlparser.SelectExprs
expStar := &expandStarInfo{
tblColMap: map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs{},
}

for _, tbl := range tables {
if !starExpr.TableName.IsEmpty() {
if !tbl.ASTNode.As.IsEmpty() {
if !starExpr.TableName.Qualifier.IsEmpty() {
continue
}
if starExpr.TableName.Name.String() != tbl.ASTNode.As.String() {
continue
}
} else {
if !starExpr.TableName.Qualifier.IsEmpty() {
if starExpr.TableName.Qualifier.String() != tbl.Table.Keyspace.Name {
continue
}
}
tblName := tbl.ASTNode.Expr.(sqlparser.TableName)
if starExpr.TableName.Name.String() != tblName.Name.String() {
continue
}
}
}
unknownTbl = false
if tbl.Table == nil || !tbl.Table.ColumnListAuthoritative {
expStar.proceed = false
break
}
expStar.proceed = true
tblName, err := tbl.ASTNode.TableName()
if err != nil {
return nil, nil, err
}
for _, col := range tbl.Table.Columns {
colNames = append(colNames, &sqlparser.AliasedExpr{
Expr: sqlparser.NewColNameWithQualifier(col.Name.String(), tblName),
As: sqlparser.NewColIdent(col.Name.String()),
})
}
expStar.tblColMap[tbl.ASTNode] = colNames
}

if unknownTbl {
// This will only happen for case when starExpr has qualifier.
return nil, nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadDb, "Unknown table '%s'", sqlparser.String(starExpr.TableName))
}
return colNames, expStar, nil
}

type expandStarInfo struct {
proceed bool
tblColMap map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs
}

func expandStar(sel *sqlparser.Select, semTable *semantics.SemTable) (*sqlparser.Select, error) {
// TODO we could store in semTable whether there are any * in the query that needs expanding or not
sr := &starRewriter{semTable: semTable}

_ = sqlparser.Rewrite(sel, sr.starRewrite, nil)
if sr.err != nil {
return nil, sr.err
}
return sel, nil
}

func planLimit(limit *sqlparser.Limit, plan logicalPlan) (logicalPlan, error) {
if limit == nil {
return plan, nil
Expand Down
155 changes: 155 additions & 0 deletions go/vt/vtgate/planbuilder/route_planning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ import (
"fmt"
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"

"vitess.io/vitess/go/vt/vtgate/semantics"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -122,3 +126,154 @@ func TestClone(t *testing.T) {
assert.NotNil(t, clonedRP.vindexPreds[0].foundVindex)
assert.Nil(t, original.vindexPreds[0].foundVindex)
}

func TestExpandStar(t *testing.T) {
schemaInfo := &fakeSI{
tables: map[string]*vindexes.Table{
"t1": {
Name: sqlparser.NewTableIdent("t1"),
Columns: []vindexes.Column{{
Name: sqlparser.NewColIdent("a"),
Type: sqltypes.VarChar,
}, {
Name: sqlparser.NewColIdent("b"),
Type: sqltypes.VarChar,
}, {
Name: sqlparser.NewColIdent("c"),
Type: sqltypes.VarChar,
}},
ColumnListAuthoritative: true,
},
"t2": {
Name: sqlparser.NewTableIdent("t2"),
Columns: []vindexes.Column{{
Name: sqlparser.NewColIdent("c1"),
Type: sqltypes.VarChar,
}, {
Name: sqlparser.NewColIdent("c2"),
Type: sqltypes.VarChar,
}},
ColumnListAuthoritative: true,
},
"t3": { // non authoritative table.
Name: sqlparser.NewTableIdent("t3"),
Columns: []vindexes.Column{{
Name: sqlparser.NewColIdent("col"),
Type: sqltypes.VarChar,
}},
ColumnListAuthoritative: false,
},
},
}
cDB := "db"
tcases := []struct {
sql string
expSQL string
expErr string
}{{
sql: "select * from t1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select t1.* from t1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select *, 42, t1.* from t1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, 42, t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select 42, t1.* from t1",
expSQL: "select 42, t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select * from t1, t2",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1, t2",
}, {
sql: "select t1.* from t1, t2",
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1, t2",
}, {
sql: "select *, t1.* from t1, t2",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t1.a as a, t1.b as b, t1.c as c from t1, t2",
}, { // aliased table
sql: "select * from t1 a, t2 b",
expSQL: "select a.a as a, a.b as b, a.c as c, b.c1 as c1, b.c2 as c2 from t1 as a, t2 as b",
}, { // t3 is non-authoritative table
sql: "select * from t3",
expSQL: "select * from t3",
}, { // t3 is non-authoritative table
sql: "select * from t1, t2, t3",
expSQL: "select * from t1, t2, t3",
}, { // t3 is non-authoritative table
sql: "select t1.*, t2.*, t3.* from t1, t2, t3",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t3.* from t1, t2, t3",
}, {
sql: "select foo.* from t1, t2",
expErr: "Unknown table 'foo'",
}}
for _, tcase := range tcases {
t.Run(tcase.sql, func(t *testing.T) {
ast, err := sqlparser.Parse(tcase.sql)
require.NoError(t, err)
semTable, err := semantics.Analyze(ast, cDB, schemaInfo)
require.NoError(t, err)
expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable)
if tcase.expErr == "" {
require.NoError(t, err)
assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect))
} else {
require.EqualError(t, err, tcase.expErr)
}
})
}
}

func TestSemTableDependenciesAfterExpandStar(t *testing.T) {
schemaInfo := &fakeSI{tables: map[string]*vindexes.Table{
"t1": {
Name: sqlparser.NewTableIdent("t1"),
Columns: []vindexes.Column{{
Name: sqlparser.NewColIdent("a"),
Type: sqltypes.VarChar,
}},
ColumnListAuthoritative: true,
}}}
tcases := []struct {
sql string
expSQL string
sameTbl int
otherTbl int
expandedCol int
}{{
sql: "select a, * from t1",
expSQL: "select a, t1.a as a from t1",
otherTbl: -1, sameTbl: 0, expandedCol: 1,
}, {
sql: "select t2.a, t1.a, t1.* from t1, t2",
expSQL: "select t2.a, t1.a, t1.a as a from t1, t2",
otherTbl: 0, sameTbl: 1, expandedCol: 2,
}, {
sql: "select t2.a, t.a, t.* from t1 t, t2",
expSQL: "select t2.a, t.a, t.a as a from t1 as t, t2",
otherTbl: 0, sameTbl: 1, expandedCol: 2,
}}
for _, tcase := range tcases {
t.Run(tcase.sql, func(t *testing.T) {
ast, err := sqlparser.Parse(tcase.sql)
require.NoError(t, err)
semTable, err := semantics.Analyze(ast, "", schemaInfo)
require.NoError(t, err)
expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable)
require.NoError(t, err)
assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect))
if tcase.otherTbl != -1 {
assert.NotEqual(t,
semTable.Dependencies(expandedSelect.SelectExprs[tcase.otherTbl].(*sqlparser.AliasedExpr).Expr),
semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr),
)
}
if tcase.sameTbl != -1 {
assert.Equal(t,
semTable.Dependencies(expandedSelect.SelectExprs[tcase.sameTbl].(*sqlparser.AliasedExpr).Expr),
semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr),
)
}
})
}
}
15 changes: 10 additions & 5 deletions go/vt/vtgate/semantics/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,18 @@ type (
exprDeps map[sqlparser.Expr]TableSet
err error
currentDb string

selectScope map[*sqlparser.Select]*scope
}
)

// newAnalyzer create the semantic analyzer
func newAnalyzer(dbName string, si SchemaInformation) *analyzer {
return &analyzer{
exprDeps: map[sqlparser.Expr]TableSet{},
currentDb: dbName,
si: si,
exprDeps: map[sqlparser.Expr]TableSet{},
selectScope: map[*sqlparser.Select]*scope{},
currentDb: dbName,
si: si,
}
}

Expand All @@ -54,7 +57,7 @@ func Analyze(statement sqlparser.Statement, currentDb string, si SchemaInformati
if err != nil {
return nil, err
}
return &SemTable{exprDependencies: analyzer.exprDeps, Tables: analyzer.Tables}, nil
return &SemTable{exprDependencies: analyzer.exprDeps, Tables: analyzer.Tables, selectScope: analyzer.selectScope}, nil
}

// analyzeDown pushes new scopes when we encounter sub queries,
Expand All @@ -65,6 +68,7 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool {
switch node := n.(type) {
case *sqlparser.Select:
a.push(newScope(current))
a.selectScope[node] = a.currentScope()
if err := a.analyzeTableExprs(node.From); err != nil {
a.err = err
return false
Expand All @@ -87,6 +91,7 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool {
a.exprDeps[node] = t
}
}

// this is the visitor going down the tree. Returning false here would just not visit the children
// to the current node, but that is not what we want if we have encountered an error.
// In order to abort the whole visitation, we have to return true here and then return false in the `analyzeUp` method
Expand Down Expand Up @@ -163,7 +168,7 @@ func (a *analyzer) resolveUnQualifiedColumn(current *scope, expr *sqlparser.ColN

var tblInfo *TableInfo
for _, tbl := range current.tables {
if !tbl.Table.ColumnListAuthoritative {
if tbl.Table == nil || !tbl.Table.ColumnListAuthoritative {
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqError, fmt.Sprintf("Column '%s' in field list is ambiguous", sqlparser.String(expr)))
}
for _, col := range tbl.Table.Columns {
Expand Down
15 changes: 15 additions & 0 deletions go/vt/vtgate/semantics/semantic_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type (
SemTable struct {
Tables []*TableInfo
exprDependencies map[sqlparser.Expr]TableSet
selectScope map[*sqlparser.Select]*scope
}

scope struct {
Expand Down Expand Up @@ -100,6 +101,20 @@ func (st *SemTable) Dependencies(expr sqlparser.Expr) TableSet {
return deps
}

// GetSelectTables returns the table in the select.
func (st *SemTable) GetSelectTables(node *sqlparser.Select) []*TableInfo {
scope := st.selectScope[node]
return scope.tables
}

// AddExprs adds new select exprs to the SemTable.
func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.SelectExprs) {
tableSet := st.TableSetFor(tbl)
for _, col := range cols {
st.exprDependencies[col.(*sqlparser.AliasedExpr).Expr] = tableSet
}
}

func newScope(parent *scope) *scope {
return &scope{parent: parent}
}
Expand Down