Skip to content

Commit

Permalink
Support mysql_grant resource update
Browse files Browse the repository at this point in the history
  • Loading branch information
winebarrel committed Apr 9, 2020
1 parent 563342a commit 7558161
Showing 1 changed file with 68 additions and 3 deletions.
71 changes: 68 additions & 3 deletions mysql/resource_grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,75 @@ func ReadGrant(d *schema.ResourceData, meta interface{}) error {
}

func UpdateGrant(d *schema.ResourceData, meta interface{}) error {
db := meta.(*MySQLConfiguration).Db

hasRoles, err := supportsRoles(db)

if err != nil {
return err
}

userOrRole, _, err := userOrRole(
d.Get("user").(string),
d.Get("host").(string),
d.Get("role").(string),
hasRoles)

if err != nil {
return err
}

database := d.Get("database").(string)
table := d.Get("table").(string)

if d.HasChange("privileges") {
oldPrivs, newPrivs := d.GetChange("plaintext_password")
log.Printf("xxx old: %v\n", oldPrivs)
log.Printf("xxx new: %v\n", newPrivs)
err = updatePrivileges(d, db, userOrRole, database, table)

if err != nil {
return err
}
}

return nil
}

func updatePrivileges(d *schema.ResourceData, db *sql.DB, user string, database string, table string) error {
oldPrivsIf, newPrivsIf := d.GetChange("privileges")
oldPrivs := oldPrivsIf.(*schema.Set)
newPrivs := newPrivsIf.(*schema.Set)
grantIfs := newPrivs.Difference(oldPrivs).List()
revokeIfs := oldPrivs.Difference(newPrivs).List()

if len(grantIfs) > 0 {
grants := make([]string, len(grantIfs))

for i, v := range grantIfs {
grants[i] = v.(string)
}

sql := fmt.Sprintf("GRANT %s ON %s.%s TO %s", strings.Join(grants, ","), database, table, user)

log.Printf("[DEBUG] SQL: %s", sql)

if _, err := db.Exec(sql); err != nil {
return err
}
}

if len(revokeIfs) > 0 {
revokes := make([]string, len(revokeIfs))

for i, v := range revokeIfs {
revokes[i] = v.(string)
}

sql := fmt.Sprintf("REVOKE %s ON %s.%s FROM %s", strings.Join(revokes, ","), database, table, user)

log.Printf("[DEBUG] SQL: %s", sql)

if _, err := db.Exec(sql); err != nil {
return err
}
}

return nil
Expand Down

0 comments on commit 7558161

Please sign in to comment.