Skip to content

Commit

Permalink
r/aws_vpc_endpoint_security_group_association: Use internal 'finder' …
Browse files Browse the repository at this point in the history
…package.

Acceptance test output:

$ make testacc TEST=./aws/ TESTARGS='-run=TestAccAWSVpcEndpointSecurityGroupAssociation_' ACCTEST_PARALLELISM=2
==> Checking that code complies with gofmt requirements...
TF_ACC=1 go test ./aws -v -count 1 -parallel 2 -run=TestAccAWSVpcEndpointSecurityGroupAssociation_ -timeout 120m
=== RUN   TestAccAWSVpcEndpointSecurityGroupAssociation_basic
=== PAUSE TestAccAWSVpcEndpointSecurityGroupAssociation_basic
=== RUN   TestAccAWSVpcEndpointSecurityGroupAssociation_disappears
=== PAUSE TestAccAWSVpcEndpointSecurityGroupAssociation_disappears
=== RUN   TestAccAWSVpcEndpointSecurityGroupAssociation_multiple
=== PAUSE TestAccAWSVpcEndpointSecurityGroupAssociation_multiple
=== RUN   TestAccAWSVpcEndpointSecurityGroupAssociation_ReplaceDefaultAssociation
=== PAUSE TestAccAWSVpcEndpointSecurityGroupAssociation_ReplaceDefaultAssociation
=== CONT  TestAccAWSVpcEndpointSecurityGroupAssociation_basic
=== CONT  TestAccAWSVpcEndpointSecurityGroupAssociation_ReplaceDefaultAssociation
--- PASS: TestAccAWSVpcEndpointSecurityGroupAssociation_ReplaceDefaultAssociation (104.16s)
=== CONT  TestAccAWSVpcEndpointSecurityGroupAssociation_multiple
--- PASS: TestAccAWSVpcEndpointSecurityGroupAssociation_basic (111.52s)
=== CONT  TestAccAWSVpcEndpointSecurityGroupAssociation_disappears
    resource_aws_vpc_endpoint_security_group_association_test.go:41: [INFO] Got non-empty plan, as expected
--- PASS: TestAccAWSVpcEndpointSecurityGroupAssociation_multiple (61.28s)
--- PASS: TestAccAWSVpcEndpointSecurityGroupAssociation_disappears (63.64s)
PASS
ok  	github.com/terraform-providers/terraform-provider-aws/aws	175.194s
  • Loading branch information
ewbankkit committed Nov 16, 2020
1 parent ec1ca86 commit d6a3b16
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 100 deletions.
44 changes: 44 additions & 0 deletions aws/internal/service/ec2/finder/finder.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ func ClientVpnRouteByID(conn *ec2.EC2, routeID string) (*ec2.DescribeClientVpnRo
return ClientVpnRoute(conn, endpointID, targetSubnetID, destinationCidr)
}

// DefaultSecurityGroup returns the default security group for the specified VPC.
// Returns nil and potentially an error if no default security group is found.
func DefaultSecurityGroup(conn *ec2.EC2, vpcID string) (*ec2.SecurityGroup, error) {
filters := map[string]string{
"group-name": "default",
"vpc-id": vpcID,
}

input := &ec2.DescribeSecurityGroupsInput{
Filters: tfec2.BuildAttributeFilterList(filters),
}

output, err := conn.DescribeSecurityGroups(input)
if err != nil {
return nil, err
}

if output == nil || len(output.SecurityGroups) == 0 {
return nil, nil
}

return output.SecurityGroups[0], nil
}

