Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

privilege, executor: support SET DEFAULT ROLE #9949

Merged
merged 14 commits into from
Apr 17, 2019
141 changes: 141 additions & 0 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,152 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err erro
err = e.executeSetRole(x)
case *ast.RevokeRoleStmt:
err = e.executeRevokeRole(x)
case *ast.SetDefaultRoleStmt:
err = e.executeSetDefaultRole(x)
}
e.done = true
return err
}

func (e *SimpleExec) setDefaultRoleNone(s *ast.SetDefaultRoleStmt) error {
sqlExecutor := e.ctx.(sqlexec.SQLExecutor)
if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil {
return err
}
for _, u := range s.UserList {
if u.Hostname == "" {
u.Hostname = "%"
}
sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", u.Username, u.Hostname)
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
}
if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil {
return err
}
return nil
}

func (e *SimpleExec) setDefaultRoleRegular(s *ast.SetDefaultRoleStmt) error {
for _, user := range s.UserList {
exists, err := userExists(e.ctx, user.Username, user.Hostname)
if err != nil {
return err
}
if !exists {
return ErrCannotUser.GenWithStackByArgs("SET DEFAULT ROLE", user.String())
}
}
for _, role := range s.RoleList {
exists, err := userExists(e.ctx, role.Username, role.Hostname)
if err != nil {
return err
}
if !exists {
return ErrCannotUser.GenWithStackByArgs("SET DEFAULT ROLE", role.String())
}
}
sqlExecutor := e.ctx.(sqlexec.SQLExecutor)
if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil {
return err
}
for _, user := range s.UserList {
if user.Hostname == "" {
user.Hostname = "%"
}
sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname)
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
for _, role := range s.RoleList {
sql := fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles values('%s', '%s', '%s', '%s');", user.Hostname, user.Username, role.Hostname, role.Username)
checker := privilege.GetPrivilegeManager(e.ctx)
ok := checker.FindEdge(e.ctx, role, user)
if ok {
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
} else {
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String())
}
}
}
if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil {
return err
}
return nil
}

func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error {
for _, user := range s.UserList {
exists, err := userExists(e.ctx, user.Username, user.Hostname)
if err != nil {
return err
}
if !exists {
return ErrCannotUser.GenWithStackByArgs("SET DEFAULT ROLE", user.String())
}
}
sqlExecutor := e.ctx.(sqlexec.SQLExecutor)
if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil {
return err
}
for _, user := range s.UserList {
if user.Hostname == "" {
user.Hostname = "%"
}
sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname)
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+
"SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username)
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
}
if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil {
return err
}
return nil
}

func (e *SimpleExec) executeSetDefaultRole(s *ast.SetDefaultRoleStmt) error {
switch s.SetRoleOpt {
case ast.SetRoleAll:
return e.setDefaultRoleAll(s)
case ast.SetRoleNone:
return e.setDefaultRoleNone(s)
case ast.SetRoleRegular:
return e.setDefaultRoleRegular(s)
}
err := domain.GetDomain(e.ctx).PrivilegeHandle().Update(e.ctx.(sessionctx.Context))
return err
}

func (e *SimpleExec) executeSetRole(s *ast.SetRoleStmt) error {
checkDup := make(map[string]*auth.RoleIdentity, len(s.RoleList))
// Check whether RoleNameList contain duplicate role name.
Expand Down
50 changes: 50 additions & 0 deletions executor/simple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,56 @@ func (s *testSuite3) TestRole(c *C) {
tk.MustExec(dropRoleSQL)
}

func (s *testSuite3) TestDefaultRole(c *C) {
tk := testkit.NewTestKit(c, s.store)

createRoleSQL := `CREATE ROLE r_1, r_2, r_3, u_1;`
tk.MustExec(createRoleSQL)

tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('%','r_1','%','u_1')")
tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('%','r_2','%','u_1')")

tk.MustExec("flush privileges;")

setRoleSQL := `SET DEFAULT ROLE r_3 TO u_1;`
_, err := tk.Exec(setRoleSQL)
c.Check(err, NotNil)

setRoleSQL = `SET DEFAULT ROLE r_1 TO u_1000;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, NotNil)

setRoleSQL = `SET DEFAULT ROLE r_1, r_3 TO u_1;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, NotNil)

setRoleSQL = `SET DEFAULT ROLE r_1 TO u_1;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, IsNil)
result := tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`)
result.Check(testkit.Rows("r_1"))
setRoleSQL = `SET DEFAULT ROLE r_2 TO u_1;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, IsNil)
result = tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`)
result.Check(testkit.Rows("r_2"))

