Skip to content

Commit

Permalink
reduce interface assert
Browse files Browse the repository at this point in the history
  • Loading branch information
imtbkcat committed Apr 17, 2019
1 parent b82747f commit 370ef6b
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,24 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err erro
}

func (e *SimpleExec) setDefaultRoleNone(s *ast.SetDefaultRoleStmt) error {
if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "begin"); err != nil {
sqlExecutor := e.ctx.(sqlexec.SQLExecutor)
if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil {
return err
}
for _, u := range s.UserList {
if u.Hostname == "" {
u.Hostname = "%"
}
sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", u.Username, u.Hostname)
if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil {
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); rollbackErr != nil {
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
}
if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "commit"); err != nil {
if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil {
return err
}
return nil
Expand All @@ -136,17 +137,18 @@ func (e *SimpleExec) setDefaultRoleRegular(s *ast.SetDefaultRoleStmt) error {
return ErrCannotUser.GenWithStackByArgs("SET DEFAULT ROLE", role.String())
}
}
if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "begin"); err != nil {
sqlExecutor := e.ctx.(sqlexec.SQLExecutor)
if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil {
return err
}
for _, user := range s.UserList {
if user.Hostname == "" {
user.Hostname = "%"
}
sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname)
if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil {
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); rollbackErr != nil {
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
Expand All @@ -156,22 +158,22 @@ func (e *SimpleExec) setDefaultRoleRegular(s *ast.SetDefaultRoleStmt) error {
checker := privilege.GetPrivilegeManager(e.ctx)
ok := checker.FindEdge(e.ctx, role, user)
if ok {
if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil {
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); rollbackErr != nil {
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
} else {
if _, rollbackErr := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); rollbackErr != nil {
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String())
}
}
}
if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "commit"); err != nil {
if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil {
return err
}
return nil
Expand All @@ -187,31 +189,32 @@ func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error {
return ErrCannotUser.GenWithStackByArgs("SET DEFAULT ROLE", user.String())
}
}
if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "begin"); err != nil {
sqlExecutor := e.ctx.(sqlexec.SQLExecutor)
if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil {
return err
}
for _, user := range s.UserList {
if user.Hostname == "" {
user.Hostname = "%"
}
sql := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname)
if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil {
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql))
if _, rollbackErr := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); rollbackErr != nil {
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+
"SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username)
if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil {
if _, rollbackErr := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); rollbackErr != nil {
if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil {
if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil {
return rollbackErr
}
return err
}
}
if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "commit"); err != nil {
if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil {
return err
}
return nil
Expand Down

0 comments on commit 370ef6b

Please sign in to comment.