diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 80021de840a36..780e1e4b9cc23 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2145,7 +2145,8 @@ func (b *planBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { var tableList []*ast.TableName tableList = extractTableList(sel.From.TableRefs, tableList) - for _, t := range tableList { + tableListForUpdate := extractTableListForUpdate(tableList, update.List) + for _, t := range tableListForUpdate { dbName := t.Schema.L if dbName == "" { dbName = b.ctx.GetSessionVars().CurrentDB @@ -2435,6 +2436,33 @@ func extractTableList(node ast.ResultSetNode, input []*ast.TableName) []*ast.Tab return input } +func extractTableListForUpdate(refTables []*ast.TableName, sets []*ast.Assignment) []*ast.TableName { + // if there's only one tbl , we won't bother + if len(refTables) == 1 { + return refTables + } + tbls := make([]*ast.TableName, 0, len(sets)) + var tblsMap map[string]*ast.TableName + + for _, a := range sets { + c := a.Column + + if c.Table.L != "" { + for _, tbl := range refTables { + if strings.EqualFold(tbl.Name.L, c.Name.L) { + tblsMap["`"+tbl.Schema.L+"`.`"+tbl.Name.L+"`"] = tbl + } + } + } + } + + for _, tbl := range tblsMap { + tbls = append(tbls, tbl) + } + + return tbls +} + // extractTableSourceAsNames extracts TableSource.AsNames from node. // if onlySelectStmt is set to be true, only extracts AsNames when TableSource.Source.(type) == *ast.SelectStmt func extractTableSourceAsNames(node ast.ResultSetNode, input []string, onlySelectStmt bool) []string { diff --git a/session/session_test.go b/session/session_test.go index 74c890b246541..285f6edf34ff2 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -2367,3 +2367,26 @@ func (s *testSessionSuite) TestTxnGoString(c *C) { tk.MustExec("rollback") c.Assert(fmt.Sprintf("%#v", txn), Equals, "Txn{state=invalid}") } + +func (s *testSessionSuite) TestUpdatePrivilege(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create user xxx;") + + tk1 := testkit.NewTestKitWithInit(c, s.store) + + // Fix issue 10028 + tk.MustExec("create database ap") + tk.MustExec("create database tp") + tk.MustExec("grant all privileges on ap.* to xxx") + tk.MustExec("grant select on tp.* to xxx") + tk.MustExec("flush privileges") + tk.MustExec("create table tp.record( id int,name varchar(128),age int)") + tk.MustExec("insert into tp.record (id,name,age) values (1,'john',18),(2,'lary',19),(3,'lily',18)") + tk.MustExec("create table ap.record( id int,name varchar(128),age int)") + tk.MustExec("insert into ap.record(id) values(1)") + c.Assert(tk1.Se.Auth(&auth.UserIdentity{Username: "xxx", Hostname: "localhost"}, + []byte(""), + []byte("")), IsTrue) + _, err2 := tk1.Exec("update ap.record t inner join tp.record tt on t.id=tt.id set t.name=tt.name") + c.Assert(err2, IsNil) +}