setRoleSQL = `SET DEFAULT ROLE ALL TO u_1;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, IsNil)
result = tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`)
result.Check(testkit.Rows("r_1", "r_2"))

setRoleSQL = `SET DEFAULT ROLE NONE TO u_1;`
_, err = tk.Exec(setRoleSQL)
c.Check(err, IsNil)
result = tk.MustQuery(`SELECT DEFAULT_ROLE_USER FROM mysql.default_roles WHERE USER="u_1"`)
result.Check(nil)

dropRoleSQL := `DROP USER r_1, r_2, r_3, u_1;`
tk.MustExec(dropRoleSQL)
}

func (s *testSuite3) TestUser(c *C) {
tk := testkit.NewTestKit(c, s.store)
// Make sure user test not in mysql.User.
Expand Down
4 changes: 2 additions & 2 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func (b *PlanBuilder) Build(node ast.Node) (Plan, error) {
case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt,
*ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt,
*ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt,
*ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt:
*ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt:
return b.buildSimple(node.(ast.StmtNode))
case ast.DDLNode:
return b.buildDDL(x)
Expand Down Expand Up @@ -1095,7 +1095,7 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) {
err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err)
}
case *ast.AlterUserStmt:
case *ast.AlterUserStmt, *ast.SetDefaultRoleStmt:
err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER")
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err)
case *ast.GrantStmt:
Expand Down
6 changes: 6 additions & 0 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ type Manager interface {
// ActiveRoles active roles for current session.
// The first illegal role will be returned.
ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string)

// FindEdge find if there is an edge between role and user.
FindEdge(ctx sessionctx.Context, role *auth.RoleIdentity, user *auth.UserIdentity) bool

// GetDefaultRoles returns all default roles for certain user.
GetDefaultRoles(user, host string) []*auth.RoleIdentity
}

