Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: include schema name when checking duplicate table aliases #24663

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3721,6 +3721,27 @@ func (s *testIntegrationSerialSuite) TestMergeContinuousSelections(c *C) {
}
}

func (s *testIntegrationSuite) TestDuplicateAliasCheck(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("drop database if exists db1")
tk.MustExec("create database db1")
tk.MustExec("create table db1.t(a int)")
tk.MustExec("insert into db1.t values(1)")
tk.MustExec("drop database if exists db2")
tk.MustExec("create database db2")
tk.MustExec("create table db2.t(a int)")
tk.MustExec("insert into db2.t values(2)")
tk.MustExec("use db1")
err := tk.ExecToErr("select * from t, db1.t as t")
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "[planner:1066]Not unique table/alias: 't'")

tk.MustQuery("select * from t, db2.t as t").Check(testkit.Rows("1 2"))
err = tk.ExecToErr("select * from t, db2.t as t where t.a = t.a")
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "[planner:1052]Column 'a' in field list is ambiguous")
}

func (s *testIntegrationSerialSuite) TestEnforceMPP(c *C) {
tk := testkit.NewTestKit(c, s.store)

Expand Down
40 changes: 22 additions & 18 deletions planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TryAddExtraLimit(ctx sessionctx.Context, node ast.StmtNode) ast.StmtNode {

// Preprocess resolves table names of the node, and checks some statements validation.
func Preprocess(ctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema, preprocessOpt ...PreprocessOpt) error {
v := preprocessor{is: is, ctx: ctx, tableAliasInJoin: make([]map[string]interface{}, 0)}
v := preprocessor{is: is, ctx: ctx, tableAliasInJoin: make([]map[string]string, 0)}
for _, optFn := range preprocessOpt {
optFn(&v)
}
Expand Down Expand Up @@ -120,7 +120,7 @@ type preprocessor struct {

// tableAliasInJoin is a stack that keeps the table alias names for joins.
// len(tableAliasInJoin) may bigger than 1 because the left/right child of join may be subquery that contains `JOIN`
tableAliasInJoin []map[string]interface{}
tableAliasInJoin []map[string]string
}

func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
Expand Down Expand Up @@ -720,40 +720,44 @@ func (p *preprocessor) checkDropTableNames(tables []*ast.TableName) {

func (p *preprocessor) checkNonUniqTableAlias(stmt *ast.Join) {
if p.flag&parentIsJoin == 0 {
p.tableAliasInJoin = append(p.tableAliasInJoin, make(map[string]interface{}))
p.tableAliasInJoin = append(p.tableAliasInJoin, make(map[string]string))
}
tableAliases := p.tableAliasInJoin[len(p.tableAliasInJoin)-1]
isOracleMode := p.ctx.GetSessionVars().SQLMode&mysql.ModeOracle != 0
if !isOracleMode {
if err := isTableAliasDuplicate(stmt.Left, tableAliases); err != nil {
if err := isTableAliasDuplicate(p.ctx, stmt.Left, tableAliases); err != nil {
p.err = err
return
}
if err := isTableAliasDuplicate(stmt.Right, tableAliases); err != nil {
if err := isTableAliasDuplicate(p.ctx, stmt.Right, tableAliases); err != nil {
p.err = err
return
}
}
p.flag |= parentIsJoin
}

func isTableAliasDuplicate(node ast.ResultSetNode, tableAliases map[string]interface{}) error {
func isTableAliasDuplicate(ctx sessionctx.Context, node ast.ResultSetNode, tableAliases map[string]string) error {
if ts, ok := node.(*ast.TableSource); ok {
tabName := ts.AsName
if tabName.L == "" {
if tableNode, ok := ts.Source.(*ast.TableName); ok {
if tableNode.Schema.L != "" {
tabName = model.NewCIStr(fmt.Sprintf("%s.%s", tableNode.Schema.L, tableNode.Name.L))
} else {
tabName = tableNode.Name
}
var fullName string
tabName := ts.AsName.L
dbName := strings.ToLower(ctx.GetSessionVars().CurrentDB)
if tableNode, ok := ts.Source.(*ast.TableName); ok {
if tableNode.Schema.L != "" {
dbName = tableNode.Schema.L
}
if tabName == "" {
tabName = tableNode.Name.L
}
fullName = fmt.Sprintf("%s.%s", dbName, tabName)
} else if tabName != "" {
fullName = fmt.Sprintf("%s.%s", dbName, tabName)
}
_, exists := tableAliases[tabName.L]
if len(tabName.L) != 0 && exists {
return ErrNonUniqTable.GenWithStackByArgs(tabName)
existsName, exists := tableAliases[fullName]
if len(fullName) != 0 && exists {
return ErrNonUniqTable.GenWithStackByArgs(existsName)
}
tableAliases[tabName.L] = nil
tableAliases[fullName] = tabName
}
return nil
}
Expand Down