// SecurityGroupByID looks up a security group by ID. When not found, returns nil and potentially an API error.
func SecurityGroupByID(conn *ec2.EC2, id string) (*ec2.SecurityGroup, error) {
req := &ec2.DescribeSecurityGroupsInput{
Expand Down Expand Up @@ -91,6 +115,26 @@ func VpcEndpointByID(conn *ec2.EC2, id string) (*ec2.VpcEndpoint, error) {
return output.VpcEndpoints[0], nil
}

// VpcEndpointSecurityGroupAssociationExists returns whether the specified VPC endpoint/security group association exists.
func VpcEndpointSecurityGroupAssociationExists(conn *ec2.EC2, vpcEndpointID, securityGroupID string) (bool, error) {
vpcEndpoint, err := VpcEndpointByID(conn, vpcEndpointID)
if err != nil {
return false, err
}

if vpcEndpoint == nil {
return false, nil
}

for _, group := range vpcEndpoint.Groups {
if aws.StringValue(group.GroupId) == securityGroupID {
return true, nil
}
}

return false, nil
}

// VpcPeeringConnectionByID returns the VPC peering connection corresponding to the specified identifier.
// Returns nil and potentially an error if no VPC peering connection is found.
func VpcPeeringConnectionByID(conn *ec2.EC2, id string) (*ec2.VpcPeeringConnection, error) {
Expand Down
4 changes: 4 additions & 0 deletions aws/internal/service/ec2/id.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func ClientVpnRouteParseID(id string) (string, string, string, error) {
"target-subnet-id"+clientVpnRouteIDSeparator+"destination-cidr-block", id)
}

func VpcEndpointSecurityGroupAssociationCreateID(vpceID, sgID string) string {
return fmt.Sprintf("a-%s%d", vpceID, hashcode.String(sgID))
}

func VpnGatewayVpcAttachmentCreateID(vpnGatewayID, vpcID string) string {
return fmt.Sprintf("vpn-attachment-%x", hashcode.String(fmt.Sprintf("%s-%s", vpcID, vpnGatewayID)))
}
19 changes: 0 additions & 19 deletions aws/resource_aws_vpc_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,22 +523,3 @@ func flattenVpcEndpointSecurityGroupIds(groups []*ec2.SecurityGroupIdentifier) *

return schema.NewSet(schema.HashString, vSecurityGroupIds)
}

// readVpcEndpoint returns the specified VPC endpoint.
func readVpcEndpoint(conn *ec2.EC2, vpceId string) (*ec2.VpcEndpoint, error) {
input := &ec2.DescribeVpcEndpointsInput{
VpcEndpointIds: aws.StringSlice([]string{vpceId}),
}

output, err := conn.DescribeVpcEndpoints(input)

if err != nil {
return nil, fmt.Errorf("error reading VPC endpoint (%s): %w", vpceId, err)
}

if n := len(output.VpcEndpoints); n != 1 {
return nil, fmt.Errorf("found %d VPC endpoints (%s), expected 1", n, vpceId)
}

return output.VpcEndpoints[0], nil
}
84 changes: 30 additions & 54 deletions aws/resource_aws_vpc_endpoint_security_group_association.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/terraform-providers/terraform-provider-aws/aws/internal/hashcode"
tfec2 "github.com/terraform-providers/terraform-provider-aws/aws/internal/service/ec2"
"github.com/terraform-providers/terraform-provider-aws/aws/internal/service/ec2/finder"
)

