Skip to content

Commit

Permalink
bindinfo,planner: report error when creating sql binding on temporary…
Browse files Browse the repository at this point in the history
… table (#25058)
  • Loading branch information
tiancaiamao authored Jun 15, 2021
1 parent 39c503e commit bd11917
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 28 deletions.
16 changes: 16 additions & 0 deletions bindinfo/bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/bindinfo"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/metrics"
Expand Down Expand Up @@ -2098,3 +2099,18 @@ func (s *testSuite) TestBindingWithoutCharset(c *C) {
c.Assert(rows[0][0], Equals, "select * from `test` . `t` where `a` = ?")
c.Assert(rows[0][1], Equals, "SELECT * FROM `test`.`t` WHERE `a` = 'aa'")
}

func (s *testSuite) TestTemporaryTable(c *C) {
tk := testkit.NewTestKit(c, s.store)
s.cleanBindingEnv(tk)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("set tidb_enable_global_temporary_table = true")
tk.MustExec("create global temporary table t(a int, b int, key(a), key(b)) on commit delete rows")
tk.MustExec("create table t2(a int, b int, key(a), key(b))")
tk.MustGetErrCode("create session binding for select * from t where b = 123 using select * from t ignore index(b) where b = 123;", errno.ErrOptOnTemporaryTable)
tk.MustGetErrCode("create binding for insert into t select * from t2 where t2.b = 1 and t2.c > 1 using insert into t select /*+ use_index(t2,c) */ * from t2 where t2.b = 1 and t2.c > 1", errno.ErrOptOnTemporaryTable)
tk.MustGetErrCode("create binding for replace into t select * from t2 where t2.b = 1 and t2.c > 1 using replace into t select /*+ use_index(t2,c) */ * from t2 where t2.b = 1 and t2.c > 1", errno.ErrOptOnTemporaryTable)
tk.MustGetErrCode("create binding for update t set a = 1 where b = 1 and c > 1 using update /*+ use_index(t, c) */ t set a = 1 where b = 1 and c > 1", errno.ErrOptOnTemporaryTable)
tk.MustGetErrCode("create binding for delete from t where b = 1 and c > 1 using delete /*+ use_index(t, c) */ from t where b = 1 and c > 1", errno.ErrOptOnTemporaryTable)
}
97 changes: 69 additions & 28 deletions planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ import (
"github.com/pingcap/parser/format"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/ddl"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
driver "github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/util"
Expand Down Expand Up @@ -337,6 +339,37 @@ func bindableStmtType(node ast.StmtNode) byte {
return TypeInvalid
}

func (p *preprocessor) tableByName(tn *ast.TableName) (table.Table, error) {
currentDB := p.ctx.GetSessionVars().CurrentDB
if tn.Schema.String() != "" {
currentDB = tn.Schema.L
}
if currentDB == "" {
return nil, errors.Trace(ErrNoDB)
}
sName := model.NewCIStr(currentDB)
tbl, err := p.ensureInfoSchema().TableByName(sName, tn.Name)
if err != nil {
// We should never leak that the table doesn't exist (i.e. attach ErrTableNotExists)
// unless we know that the user has permissions to it, should it exist.
// By checking here, this makes all SELECT/SHOW/INSERT/UPDATE/DELETE statements safe.
currentUser, activeRoles := p.ctx.GetSessionVars().User, p.ctx.GetSessionVars().ActiveRoles
if pm := privilege.GetPrivilegeManager(p.ctx); pm != nil {
if !pm.RequestVerification(activeRoles, sName.L, tn.Name.O, "", mysql.AllPrivMask) {
u := currentUser.Username
h := currentUser.Hostname
if currentUser.AuthHostname != "" {
u = currentUser.AuthUsername
h = currentUser.AuthHostname
}
return nil, ErrTableaccessDenied.GenWithStackByArgs(p.stmtType(), u, h, tn.Name.O)
}
}
return nil, err
}
return tbl, err
}

func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode, defaultDB string) {
origTp := bindableStmtType(originNode)
hintedTp := bindableStmtType(hintedNode)
Expand All @@ -355,6 +388,39 @@ func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode, def
return
}
}

// Check the bind operation is not on any temporary table.
var resNode ast.ResultSetNode
switch n := originNode.(type) {
case *ast.SelectStmt:
resNode = n.From.TableRefs
case *ast.DeleteStmt:
resNode = n.TableRefs.TableRefs
case *ast.UpdateStmt:
resNode = n.TableRefs.TableRefs
case *ast.InsertStmt:
resNode = n.Table.TableRefs
}
if resNode != nil {
tblNames := extractTableList(resNode, nil, false)
for _, tn := range tblNames {
tbl, err := p.tableByName(tn)
if err != nil {
// If the operation is order is: drop table -> drop binding
// The table doesn't exist, it is not an error.
if terror.ErrorEqual(err, infoschema.ErrTableNotExists) {
continue
}
p.err = err
return
}
if tbl.Meta().TempTableType != model.TempTableNone {
p.err = ddl.ErrOptOnTemporaryTable.GenWithStackByArgs("create binding")
return
}
}
}

originSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(originNode, defaultDB, originNode.Text()))
hintedSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(hintedNode, defaultDB, hintedNode.Text()))
if originSQL != hintedSQL {
Expand Down Expand Up @@ -598,17 +664,7 @@ func (p *preprocessor) checkDropDatabaseGrammar(stmt *ast.DropDatabaseStmt) {

func (p *preprocessor) checkAdminCheckTableGrammar(stmt *ast.AdminStmt) {
for _, table := range stmt.Tables {
currentDB := p.ctx.GetSessionVars().CurrentDB
if table.Schema.String() != "" {
currentDB = table.Schema.L
}
if currentDB == "" {
p.err = errors.Trace(ErrNoDB)
return
}
sName := model.NewCIStr(currentDB)
tName := table.Name
tableInfo, err := p.ensureInfoSchema().TableByName(sName, tName)
tableInfo, err := p.tableByName(table)
if err != nil {
p.err = err
return
Expand Down Expand Up @@ -1270,27 +1326,12 @@ func (p *preprocessor) handleTableName(tn *ast.TableName) {
return
}

table, err := p.ensureInfoSchema().TableByName(tn.Schema, tn.Name)
table, err := p.tableByName(tn)
if err != nil {
// We should never leak that the table doesn't exist (i.e. attach ErrTableNotExists)
// unless we know that the user has permissions to it, should it exist.
// By checking here, this makes all SELECT/SHOW/INSERT/UPDATE/DELETE statements safe.
currentUser, activeRoles := p.ctx.GetSessionVars().User, p.ctx.GetSessionVars().ActiveRoles
if pm := privilege.GetPrivilegeManager(p.ctx); pm != nil {
if !pm.RequestVerification(activeRoles, tn.Schema.L, tn.Name.O, "", mysql.AllPrivMask) {
u := currentUser.Username
h := currentUser.Hostname
if currentUser.AuthHostname != "" {
u = currentUser.AuthUsername
h = currentUser.AuthHostname
}
p.err = ErrTableaccessDenied.GenWithStackByArgs(p.stmtType(), u, h, tn.Name.O)
return
}
}
p.err = err
return
}

tableInfo := table.Meta()
dbInfo, _ := p.ensureInfoSchema().SchemaByName(tn.Schema)
// tableName should be checked as sequence object.
Expand Down

0 comments on commit bd11917

Please sign in to comment.