Skip to content

Commit

Permalink
resource/cloudflare_ruleset: send ERE payload as one
Browse files Browse the repository at this point in the history
We don't need to restrict the entrypoint calls to "execute" calls; send it all as one payload.
  • Loading branch information
jacobbednarz committed Aug 29, 2021
1 parent f052274 commit 871aa14
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 26 deletions.
26 changes: 11 additions & 15 deletions cloudflare/resource_cloudflare_ruleset.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,23 +264,19 @@ func resourceCloudflareRulesetCreate(d *schema.ResourceData, meta interface{}) e
return errors.Wrap(err, fmt.Sprintf("error creating ruleset %s", d.Get("name").(string)))
}

for i, rule := range rules {
if rule.Action == string(cloudflare.RulesetRuleActionExecute) {
rulesetEntryPoint := cloudflare.Ruleset{
Description: d.Get("description").(string),
Rules: []cloudflare.RulesetRule{rules[i]},
}
rulesetEntryPoint := cloudflare.Ruleset{
Description: d.Get("description").(string),
Rules: rules,
}

if accountID != "" {
_, err = client.UpdateAccountRulesetPhase(context.Background(), accountID, rs.Phase, rulesetEntryPoint)
} else {
_, err = client.UpdateZoneRulesetPhase(context.Background(), zoneID, rs.Phase, rulesetEntryPoint)
}
if accountID != "" {
_, err = client.UpdateAccountRulesetPhase(context.Background(), accountID, rs.Phase, rulesetEntryPoint)
} else {
_, err = client.UpdateZoneRulesetPhase(context.Background(), zoneID, rs.Phase, rulesetEntryPoint)
}

if err != nil {
return errors.Wrap(err, fmt.Sprintf("error updating ruleset phase entrypoint %s", d.Get("name").(string)))
}
}
if err != nil {
return errors.Wrap(err, fmt.Sprintf("error updating ruleset phase entrypoint %s", d.Get("name").(string)))
}

d.SetId(ruleset.ID)
Expand Down
22 changes: 11 additions & 11 deletions cloudflare/resource_cloudflare_ruleset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,24 @@ func TestAccCloudflareRuleset_WAFBasic(t *testing.T) {

t.Parallel()
rnd := generateRandomResourceName()
accountID := os.Getenv("CLOUDFLARE_ACCOUNT_ID")
zoneID := os.Getenv("CLOUDFLARE_ZONE_ID")
resourceName := "cloudflare_ruleset." + rnd

resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
Steps: []resource.TestStep{
{
Config: testAccCheckCloudflareRulesetCustomWAFBasic(rnd, "my basic WAF ruleset", accountID),
Config: testAccCheckCloudflareRulesetCustomWAFBasic(rnd, "my basic WAF ruleset", zoneID),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr(resourceName, "name", "my basic WAF ruleset"),
resource.TestCheckResourceAttr(resourceName, "description", rnd+" ruleset description"),
resource.TestCheckResourceAttr(resourceName, "kind", "custom"),
resource.TestCheckResourceAttr(resourceName, "kind", "zone"),
resource.TestCheckResourceAttr(resourceName, "phase", "http_request_firewall_custom"),

resource.TestCheckResourceAttr(resourceName, "rules.#", "1"),
resource.TestCheckResourceAttr(resourceName, "rules.0.action", "log"),
resource.TestCheckResourceAttr(resourceName, "rules.0.expression", "true"),
resource.TestCheckResourceAttr(resourceName, "rules.0.action", "challenge"),
resource.TestCheckResourceAttr(resourceName, "rules.0.expression", "(ip.geoip.country eq \"GB\" or ip.geoip.country eq \"FR\") or cf.threat_score > 0"),
resource.TestCheckResourceAttr(resourceName, "rules.0.description", rnd+" ruleset rule description"),
),
},
Expand Down Expand Up @@ -605,21 +605,21 @@ func testAccCheckCloudflareRulesetMagicTransitMultiple(rnd, name, accountID stri
}`, rnd, name, accountID)
}

func testAccCheckCloudflareRulesetCustomWAFBasic(rnd, name, accountID string) string {
func testAccCheckCloudflareRulesetCustomWAFBasic(rnd, name, zoneID string) string {
return fmt.Sprintf(`
resource "cloudflare_ruleset" "%[1]s" {
account_id = "%[3]s"
zone_id = "%[3]s"
name = "%[2]s"
description = "%[1]s ruleset description"
kind = "custom"
kind = "zone"
phase = "http_request_firewall_custom"
rules {
action = "log"
expression = "true"
action = "challenge"
expression = "(ip.geoip.country eq \"GB\" or ip.geoip.country eq \"FR\") or cf.threat_score > 0"
description = "%[1]s ruleset rule description"
}
}`, rnd, name, accountID)
}`, rnd, name, zoneID)
}

func testAccCheckCloudflareRulesetManagedWAF(rnd, name, zoneID, zoneName string) string {
Expand Down

0 comments on commit 871aa14

Please sign in to comment.