Skip to content

Commit

Permalink
suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
DerekTBrown authored and petoju committed Jan 23, 2024
1 parent 2dc2802 commit 309f64c
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 37 deletions.
20 changes: 13 additions & 7 deletions mysql/resource_grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ func (t *TablePrivilegeGrant) GrantOption() bool {
}

func (t *TablePrivilegeGrant) GetDatabase() string {
if strings.Compare(t.Database, "*") != 0 && !strings.HasSuffix(t.Database, "`") {
if t.Database == "*" {
return "*"
} else {
return fmt.Sprintf("`%s`", t.Database)
}
return t.Database
}

func (t *TablePrivilegeGrant) GetTable() string {
Expand All @@ -133,14 +134,18 @@ func (t *TablePrivilegeGrant) SQLGrantStatement() string {
}

func (t *TablePrivilegeGrant) SQLRevokeStatement() string {
return fmt.Sprintf("REVOKE %s ON %s.%s FROM %s", strings.Join(t.Privileges, ", "), t.GetDatabase(), t.GetTable(), t.UserOrRole.SQLString())
privs := t.Privileges
if t.Grant {
privs = append(privs, "GRANT OPTION")
}
return fmt.Sprintf("REVOKE %s ON %s.%s FROM %s", strings.Join(privs, ", "), t.GetDatabase(), t.GetTable(), t.UserOrRole.SQLString())
}

func (t *TablePrivilegeGrant) SQLPartialRevokePrivilegesStatement(privilegesToRevoke []string) string {
stmt := fmt.Sprintf("REVOKE %s ON %s.%s FROM %s", strings.Join(privilegesToRevoke, ", "), t.GetDatabase(), t.GetTable(), t.UserOrRole.SQLString())
if t.Grant {
stmt += " WITH GRANT OPTION"
privilegesToRevoke = append(privilegesToRevoke, "GRANT OPTION")
}
stmt := fmt.Sprintf("REVOKE %s ON %s.%s FROM %s", strings.Join(privilegesToRevoke, ", "), t.GetDatabase(), t.GetTable(), t.UserOrRole.SQLString())
return stmt
}

Expand Down Expand Up @@ -189,10 +194,11 @@ func (t *ProcedurePrivilegeGrant) SQLGrantStatement() string {
}

func (t *ProcedurePrivilegeGrant) SQLRevokeStatement() string {
stmt := fmt.Sprintf("REVOKE %s ON %s %s.%s FROM %s", strings.Join(t.Privileges, ", "), t.ObjectT, t.GetDatabase(), t.CallableName, t.UserOrRole.SQLString())
privs := t.Privileges
if t.Grant {
stmt += " WITH GRANT OPTION"
privs = append(privs, "GRANT OPTION")
}
stmt := fmt.Sprintf("REVOKE %s ON %s %s.%s FROM %s", strings.Join(privs, ", "), t.ObjectT, t.GetDatabase(), t.CallableName, t.UserOrRole.SQLString())
return stmt
}

Expand Down
69 changes: 39 additions & 30 deletions mysql/resource_grant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestAccGrant(t *testing.T) {
{
Config: testAccGrantConfigBasic(dbName),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "SELECT", true),
testAccPrivilege("mysql_grant.test", "SELECT", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
Expand All @@ -33,7 +33,7 @@ func TestAccGrant(t *testing.T) {
{
Config: testAccGrantConfigBasic(dbName),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "SELECT", true),
testAccPrivilege("mysql_grant.test", "SELECT", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
Expand All @@ -53,19 +53,22 @@ func TestAccGrantWithGrantOption(t *testing.T) {
{
Config: testAccGrantConfigBasic(dbName),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "SELECT", true),
testAccPrivilege("mysql_grant.test", "SELECT", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "grant", "false"),
),
},
{
Config: testAccGrantConfigBasicWithGrant(dbName),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "SELECT", true),
testAccPrivilege("mysql_grant.test", "SELECT", true, true),
resource.TestCheckResourceAttr("mysql_grant.test", "grant", "true"),
),
},
{
Config: testAccGrantConfigBasic(dbName),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "SELECT", true),
testAccPrivilege("mysql_grant.test", "SELECT", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "grant", "false"),
),
},
},
Expand All @@ -82,7 +85,7 @@ func TestAccBroken(t *testing.T) {
{
Config: testAccGrantConfigBasic(dbName),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "SELECT", true),
testAccPrivilege("mysql_grant.test", "SELECT", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
Expand All @@ -93,7 +96,7 @@ func TestAccBroken(t *testing.T) {
Config: testAccGrantConfigBroken(dbName),
ExpectError: regexp.MustCompile("already has"),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "SELECT", true),
testAccPrivilege("mysql_grant.test", "SELECT", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
Expand All @@ -117,7 +120,7 @@ func TestAccDifferentHosts(t *testing.T) {
{
Config: testAccGrantConfigExtraHost(dbName, false),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test_all", "SELECT", true),
testAccPrivilege("mysql_grant.test_all", "SELECT", true, false),
resource.TestCheckResourceAttr("mysql_grant.test_all", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test_all", "host", "%"),
resource.TestCheckResourceAttr("mysql_grant.test_all", "table", "*"),
Expand All @@ -126,7 +129,7 @@ func TestAccDifferentHosts(t *testing.T) {
{
Config: testAccGrantConfigExtraHost(dbName, true),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "SELECT", true),
testAccPrivilege("mysql_grant.test", "SELECT", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test", "host", "10.1.2.3"),
resource.TestCheckResourceAttr("mysql_grant.test", "table", "*"),
Expand Down Expand Up @@ -156,7 +159,7 @@ func TestAccGrantComplex(t *testing.T) {
{
Config: testAccGrantConfigWithPrivs(dbName, `"SELECT (c1, c2)"`),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "SELECT (c1,c2)", true),
testAccPrivilege("mysql_grant.test", "SELECT (c1,c2)", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
Expand All @@ -166,10 +169,10 @@ func TestAccGrantComplex(t *testing.T) {
{
Config: testAccGrantConfigWithPrivs(dbName, `"DROP", "SELECT (c1)", "INSERT(c3, c4)", "REFERENCES(c5)"`),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "INSERT (c3,c4)", true),
testAccPrivilege("mysql_grant.test", "SELECT (c1)", true),
testAccPrivilege("mysql_grant.test", "SELECT (c1,c2)", false),
testAccPrivilege("mysql_grant.test", "REFERENCES (c5)", true),
testAccPrivilege("mysql_grant.test", "INSERT (c3,c4)", true, false),
testAccPrivilege("mysql_grant.test", "SELECT (c1)", true, false),
testAccPrivilege("mysql_grant.test", "SELECT (c1,c2)", false, false),
testAccPrivilege("mysql_grant.test", "REFERENCES (c5)", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
Expand All @@ -179,7 +182,7 @@ func TestAccGrantComplex(t *testing.T) {
{
Config: testAccGrantConfigWithPrivs(dbName, `"DROP", "SELECT (c1)", "INSERT(c4, c3, c2)"`),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "REFERENCES (c5)", false),
testAccPrivilege("mysql_grant.test", "REFERENCES (c5)", false, false),
resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
Expand All @@ -189,7 +192,7 @@ func TestAccGrantComplex(t *testing.T) {
{
Config: testAccGrantConfigWithPrivs(dbName, `"ALL PRIVILEGES"`),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "ALL", true),
testAccPrivilege("mysql_grant.test", "ALL", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
Expand All @@ -199,7 +202,7 @@ func TestAccGrantComplex(t *testing.T) {
{
Config: testAccGrantConfigWithPrivs(dbName, `"ALL"`),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "ALL", true),
testAccPrivilege("mysql_grant.test", "ALL", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
Expand All @@ -209,11 +212,11 @@ func TestAccGrantComplex(t *testing.T) {
{
Config: testAccGrantConfigWithPrivs(dbName, `"DROP", "SELECT (c1, c2)", "INSERT(c5)", "REFERENCES(c1)"`),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "ALL", false),
testAccPrivilege("mysql_grant.test", "DROP", true),
testAccPrivilege("mysql_grant.test", "SELECT(c1,c2)", true),
testAccPrivilege("mysql_grant.test", "INSERT(c5)", true),
testAccPrivilege("mysql_grant.test", "REFERENCES(c1)", true),
testAccPrivilege("mysql_grant.test", "ALL", false, false),
testAccPrivilege("mysql_grant.test", "DROP", true, false),
testAccPrivilege("mysql_grant.test", "SELECT(c1,c2)", true, false),
testAccPrivilege("mysql_grant.test", "INSERT(c5)", true, false),
testAccPrivilege("mysql_grant.test", "REFERENCES(c1)", true, false),
resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
Expand Down Expand Up @@ -246,9 +249,9 @@ func TestAccGrantComplexMySQL8(t *testing.T) {
{
Config: testAccGrantConfigWithDynamicMySQL8(dbName),
Check: resource.ComposeTestCheckFunc(
testAccPrivilege("mysql_grant.test", "SHOW DATABASES", true),
testAccPrivilege("mysql_grant.test", "CONNECTION_ADMIN", true),
testAccPrivilege("mysql_grant.test", "SELECT", true),
testAccPrivilege("mysql_grant.test", "SHOW DATABASES", true, false),
testAccPrivilege("mysql_grant.test", "CONNECTION_ADMIN", true, false),
testAccPrivilege("mysql_grant.test", "SELECT", true, false),
),
},
},
Expand Down Expand Up @@ -277,6 +280,7 @@ func TestAccGrant_role(t *testing.T) {
Config: testAccGrantConfigRoleWithGrantOption(dbName, roleName),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr("mysql_grant.test", "role", roleName),
resource.TestCheckResourceAttr("mysql_grant.test", "grant", "true"),
),
},
{
Expand Down Expand Up @@ -346,7 +350,7 @@ func prepareTable(dbname string) resource.TestCheckFunc {
}

// Test privilege - one can condition it exists or that it doesn't exist.
func testAccPrivilege(rn string, privilege string, expectExists bool) resource.TestCheckFunc {
func testAccPrivilege(rn string, privilege string, expectExists bool, expectGrant bool) resource.TestCheckFunc {
return func(s *terraform.State) error {
rs, ok := s.RootModule().Resources[rn]
if !ok {
Expand Down Expand Up @@ -385,7 +389,7 @@ func testAccPrivilege(rn string, privilege string, expectExists bool) resource.T

privilegeNorm := normalizePerms([]string{privilege})[0]

haveGrant := false
var expectedGrant MySQLGrant

Outer:
for _, grant := range grants {
Expand All @@ -396,20 +400,24 @@ func testAccPrivilege(rn string, privilege string, expectExists bool) resource.T
for _, priv := range grantWithPrivs.GetPrivileges() {
log.Printf("[DEBUG] Checking grant %s against %s", priv, privilegeNorm)
if priv == privilegeNorm {
haveGrant = true
expectedGrant = grant
break Outer
}
}
}

if expectExists != haveGrant {
if haveGrant {
if expectExists != (expectedGrant != nil) {
if expectedGrant != nil {
return fmt.Errorf("grant %s found but it was not requested for %s", privilege, userOrRole)
} else {
return fmt.Errorf("grant %s not found for %s", privilegeNorm, userOrRole)
}
}

if expectedGrant != nil && expectedGrant.GrantOption() != expectGrant {
return fmt.Errorf("grant %s found but had incorrect grant option", privilege)
}

// We match expectations.
return nil
}
Expand Down Expand Up @@ -751,6 +759,7 @@ func testAccGrantConfigComplexRoleGrants(user string) string {
privileges = ["SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "RELOAD", "PROCESS", "REFERENCES", "INDEX", "ALTER", "SHOW DATABASES", "CREATE TEMPORARY TABLES", "LOCK TABLES", "EXECUTE", "REPLICATION SLAVE", "REPLICATION CLIENT", "CREATE VIEW", "SHOW VIEW", "CREATE ROUTINE", "ALTER ROUTINE", "CREATE USER", "EVENT", "TRIGGER"]
}`, user)
}

func prepareProcedure(dbname string, procedureName string) resource.TestCheckFunc {
return func(s *terraform.State) error {
ctx := context.Background()
Expand Down

0 comments on commit 309f64c

Please sign in to comment.