diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 80021de840a36..9a6cb8dc520d1 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -1244,7 +1244,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { func tblInfoFromCol(from ast.ResultSetNode, col *expression.Column) *model.TableInfo { var tableList []*ast.TableName - tableList = extractTableList(from, tableList) + tableList = extractTableList(from, tableList, true) for _, field := range tableList { if field.Name.L == col.TblName.L { return field.TableInfo @@ -2144,7 +2144,7 @@ func (b *planBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { } var tableList []*ast.TableName - tableList = extractTableList(sel.From.TableRefs, tableList) + tableList = extractTableList(sel.From.TableRefs, tableList, false) for _, t := range tableList { dbName := t.Schema.L if dbName == "" { @@ -2262,6 +2262,15 @@ func (b *planBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A p = np newList = append(newList, &expression.Assignment{Col: col, Expr: newExpr}) } + for _, assign := range newList { + col := assign.Col + + dbName := col.DBName.L + if dbName == "" { + dbName = b.ctx.GetSessionVars().CurrentDB + } + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, dbName, col.OrigTblName.L, "") + } return newList, p, nil } @@ -2363,7 +2372,7 @@ func (b *planBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { del.SetSchema(expression.NewSchema()) var tableList []*ast.TableName - tableList = extractTableList(delete.TableRefs.TableRefs, tableList) + tableList = extractTableList(delete.TableRefs.TableRefs, tableList, true) // Collect visitInfo. if delete.Tables != nil { @@ -2416,14 +2425,16 @@ func (b *planBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { } // extractTableList extracts all the TableNames from node. -func extractTableList(node ast.ResultSetNode, input []*ast.TableName) []*ast.TableName { +// If asName is true, extract AsName prior to OrigName. +// Privilege check should use OrigName, while expression may use AsName. +func extractTableList(node ast.ResultSetNode, input []*ast.TableName, asName bool) []*ast.TableName { switch x := node.(type) { case *ast.Join: - input = extractTableList(x.Left, input) - input = extractTableList(x.Right, input) + input = extractTableList(x.Left, input, asName) + input = extractTableList(x.Right, input, asName) case *ast.TableSource: if s, ok := x.Source.(*ast.TableName); ok { - if x.AsName.L != "" { + if x.AsName.L != "" && asName { newTableName := *s newTableName.Name = x.AsName input = append(input, &newTableName) diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 2f12ca8d7c7b4..a548235777cdd 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -1601,6 +1601,13 @@ func (s *testPlanSuite) TestVisitInfo(c *C) { {mysql.SelectPriv, "test", "t", ""}, }, }, + { + sql: "update t a1 set a1.a = a1.a + 1", + ans: []visitInfo{ + {mysql.UpdatePriv, "test", "t", ""}, + {mysql.SelectPriv, "test", "t", ""}, + }, + }, { sql: "select a, sum(e) from t group by a", ans: []visitInfo{ diff --git a/session/session_test.go b/session/session_test.go index 74c890b246541..7b7fc125d594c 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -2348,6 +2348,24 @@ func (s *testSessionSuite) TestSetGroupConcatMaxLen(c *C) { c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue, Commentf("err %v", err)) } +func (s *testSessionSuite) TestUpdatePrivilege(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + + // Fix issue 8911 + tk.MustExec("create database weperk") + tk.MustExec("use weperk") + tk.MustExec("create table tb_wehub_server (id int, active_count int, used_count int)") + tk.MustExec("create user 'weperk'") + tk.MustExec("grant all privileges on weperk.* to 'weperk'@'%'") + tk.MustExec("flush privileges;") + + tk1 := testkit.NewTestKitWithInit(c, s.store) + c.Assert(tk1.Se.Auth(&auth.UserIdentity{Username: "weperk", Hostname: "%"}, + []byte(""), []byte("")), IsTrue) + tk1.MustExec("use weperk") + tk1.MustExec("update tb_wehub_server a set a.active_count=a.active_count+1,a.used_count=a.used_count+1 where id=1") +} + func (s *testSessionSuite) TestTxnGoString(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("drop table if exists gostr;")