const key keyType = 0
Expand Down
73 changes: 66 additions & 7 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ type dbRecord struct {
User string
Privileges mysql.PrivilegeType

// patChars is compiled from Host and DB, cached for pattern match performance.
// hostPatChars is compiled from Host and DB, cached for pattern match performance.
hostPatChars []byte
hostPatTypes []byte

Expand Down Expand Up @@ -105,7 +105,19 @@ type columnsPrivRecord struct {
patTypes []byte
}

// RoleGraphEdgesTable is used to cache relationship between and role.
// defaultRoleRecord is used to cache mysql.default_roles
type defaultRoleRecord struct {
Host string
User string
DefaultRoleUser string
DefaultRoleHost string

// patChars is compiled from Host, cached for pattern match performance.
patChars []byte
patTypes []byte
}

// roleGraphEdgesTable is used to cache relationship between and role.
type roleGraphEdgesTable struct {
roleList map[string]bool
}
Expand All @@ -125,11 +137,12 @@ func (g roleGraphEdgesTable) Find(user, host string) bool {

// MySQLPrivilege is the in-memory cache of mysql privilege tables.
type MySQLPrivilege struct {
User []UserRecord
DB []dbRecord
TablesPriv []tablesPrivRecord
ColumnsPriv []columnsPrivRecord
RoleGraph map[string]roleGraphEdgesTable
User []UserRecord
DB []dbRecord
TablesPriv []tablesPrivRecord
ColumnsPriv []columnsPrivRecord
DefaultRoles []defaultRoleRecord
RoleGraph map[string]roleGraphEdgesTable
}

// FindRole is used to detect whether there is edges between users and roles.
Expand Down Expand Up @@ -166,6 +179,14 @@ func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error {
log.Warn("mysql.tables_priv missing")
}

err = p.LoadDefaultRoles(ctx)
if err != nil {
if !noSuchTable(err) {
return errors.Trace(err)
}
log.Warn("mysql.default_roles missing")
}

err = p.LoadColumnsPrivTable(ctx)
if err != nil {
if !noSuchTable(err) {
Expand Down Expand Up @@ -316,6 +337,11 @@ func (p *MySQLPrivilege) LoadColumnsPrivTable(ctx sessionctx.Context) error {
return p.loadTable(ctx, "select HIGH_PRIORITY Host,DB,User,Table_name,Column_name,Timestamp,Column_priv from mysql.columns_priv", p.decodeColumnsPrivTableRow)
}

// LoadDefaultRoles loads the mysql.columns_priv table from database.
func (p *MySQLPrivilege) LoadDefaultRoles(ctx sessionctx.Context) error {
return p.loadTable(ctx, "select HOST, USER, DEFAULT_ROLE_HOST, DEFAULT_ROLE_USER from mysql.default_roles", p.decodeDefaultRoleTableRow)
}

func (p *MySQLPrivilege) loadTable(sctx sessionctx.Context, sql string,
decodeTableRow func(chunk.Row, []*ast.ResultField) error) error {
ctx := context.Background()
Expand Down Expand Up @@ -455,6 +481,25 @@ func (p *MySQLPrivilege) decodeRoleEdgesTable(row chunk.Row, fs []*ast.ResultFie
return nil
}

func (p *MySQLPrivilege) decodeDefaultRoleTableRow(row chunk.Row, fs []*ast.ResultField) error {
var value defaultRoleRecord
for i, f := range fs {
switch {
case f.ColumnAsName.L == "host":
value.Host = row.GetString(i)
value.patChars, value.patTypes = stringutil.CompilePattern(value.Host, '\\')
case f.ColumnAsName.L == "user":
value.User = row.GetString(i)
case f.ColumnAsName.L == "default_role_host":
value.DefaultRoleHost = row.GetString(i)
case f.ColumnAsName.L == "default_role_user":
value.DefaultRoleUser = row.GetString(i)
}
}
p.DefaultRoles = append(p.DefaultRoles, value)
return nil
}

func (p *MySQLPrivilege) decodeColumnsPrivTableRow(row chunk.Row, fs []*ast.ResultField) error {
var value columnsPrivRecord
for i, f := range fs {
Expand Down Expand Up @@ -522,6 +567,10 @@ func (record *columnsPrivRecord) match(user, host, db, table, col string) bool {
patternMatch(host, record.patChars, record.patTypes)
}

func (record *defaultRoleRecord) match(user, host string) bool {
return record.User == user && patternMatch(host, record.patChars, record.patTypes)
}

// patternMatch matches "%" the same way as ".*" in regular expression, for example,
// "10.0.%" would match "10.0.1" "10.0.1.118" ...
func patternMatch(str string, patChars, patTypes []byte) bool {
Expand Down Expand Up @@ -766,6 +815,16 @@ func appendUserPrivilegesTableRow(rows [][]types.Datum, user UserRecord) [][]typ
return rows
}

func (p *MySQLPrivilege) getDefaultRoles(user, host string) []*auth.RoleIdentity {
ret := make([]*auth.RoleIdentity, 0)
for _, r := range p.DefaultRoles {
if r.match(user, host) {
ret = append(ret, &auth.RoleIdentity{Username: r.DefaultRoleUser, Hostname: r.DefaultRoleHost})
}
}
return ret
}

// Handle wraps MySQLPrivilege providing thread safe access.
type Handle struct {
priv atomic.Value
Expand Down
19 changes: 19 additions & 0 deletions privilege/privileges/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,25 @@ func (s *testCacheSuite) TestLoadColumnsPrivTable(c *C) {
c.Assert(p.ColumnsPriv[1].ColumnPriv, Equals, mysql.SelectPriv)
}

func (s *testCacheSuite) TestLoadDefaultRoleTable(c *C) {
se, err := session.CreateSession4Test(s.store)
c.Assert(err, IsNil)
defer se.Close()
mustExec(c, se, "use mysql;")
mustExec(c, se, "truncate table default_roles")

mustExec(c, se, `INSERT INTO mysql.default_roles VALUES ("%", "test_default_roles", "localhost", "r_1")`)
mustExec(c, se, `INSERT INTO mysql.default_roles VALUES ("%", "test_default_roles", "localhost", "r_2")`)
var p privileges.MySQLPrivilege
err = p.LoadDefaultRoles(se)
c.Assert(err, IsNil)
c.Assert(p.DefaultRoles[0].Host, Equals, `%`)
c.Assert(p.DefaultRoles[0].User, Equals, "test_default_roles")
c.Assert(p.DefaultRoles[0].DefaultRoleHost, Equals, "localhost")
c.Assert(p.DefaultRoles[0].DefaultRoleUser, Equals, "r_1")
c.Assert(p.DefaultRoles[1].DefaultRoleHost, Equals, "localhost")
}

func (s *testCacheSuite) TestPatternMatch(c *C) {
se, err := session.CreateSession4Test(s.store)
c.Assert(err, IsNil)
Expand Down
Loading