Skip to content

Commit

Permalink
Merge pull request #18561 from DrFaust92/r/apigw_stage_web_acl_arn
Browse files Browse the repository at this point in the history
r/apigateway_stage - add `waf_acl_arn` attribute + use waiter and finder
  • Loading branch information
ewbankkit authored Jan 21, 2022
2 parents 8c31926 + 69cb5b4 commit 6ece474
Show file tree
Hide file tree
Showing 11 changed files with 568 additions and 408 deletions.
3 changes: 3 additions & 0 deletions .changelog/18561.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_api_gateway_stage: Add `web_acl_arn` attribute
```
30 changes: 12 additions & 18 deletions internal/service/apigateway/deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"github.com/hashicorp/aws-sdk-go-base/tfawserr"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-provider-aws/internal/conns"
"github.com/hashicorp/terraform-provider-aws/internal/flex"
"github.com/hashicorp/terraform-provider-aws/internal/tfresource"
)

func ResourceDeployment() *schema.Resource {
Expand Down Expand Up @@ -81,21 +83,15 @@ func resourceDeploymentCreate(d *schema.ResourceData, meta interface{}) error {
// Create the gateway
log.Printf("[DEBUG] Creating API Gateway Deployment")

variables := make(map[string]string)
for k, v := range d.Get("variables").(map[string]interface{}) {
variables[k] = v.(string)
}

var err error
deployment, err := conn.CreateDeployment(&apigateway.CreateDeploymentInput{
RestApiId: aws.String(d.Get("rest_api_id").(string)),
StageName: aws.String(d.Get("stage_name").(string)),
Description: aws.String(d.Get("description").(string)),
StageDescription: aws.String(d.Get("stage_description").(string)),
Variables: aws.StringMap(variables),
Variables: flex.ExpandStringMap(d.Get("variables").(map[string]interface{})),
})
if err != nil {
return fmt.Errorf("Error creating API Gateway Deployment: %s", err)
return fmt.Errorf("Error creating API Gateway Deployment: %w", err)
}

d.SetId(aws.StringValue(deployment.Id))
Expand Down Expand Up @@ -188,14 +184,12 @@ func resourceDeploymentDelete(d *schema.ResourceData, meta interface{}) error {
// InvalidParameter: 1 validation error(s) found.
// - minimum field size of 1, GetStageInput.StageName.
stageName := d.Get("stage_name").(string)
restApiId := d.Get("rest_api_id").(string)
if stageName != "" {
stage, err := conn.GetStage(&apigateway.GetStageInput{
StageName: aws.String(stageName),
RestApiId: aws.String(d.Get("rest_api_id").(string)),
})
stage, err := FindStageByName(conn, restApiId, stageName)

if err != nil && !tfawserr.ErrMessageContains(err, apigateway.ErrCodeNotFoundException, "") {
return fmt.Errorf("error getting referenced stage: %s", err)
if err != nil && !tfresource.NotFound(err) {
return fmt.Errorf("error getting referenced stage: %w", err)
}

if stage != nil && aws.StringValue(stage.DeploymentId) == d.Id() {
Expand All @@ -205,24 +199,24 @@ func resourceDeploymentDelete(d *schema.ResourceData, meta interface{}) error {

if shouldDeleteStage {
if _, err := conn.DeleteStage(&apigateway.DeleteStageInput{
StageName: aws.String(d.Get("stage_name").(string)),
RestApiId: aws.String(d.Get("rest_api_id").(string)),
StageName: aws.String(stageName),
RestApiId: aws.String(restApiId),
}); err == nil {
return nil
}
}

_, err := conn.DeleteDeployment(&apigateway.DeleteDeploymentInput{
DeploymentId: aws.String(d.Id()),
RestApiId: aws.String(d.Get("rest_api_id").(string)),
RestApiId: aws.String(restApiId),
})

if tfawserr.ErrMessageContains(err, apigateway.ErrCodeNotFoundException, "") {
return nil
}

if err != nil {
return fmt.Errorf("error deleting API Gateway Deployment (%s): %s", d.Id(), err)
return fmt.Errorf("error deleting API Gateway Deployment (%s): %w", d.Id(), err)
}

return nil
Expand Down
36 changes: 6 additions & 30 deletions internal/service/apigateway/deployment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func TestAccAPIGatewayDeployment_triggers(t *testing.T) {
Config: testAccDeploymentTriggersConfig("description1", "https://example.com"),
Check: resource.ComposeTestCheckFunc(
testAccCheckDeploymentExists(resourceName, &deployment1),
testAccCheckDeploymentStageExists(resourceName, &stage),
testAccCheckStageExists(resourceName, &stage),
resource.TestCheckResourceAttr(resourceName, "description", "description1"),
resource.TestCheckResourceAttr(resourceName, "stage_description", "description1"),
),
Expand All @@ -100,7 +100,7 @@ func TestAccAPIGatewayDeployment_triggers(t *testing.T) {
Check: resource.ComposeTestCheckFunc(
testAccCheckDeploymentExists(resourceName, &deployment2),
testAccCheckDeploymentRecreated(&deployment1, &deployment2),
testAccCheckDeploymentStageExists(resourceName, &stage),
testAccCheckStageExists(resourceName, &stage),
resource.TestCheckResourceAttr(resourceName, "description", "description1"),
resource.TestCheckResourceAttr(resourceName, "stage_description", "description1"),
),
Expand All @@ -110,7 +110,7 @@ func TestAccAPIGatewayDeployment_triggers(t *testing.T) {
Check: resource.ComposeTestCheckFunc(
testAccCheckDeploymentExists(resourceName, &deployment3),
testAccCheckDeploymentNotRecreated(&deployment2, &deployment3),
testAccCheckDeploymentStageExists(resourceName, &stage),
testAccCheckStageExists(resourceName, &stage),
resource.TestCheckResourceAttr(resourceName, "description", "description1"),
resource.TestCheckResourceAttr(resourceName, "stage_description", "description1"),
),
Expand All @@ -120,7 +120,7 @@ func TestAccAPIGatewayDeployment_triggers(t *testing.T) {
Check: resource.ComposeTestCheckFunc(
testAccCheckDeploymentExists(resourceName, &deployment4),
testAccCheckDeploymentRecreated(&deployment3, &deployment4),
testAccCheckDeploymentStageExists(resourceName, &stage),
testAccCheckStageExists(resourceName, &stage),
resource.TestCheckResourceAttr(resourceName, "description", "description2"),
resource.TestCheckResourceAttr(resourceName, "stage_description", "description2"),
),
Expand Down Expand Up @@ -172,7 +172,7 @@ func TestAccAPIGatewayDeployment_stageDescription(t *testing.T) {
Config: testAccDeploymentStageDescriptionConfig("description1"),
Check: resource.ComposeTestCheckFunc(
testAccCheckDeploymentExists(resourceName, &deployment),
testAccCheckDeploymentStageExists(resourceName, &stage),
testAccCheckStageExists(resourceName, &stage),
resource.TestCheckResourceAttr(resourceName, "stage_description", "description1"),
),
},
Expand All @@ -195,7 +195,7 @@ func TestAccAPIGatewayDeployment_stageName(t *testing.T) {
Config: testAccDeploymentStageNameConfig("test"),
Check: resource.ComposeTestCheckFunc(
testAccCheckDeploymentExists(resourceName, &deployment),
testAccCheckDeploymentStageExists(resourceName, &stage),
testAccCheckStageExists(resourceName, &stage),
resource.TestCheckResourceAttr(resourceName, "stage_name", "test"),
),
},
Expand Down Expand Up @@ -284,30 +284,6 @@ func testAccCheckDeploymentExists(n string, res *apigateway.Deployment) resource
}
}

func testAccCheckDeploymentStageExists(resourceName string, res *apigateway.Stage) resource.TestCheckFunc {
return func(s *terraform.State) error {
conn := acctest.Provider.Meta().(*conns.AWSClient).APIGatewayConn

rs, ok := s.RootModule().Resources[resourceName]
if !ok {
return fmt.Errorf("Deployment not found: %s", resourceName)
}

req := &apigateway.GetStageInput{
StageName: aws.String(rs.Primary.Attributes["stage_name"]),
RestApiId: aws.String(rs.Primary.Attributes["rest_api_id"]),
}
stage, err := conn.GetStage(req)
if err != nil {
return err
}

*res = *stage

return nil
}
}

func testAccCheckDeploymentDestroy(s *terraform.State) error {
conn := acctest.Provider.Meta().(*conns.AWSClient).APIGatewayConn

Expand Down
34 changes: 34 additions & 0 deletions internal/service/apigateway/find.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package apigateway

import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/apigateway"
"github.com/hashicorp/aws-sdk-go-base/tfawserr"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-provider-aws/internal/tfresource"
)

func FindStageByName(conn *apigateway.APIGateway, restApiId, name string) (*apigateway.Stage, error) {
input := &apigateway.GetStageInput{
RestApiId: aws.String(restApiId),
StageName: aws.String(name),
}

output, err := conn.GetStage(input)
if tfawserr.ErrCodeEquals(err, apigateway.ErrCodeNotFoundException) {
return nil, &resource.NotFoundError{
LastError: err,
LastRequest: input,
}
}

if err != nil {
return nil, err
}

if output == nil {
return nil, tfresource.NewEmptyResultError(input)
}

return output, nil
}
22 changes: 7 additions & 15 deletions internal/service/apigateway/method_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/hashicorp/terraform-provider-aws/internal/conns"
"github.com/hashicorp/terraform-provider-aws/internal/tfresource"
)

func ResourceMethodSettings() *schema.Resource {
Expand Down Expand Up @@ -97,14 +98,10 @@ func ResourceMethodSettings() *schema.Resource {
Computed: true,
},
"unauthorized_cache_control_header_strategy": {
Type: schema.TypeString,
Optional: true,
ValidateFunc: validation.StringInSlice([]string{
apigateway.UnauthorizedCacheControlHeaderStrategyFailWith403,
apigateway.UnauthorizedCacheControlHeaderStrategySucceedWithResponseHeader,
apigateway.UnauthorizedCacheControlHeaderStrategySucceedWithoutResponseHeader,
}, false),
Computed: true,
Type: schema.TypeString,
Optional: true,
ValidateFunc: validation.StringInSlice(apigateway.UnauthorizedCacheControlHeaderStrategy_Values(), false),
Computed: true,
},
},
},
Expand Down Expand Up @@ -137,14 +134,9 @@ func flattenMethodSettings(settings *apigateway.MethodSetting) []interface{} {
func resourceMethodSettingsRead(d *schema.ResourceData, meta interface{}) error {
conn := meta.(*conns.AWSClient).APIGatewayConn

input := &apigateway.GetStageInput{
RestApiId: aws.String(d.Get("rest_api_id").(string)),
StageName: aws.String(d.Get("stage_name").(string)),
}

stage, err := conn.GetStage(input)
stage, err := FindStageByName(conn, d.Get("rest_api_id").(string), d.Get("stage_name").(string))

if !d.IsNewResource() && tfawserr.ErrCodeEquals(err, apigateway.ErrCodeNotFoundException) {
if !d.IsNewResource() && tfresource.NotFound(err) {
log.Printf("[WARN] API Gateway Stage Method Settings (%s) not found, removing from state", d.Id())
d.SetId("")
return nil
Expand Down
Loading

0 comments on commit 6ece474

Please sign in to comment.