Skip to content

Commit

Permalink
Merge pull request #38309 from hashicorp/f-wafv2_acl_rule
Browse files Browse the repository at this point in the history
Add JSON web ACL rule attribute
  • Loading branch information
jar-b authored Aug 1, 2024
2 parents 220214f + e461f53 commit 153d847
Show file tree
Hide file tree
Showing 7 changed files with 392 additions and 117 deletions.
3 changes: 3 additions & 0 deletions .changelog/38309.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_wafv2_web_acl: Add `rule_json` attribute to allow raw JSON for rules.
```
19 changes: 19 additions & 0 deletions internal/service/wafv2/flex.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
package wafv2

import (
"encoding/json"
"fmt"
"reflect"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -980,6 +983,22 @@ func expandHeaderMatchPattern(l []interface{}) *awstypes.HeaderMatchPattern {
return f
}

func expandWebACLRulesJSON(rawRules string) ([]awstypes.Rule, error) {
var rules []awstypes.Rule

err := json.Unmarshal([]byte(rawRules), &rules)
if err != nil {
return nil, fmt.Errorf("decoding JSON: %s", err)
}

for i, r := range rules {
if reflect.DeepEqual(r, awstypes.Rule{}) {
return nil, fmt.Errorf("invalid ACL Rule supplied at index (%d)", i)
}
}
return rules, nil
}

func expandWebACLRules(l []interface{}) []awstypes.Rule {
if len(l) == 0 || l[0] == nil {
return nil
Expand Down
98 changes: 98 additions & 0 deletions internal/service/wafv2/flex_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package wafv2

import (
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
awstypes "github.com/aws/aws-sdk-go-v2/service/wafv2/types"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)

func Test_expandWebACLRulesJSON(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
rawRules string
want []awstypes.Rule
wantErr bool
}{
"empty string": {
rawRules: "",
wantErr: true,
},
"empty array": {
rawRules: "[]",
want: []awstypes.Rule{},
},
"single empty object": {
rawRules: "[{}]",
wantErr: true,
},
"single null object": {
rawRules: "[null]",
wantErr: true,
},
"valid object": {
rawRules: `[{"Action":{"Count":{}},"Name":"rule-1","Priority":1,"Statement":{"RateBasedStatement":{"AggregateKeyType":"IP","EvaluationWindowSec":600,"Limit":10000,"ScopeDownStatement":{"GeoMatchStatement":{"CountryCodes":["US","NL"]}}}},"VisibilityConfig":{"CloudwatchMetricsEnabled":false,"MetricName":"friendly-rule-metric-name","SampledRequestsEnabled":false}}]`,
want: []awstypes.Rule{
{
Name: aws.String("rule-1"),
Priority: 1,
Action: &awstypes.RuleAction{
Count: &awstypes.CountAction{},
},
Statement: &awstypes.Statement{
RateBasedStatement: &awstypes.RateBasedStatement{
Limit: aws.Int64(10000),
AggregateKeyType: awstypes.RateBasedStatementAggregateKeyType("IP"),
EvaluationWindowSec: 600,
ScopeDownStatement: &awstypes.Statement{
GeoMatchStatement: &awstypes.GeoMatchStatement{
CountryCodes: []awstypes.CountryCode{"US", "NL"},
},
},
},
},
VisibilityConfig: &awstypes.VisibilityConfig{
CloudWatchMetricsEnabled: false,
MetricName: aws.String("friendly-rule-metric-name"),
SampledRequestsEnabled: false,
},
},
},
},
"valid and empty object": {
rawRules: `[{"Action":{"Count":{}},"Name":"rule-1","Priority":1,"Statement":{"RateBasedStatement":{"AggregateKeyType":"IP","EvaluationWindowSec":600,"Limit":10000,"ScopeDownStatement":{"GeoMatchStatement":{"CountryCodes":["US","NL"]}}}},"VisibilityConfig":{"CloudwatchMetricsEnabled":false,"MetricName":"friendly-rule-metric-name","SampledRequestsEnabled":false}},{}]`,
wantErr: true,
},
}

ignoreExportedOpts := cmpopts.IgnoreUnexported(
awstypes.Rule{},
awstypes.RuleAction{},
awstypes.CountAction{},
awstypes.Statement{},
awstypes.RateBasedStatement{},
awstypes.GeoMatchStatement{},
awstypes.VisibilityConfig{},
)

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()

got, err := expandWebACLRulesJSON(tc.rawRules)
if (err != nil) != tc.wantErr {
t.Errorf("expandWebACLRulesJSON() error = %v, wantErr %v", err, tc.wantErr)
return
}
if diff := cmp.Diff(got, tc.want, ignoreExportedOpts); diff != "" {
t.Errorf("unexpected diff (+wanted, -got): %s", diff)
}
})
}
}
64 changes: 57 additions & 7 deletions internal/service/wafv2/web_acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/structure"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/hashicorp/terraform-provider-aws/internal/conns"
"github.com/hashicorp/terraform-provider-aws/internal/enum"
Expand Down Expand Up @@ -101,9 +102,21 @@ func resourceWebACL() *schema.Resource {
validation.StringMatch(regexache.MustCompile(`^[0-9A-Za-z_-]+$`), "must contain only alphanumeric hyphen and underscore characters"),
),
},
"rule_json": {
Type: schema.TypeString,
Optional: true,
ConflictsWith: []string{names.AttrRule},
ValidateFunc: validation.StringIsJSON,
DiffSuppressFunc: verify.SuppressEquivalentJSONDiffs,
StateFunc: func(v interface{}) string {
json, _ := structure.NormalizeJsonString(v)
return json
},
},
names.AttrRule: {
Type: schema.TypeSet,
Optional: true,
Type: schema.TypeSet,
Optional: true,
ConflictsWith: []string{"rule_json"},
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
names.AttrAction: {
Expand Down Expand Up @@ -179,18 +192,30 @@ func resourceWebACLCreate(ctx context.Context, d *schema.ResourceData, meta inte
conn := meta.(*conns.AWSClient).WAFV2Client(ctx)

name := d.Get(names.AttrName).(string)

input := &wafv2.CreateWebACLInput{
AssociationConfig: expandAssociationConfig(d.Get("association_config").([]interface{})),
CaptchaConfig: expandCaptchaConfig(d.Get("captcha_config").([]interface{})),
ChallengeConfig: expandChallengeConfig(d.Get("challenge_config").([]interface{})),
DefaultAction: expandDefaultAction(d.Get(names.AttrDefaultAction).([]interface{})),
Name: aws.String(name),
Rules: expandWebACLRules(d.Get(names.AttrRule).(*schema.Set).List()),
Scope: awstypes.Scope(d.Get(names.AttrScope).(string)),
Tags: getTagsIn(ctx),
VisibilityConfig: expandVisibilityConfig(d.Get("visibility_config").([]interface{})),
}

if v, ok := d.GetOk(names.AttrRule); ok {
input.Rules = expandWebACLRules(v.(*schema.Set).List())
}

if v, ok := d.GetOk("rule_json"); ok {
rules, err := expandWebACLRulesJSON(v.(string))
if err != nil {
return sdkdiag.AppendErrorf(diags, "setting rule: %s", err)
}
input.Rules = rules
}

if v, ok := d.GetOk("custom_response_body"); ok && v.(*schema.Set).Len() > 0 {
input.CustomResponseBodies = expandCustomResponseBodies(v.(*schema.Set).List())
}
Expand Down Expand Up @@ -259,10 +284,16 @@ func resourceWebACLRead(ctx context.Context, d *schema.ResourceData, meta interf
d.Set(names.AttrDescription, webACL.Description)
d.Set("lock_token", output.LockToken)
d.Set(names.AttrName, webACL.Name)
rules := filterWebACLRules(webACL.Rules, expandWebACLRules(d.Get(names.AttrRule).(*schema.Set).List()))
if err := d.Set(names.AttrRule, flattenWebACLRules(rules)); err != nil {
return sdkdiag.AppendErrorf(diags, "setting rule: %s", err)

if _, ok := d.GetOk(names.AttrRule); ok {
rules := filterWebACLRules(webACL.Rules, expandWebACLRules(d.Get(names.AttrRule).(*schema.Set).List()))
if err := d.Set(names.AttrRule, flattenWebACLRules(rules)); err != nil {
return sdkdiag.AppendErrorf(diags, "setting rule: %s", err)
}
}

d.Set("rule_json", d.Get("rule_json"))

d.Set("token_domains", aws.StringSlice(webACL.TokenDomains))
if err := d.Set("visibility_config", flattenVisibilityConfig(webACL.VisibilityConfig)); err != nil {
return sdkdiag.AppendErrorf(diags, "setting visibility_config: %s", err)
Expand All @@ -281,7 +312,9 @@ func resourceWebACLUpdate(ctx context.Context, d *schema.ResourceData, meta inte
aclLockToken := d.Get("lock_token").(string)
// Find the AWS managed ShieldMitigationRuleGroup group rule if existent and add it into the set of rules to update
// so that the provider will not remove the Shield rule when changes are applied to the WebACL.
rules := expandWebACLRules(d.Get(names.AttrRule).(*schema.Set).List())
var rules []awstypes.Rule

rules = expandWebACLRules(d.Get(names.AttrRule).(*schema.Set).List())
if sr := findShieldRule(rules); len(sr) == 0 {
output, err := findWebACLByThreePartKey(ctx, conn, d.Id(), aclName, aclScope)

Expand All @@ -292,6 +325,23 @@ func resourceWebACLUpdate(ctx context.Context, d *schema.ResourceData, meta inte
rules = append(rules, findShieldRule(output.WebACL.Rules)...)
}

if d.HasChange("rule_json") {
r, err := expandWebACLRulesJSON(d.Get("rule_json").(string))
if err != nil {
return sdkdiag.AppendErrorf(diags, "expanding WAFv2 WebACL JSON rule (%s): %s", d.Id(), err)
}
if sr := findShieldRule(rules); len(sr) == 0 {
output, err := findWebACLByThreePartKey(ctx, conn, d.Id(), aclName, aclScope)

if err != nil {
return sdkdiag.AppendErrorf(diags, "reading WAFv2 WebACL (%s): %s", d.Id(), err)
}

r = append(r, findShieldRule(output.WebACL.Rules)...)
}
rules = r
}

input := &wafv2.UpdateWebACLInput{
AssociationConfig: expandAssociationConfig(d.Get("association_config").([]interface{})),
CaptchaConfig: expandCaptchaConfig(d.Get("captcha_config").([]interface{})),
Expand Down
Loading

0 comments on commit 153d847

Please sign in to comment.