func resourceAwsVpcEndpointSecurityGroupAssociation() *schema.Resource {
Expand Down Expand Up @@ -46,18 +47,26 @@ func resourceAwsVpcEndpointSecurityGroupAssociationCreate(d *schema.ResourceData

defaultSecurityGroupId := ""
if replaceDefaultAssociation {
vpce, err := readVpcEndpoint(conn, vpceId)
vpce, err := finder.VpcEndpointByID(conn, vpceId)

if err != nil {
return err
return fmt.Errorf("error reading VPC endpoint (%s): %w", vpceId, err)
}

if vpce == nil {
return fmt.Errorf("VPC endpoint (%s) not found", vpceId)
}

vpcId := aws.StringValue(vpce.VpcId)

defaultSecurityGroup, err := readDefaultSecurityGroup(conn, vpcId)
defaultSecurityGroup, err := finder.DefaultSecurityGroup(conn, vpcId)

if err != nil {
return err
return fmt.Errorf("error reading default security group for VPC (%s): %w", vpcId, err)
}

if defaultSecurityGroup == nil {
return fmt.Errorf("default security group for VPC (%s) not found", vpcId)
}

defaultSecurityGroupId = aws.StringValue(defaultSecurityGroup.GroupId)
Expand Down Expand Up @@ -86,7 +95,7 @@ func resourceAwsVpcEndpointSecurityGroupAssociationCreate(d *schema.ResourceData
return err
}

d.SetId(vpcEndpointSecurityGroupAssociationId(vpceId, sgId))
d.SetId(tfec2.VpcEndpointSecurityGroupAssociationCreateID(vpceId, sgId))

if replaceDefaultAssociation {
// Delete the existing VPC endpoint/default security group association.
Expand All @@ -106,7 +115,7 @@ func resourceAwsVpcEndpointSecurityGroupAssociationRead(d *schema.ResourceData,
vpceId := d.Get("vpc_endpoint_id").(string)
sgId := d.Get("security_group_id").(string)

found, err := readVpcEndpointSecurityGroupAssociation(conn, vpceId, sgId)
found, err := finder.VpcEndpointSecurityGroupAssociationExists(conn, vpceId, sgId)

if isAWSErr(err, "InvalidVpcEndpointId.NotFound", "") {
log.Printf("[WARN] VPC Endpoint (%s) not found, removing VPC Endpoint/Security Group association (%s) from state", vpceId, d.Id())
Expand Down Expand Up @@ -135,16 +144,26 @@ func resourceAwsVpcEndpointSecurityGroupAssociationDelete(d *schema.ResourceData
replaceDefaultAssociation := d.Get("replace_default_association").(bool)

if replaceDefaultAssociation {
vpce, err := readVpcEndpoint(conn, vpceId)
vpce, err := finder.VpcEndpointByID(conn, vpceId)

if err != nil {
return err
return fmt.Errorf("error reading VPC endpoint (%s): %w", vpceId, err)
}

if vpce == nil {
return fmt.Errorf("VPC endpoint (%s) not found", vpceId)
}

defaultSecurityGroup, err := readDefaultSecurityGroup(conn, aws.StringValue(vpce.VpcId))
vpcId := aws.StringValue(vpce.VpcId)

defaultSecurityGroup, err := finder.DefaultSecurityGroup(conn, vpcId)

if err != nil {
return err
return fmt.Errorf("error reading default security group for VPC (%s): %w", vpcId, err)
}

if defaultSecurityGroup == nil {
return fmt.Errorf("default security group for VPC (%s) not found", vpcId)
}

// Add back the VPC endpoint/default security group association.
Expand All @@ -158,49 +177,6 @@ func resourceAwsVpcEndpointSecurityGroupAssociationDelete(d *schema.ResourceData
return deleteVpcEndpointSecurityGroupAssociation(conn, vpceId, sgId)
}

func vpcEndpointSecurityGroupAssociationId(vpceId, sgId string) string {
return fmt.Sprintf("a-%s%d", vpceId, hashcode.String(sgId))
}

// readDefaultSecurityGroup returns the default security group for the specified VPC.
func readDefaultSecurityGroup(conn *ec2.EC2, vpcId string) (*ec2.SecurityGroup, error) {
input := &ec2.DescribeSecurityGroupsInput{
Filters: buildEC2AttributeFilterList(map[string]string{
"group-name": "default",
"vpc-id": vpcId,
}),
}

output, err := conn.DescribeSecurityGroups(input)

if err != nil {
return nil, fmt.Errorf("error reading default security group for VPC (%s): %w", vpcId, err)
}

if n := len(output.SecurityGroups); n != 1 {
return nil, fmt.Errorf("found %d default security groups for VPC (%s), expected 1", n, vpcId)
}

return output.SecurityGroups[0], nil
}

// readVpcEndpointSecurityGroupAssociation returns the specified VPC endpoint/security group association.
func readVpcEndpointSecurityGroupAssociation(conn *ec2.EC2, vpceId, sgId string) (bool, error) {
vpce, err := readVpcEndpoint(conn, vpceId)

if err != nil {
return false, err
}

for _, group := range vpce.Groups {
if aws.StringValue(group.GroupId) == sgId {
return true, nil
}
}

return false, nil
}

// createVpcEndpointSecurityGroupAssociation creates the specified VPC endpoint/security group association.
func createVpcEndpointSecurityGroupAssociation(conn *ec2.EC2, vpceId, sgId string) error {
input := &ec2.ModifyVpcEndpointInput{
Expand Down
48 changes: 21 additions & 27 deletions aws/resource_aws_vpc_endpoint_security_group_association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import (
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/acctest"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/terraform"
"github.com/terraform-providers/terraform-provider-aws/aws/internal/service/ec2/finder"
)

func TestAccAWSVpcEndpointSecurityGroupAssociation_basic(t *testing.T) {
Expand Down Expand Up @@ -109,27 +109,20 @@ func testAccCheckVpcEndpointSecurityGroupAssociationDestroy(s *terraform.State)
continue
}

// Try to find the resource
resp, err := conn.DescribeVpcEndpoints(&ec2.DescribeVpcEndpointsInput{
VpcEndpointIds: aws.StringSlice([]string{rs.Primary.Attributes["vpc_endpoint_id"]}),
})
out, err := finder.VpcEndpointByID(conn, rs.Primary.Attributes["vpc_endpoint_id"])
if isAWSErr(err, "InvalidVpcEndpointId.NotFound", "") {
continue
}
if err != nil {
// Verify the error is what we want
ec2err, ok := err.(awserr.Error)
if !ok {
return err
}
if ec2err.Code() != "InvalidVpcEndpointId.NotFound" {
return err
}
return nil
return err
}
if out == nil {
continue
}

vpce := resp.VpcEndpoints[0]
// VPC Endpoint will always have 1 SG.
if len(vpce.Groups) > 1 {
return fmt.Errorf(
"VPC endpoint %s has security groups", *vpce.VpcEndpointId)
if len(out.Groups) > 1 {
return fmt.Errorf("VPC endpoint %s has security groups", aws.StringValue(out.VpcEndpointId))
}
}

Expand All @@ -148,24 +141,25 @@ func testAccCheckVpcEndpointSecurityGroupAssociationExists(n string, vpce *ec2.V
}

conn := testAccProvider.Meta().(*AWSClient).ec2conn
resp, err := conn.DescribeVpcEndpoints(&ec2.DescribeVpcEndpointsInput{
VpcEndpointIds: aws.StringSlice([]string{rs.Primary.Attributes["vpc_endpoint_id"]}),
})
out, err := finder.VpcEndpointByID(conn, rs.Primary.Attributes["vpc_endpoint_id"])
if isAWSErr(err, "InvalidVpcEndpointId.NotFound", "") {
return fmt.Errorf("VPC Endpoint not found")
}
if err != nil {
return err
}
if len(resp.VpcEndpoints) == 0 {
return fmt.Errorf("VPC endpoint not found")
if out == nil {
return fmt.Errorf("VPC Endpoint not found")
}

*vpce = *resp.VpcEndpoints[0]

if len(vpce.Groups) == 0 {
if len(out.Groups) == 0 {
return fmt.Errorf("no security group associations")
}

for _, group := range vpce.Groups {
for _, group := range out.Groups {
if aws.StringValue(group.GroupId) == rs.Primary.Attributes["security_group_id"] {
*vpce = *out

return nil
}
}
Expand Down

0 comments on commit d6a3b16

Please sign in to comment.