diff --git a/aws/internal/service/ec2/finder/finder.go b/aws/internal/service/ec2/finder/finder.go index 709e7ef9962..4bc425d1529 100644 --- a/aws/internal/service/ec2/finder/finder.go +++ b/aws/internal/service/ec2/finder/finder.go @@ -55,6 +55,30 @@ func ClientVpnRouteByID(conn *ec2.EC2, routeID string) (*ec2.DescribeClientVpnRo return ClientVpnRoute(conn, endpointID, targetSubnetID, destinationCidr) } +// DefaultSecurityGroup returns the default security group for the specified VPC. +// Returns nil and potentially an error if no default security group is found. +func DefaultSecurityGroup(conn *ec2.EC2, vpcID string) (*ec2.SecurityGroup, error) { + filters := map[string]string{ + "group-name": "default", + "vpc-id": vpcID, + } + + input := &ec2.DescribeSecurityGroupsInput{ + Filters: tfec2.BuildAttributeFilterList(filters), + } + + output, err := conn.DescribeSecurityGroups(input) + if err != nil { + return nil, err + } + + if output == nil || len(output.SecurityGroups) == 0 { + return nil, nil + } + + return output.SecurityGroups[0], nil +} + // SecurityGroupByID looks up a security group by ID. When not found, returns nil and potentially an API error. func SecurityGroupByID(conn *ec2.EC2, id string) (*ec2.SecurityGroup, error) { req := &ec2.DescribeSecurityGroupsInput{ @@ -91,6 +115,26 @@ func VpcEndpointByID(conn *ec2.EC2, id string) (*ec2.VpcEndpoint, error) { return output.VpcEndpoints[0], nil } +// VpcEndpointSecurityGroupAssociationExists returns whether the specified VPC endpoint/security group association exists. +func VpcEndpointSecurityGroupAssociationExists(conn *ec2.EC2, vpcEndpointID, securityGroupID string) (bool, error) { + vpcEndpoint, err := VpcEndpointByID(conn, vpcEndpointID) + if err != nil { + return false, err + } + + if vpcEndpoint == nil { + return false, nil + } + + for _, group := range vpcEndpoint.Groups { + if aws.StringValue(group.GroupId) == securityGroupID { + return true, nil + } + } + + return false, nil +} + // VpcPeeringConnectionByID returns the VPC peering connection corresponding to the specified identifier. // Returns nil and potentially an error if no VPC peering connection is found. func VpcPeeringConnectionByID(conn *ec2.EC2, id string) (*ec2.VpcPeeringConnection, error) { diff --git a/aws/internal/service/ec2/id.go b/aws/internal/service/ec2/id.go index c3797b4a2ed..199d100360c 100644 --- a/aws/internal/service/ec2/id.go +++ b/aws/internal/service/ec2/id.go @@ -71,6 +71,10 @@ func ClientVpnRouteParseID(id string) (string, string, string, error) { "target-subnet-id"+clientVpnRouteIDSeparator+"destination-cidr-block", id) } +func VpcEndpointSecurityGroupAssociationCreateID(vpceID, sgID string) string { + return fmt.Sprintf("a-%s%d", vpceID, hashcode.String(sgID)) +} + func VpnGatewayVpcAttachmentCreateID(vpnGatewayID, vpcID string) string { return fmt.Sprintf("vpn-attachment-%x", hashcode.String(fmt.Sprintf("%s-%s", vpcID, vpnGatewayID))) } diff --git a/aws/resource_aws_vpc_endpoint.go b/aws/resource_aws_vpc_endpoint.go index 1e85a2ff011..eaf00baed33 100644 --- a/aws/resource_aws_vpc_endpoint.go +++ b/aws/resource_aws_vpc_endpoint.go @@ -523,22 +523,3 @@ func flattenVpcEndpointSecurityGroupIds(groups []*ec2.SecurityGroupIdentifier) * return schema.NewSet(schema.HashString, vSecurityGroupIds) } - -// readVpcEndpoint returns the specified VPC endpoint. -func readVpcEndpoint(conn *ec2.EC2, vpceId string) (*ec2.VpcEndpoint, error) { - input := &ec2.DescribeVpcEndpointsInput{ - VpcEndpointIds: aws.StringSlice([]string{vpceId}), - } - - output, err := conn.DescribeVpcEndpoints(input) - - if err != nil { - return nil, fmt.Errorf("error reading VPC endpoint (%s): %w", vpceId, err) - } - - if n := len(output.VpcEndpoints); n != 1 { - return nil, fmt.Errorf("found %d VPC endpoints (%s), expected 1", n, vpceId) - } - - return output.VpcEndpoints[0], nil -} diff --git a/aws/resource_aws_vpc_endpoint_security_group_association.go b/aws/resource_aws_vpc_endpoint_security_group_association.go index b46d8df9f5e..ea8e59829c8 100644 --- a/aws/resource_aws_vpc_endpoint_security_group_association.go +++ b/aws/resource_aws_vpc_endpoint_security_group_association.go @@ -7,7 +7,8 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" - "github.com/terraform-providers/terraform-provider-aws/aws/internal/hashcode" + tfec2 "github.com/terraform-providers/terraform-provider-aws/aws/internal/service/ec2" + "github.com/terraform-providers/terraform-provider-aws/aws/internal/service/ec2/finder" ) func resourceAwsVpcEndpointSecurityGroupAssociation() *schema.Resource { @@ -46,18 +47,26 @@ func resourceAwsVpcEndpointSecurityGroupAssociationCreate(d *schema.ResourceData defaultSecurityGroupId := "" if replaceDefaultAssociation { - vpce, err := readVpcEndpoint(conn, vpceId) + vpce, err := finder.VpcEndpointByID(conn, vpceId) if err != nil { - return err + return fmt.Errorf("error reading VPC endpoint (%s): %w", vpceId, err) + } + + if vpce == nil { + return fmt.Errorf("VPC endpoint (%s) not found", vpceId) } vpcId := aws.StringValue(vpce.VpcId) - defaultSecurityGroup, err := readDefaultSecurityGroup(conn, vpcId) + defaultSecurityGroup, err := finder.DefaultSecurityGroup(conn, vpcId) if err != nil { - return err + return fmt.Errorf("error reading default security group for VPC (%s): %w", vpcId, err) + } + + if defaultSecurityGroup == nil { + return fmt.Errorf("default security group for VPC (%s) not found", vpcId) } defaultSecurityGroupId = aws.StringValue(defaultSecurityGroup.GroupId) @@ -86,7 +95,7 @@ func resourceAwsVpcEndpointSecurityGroupAssociationCreate(d *schema.ResourceData return err } - d.SetId(vpcEndpointSecurityGroupAssociationId(vpceId, sgId)) + d.SetId(tfec2.VpcEndpointSecurityGroupAssociationCreateID(vpceId, sgId)) if replaceDefaultAssociation { // Delete the existing VPC endpoint/default security group association. @@ -106,7 +115,7 @@ func resourceAwsVpcEndpointSecurityGroupAssociationRead(d *schema.ResourceData, vpceId := d.Get("vpc_endpoint_id").(string) sgId := d.Get("security_group_id").(string) - found, err := readVpcEndpointSecurityGroupAssociation(conn, vpceId, sgId) + found, err := finder.VpcEndpointSecurityGroupAssociationExists(conn, vpceId, sgId) if isAWSErr(err, "InvalidVpcEndpointId.NotFound", "") { log.Printf("[WARN] VPC Endpoint (%s) not found, removing VPC Endpoint/Security Group association (%s) from state", vpceId, d.Id()) @@ -135,16 +144,26 @@ func resourceAwsVpcEndpointSecurityGroupAssociationDelete(d *schema.ResourceData replaceDefaultAssociation := d.Get("replace_default_association").(bool) if replaceDefaultAssociation { - vpce, err := readVpcEndpoint(conn, vpceId) + vpce, err := finder.VpcEndpointByID(conn, vpceId) if err != nil { - return err + return fmt.Errorf("error reading VPC endpoint (%s): %w", vpceId, err) + } + + if vpce == nil { + return fmt.Errorf("VPC endpoint (%s) not found", vpceId) } - defaultSecurityGroup, err := readDefaultSecurityGroup(conn, aws.StringValue(vpce.VpcId)) + vpcId := aws.StringValue(vpce.VpcId) + + defaultSecurityGroup, err := finder.DefaultSecurityGroup(conn, vpcId) if err != nil { - return err + return fmt.Errorf("error reading default security group for VPC (%s): %w", vpcId, err) + } + + if defaultSecurityGroup == nil { + return fmt.Errorf("default security group for VPC (%s) not found", vpcId) } // Add back the VPC endpoint/default security group association. @@ -158,49 +177,6 @@ func resourceAwsVpcEndpointSecurityGroupAssociationDelete(d *schema.ResourceData return deleteVpcEndpointSecurityGroupAssociation(conn, vpceId, sgId) } -func vpcEndpointSecurityGroupAssociationId(vpceId, sgId string) string { - return fmt.Sprintf("a-%s%d", vpceId, hashcode.String(sgId)) -} - -// readDefaultSecurityGroup returns the default security group for the specified VPC. -func readDefaultSecurityGroup(conn *ec2.EC2, vpcId string) (*ec2.SecurityGroup, error) { - input := &ec2.DescribeSecurityGroupsInput{ - Filters: buildEC2AttributeFilterList(map[string]string{ - "group-name": "default", - "vpc-id": vpcId, - }), - } - - output, err := conn.DescribeSecurityGroups(input) - - if err != nil { - return nil, fmt.Errorf("error reading default security group for VPC (%s): %w", vpcId, err) - } - - if n := len(output.SecurityGroups); n != 1 { - return nil, fmt.Errorf("found %d default security groups for VPC (%s), expected 1", n, vpcId) - } - - return output.SecurityGroups[0], nil -} - -// readVpcEndpointSecurityGroupAssociation returns the specified VPC endpoint/security group association. -func readVpcEndpointSecurityGroupAssociation(conn *ec2.EC2, vpceId, sgId string) (bool, error) { - vpce, err := readVpcEndpoint(conn, vpceId) - - if err != nil { - return false, err - } - - for _, group := range vpce.Groups { - if aws.StringValue(group.GroupId) == sgId { - return true, nil - } - } - - return false, nil -} - // createVpcEndpointSecurityGroupAssociation creates the specified VPC endpoint/security group association. func createVpcEndpointSecurityGroupAssociation(conn *ec2.EC2, vpceId, sgId string) error { input := &ec2.ModifyVpcEndpointInput{ diff --git a/aws/resource_aws_vpc_endpoint_security_group_association_test.go b/aws/resource_aws_vpc_endpoint_security_group_association_test.go index bffedc4a6d3..c9ec06a7ea9 100644 --- a/aws/resource_aws_vpc_endpoint_security_group_association_test.go +++ b/aws/resource_aws_vpc_endpoint_security_group_association_test.go @@ -5,11 +5,11 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/ec2" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/acctest" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" + "github.com/terraform-providers/terraform-provider-aws/aws/internal/service/ec2/finder" ) func TestAccAWSVpcEndpointSecurityGroupAssociation_basic(t *testing.T) { @@ -109,27 +109,20 @@ func testAccCheckVpcEndpointSecurityGroupAssociationDestroy(s *terraform.State) continue } - // Try to find the resource - resp, err := conn.DescribeVpcEndpoints(&ec2.DescribeVpcEndpointsInput{ - VpcEndpointIds: aws.StringSlice([]string{rs.Primary.Attributes["vpc_endpoint_id"]}), - }) + out, err := finder.VpcEndpointByID(conn, rs.Primary.Attributes["vpc_endpoint_id"]) + if isAWSErr(err, "InvalidVpcEndpointId.NotFound", "") { + continue + } if err != nil { - // Verify the error is what we want - ec2err, ok := err.(awserr.Error) - if !ok { - return err - } - if ec2err.Code() != "InvalidVpcEndpointId.NotFound" { - return err - } - return nil + return err + } + if out == nil { + continue } - vpce := resp.VpcEndpoints[0] // VPC Endpoint will always have 1 SG. - if len(vpce.Groups) > 1 { - return fmt.Errorf( - "VPC endpoint %s has security groups", *vpce.VpcEndpointId) + if len(out.Groups) > 1 { + return fmt.Errorf("VPC endpoint %s has security groups", aws.StringValue(out.VpcEndpointId)) } } @@ -148,24 +141,25 @@ func testAccCheckVpcEndpointSecurityGroupAssociationExists(n string, vpce *ec2.V } conn := testAccProvider.Meta().(*AWSClient).ec2conn - resp, err := conn.DescribeVpcEndpoints(&ec2.DescribeVpcEndpointsInput{ - VpcEndpointIds: aws.StringSlice([]string{rs.Primary.Attributes["vpc_endpoint_id"]}), - }) + out, err := finder.VpcEndpointByID(conn, rs.Primary.Attributes["vpc_endpoint_id"]) + if isAWSErr(err, "InvalidVpcEndpointId.NotFound", "") { + return fmt.Errorf("VPC Endpoint not found") + } if err != nil { return err } - if len(resp.VpcEndpoints) == 0 { - return fmt.Errorf("VPC endpoint not found") + if out == nil { + return fmt.Errorf("VPC Endpoint not found") } - *vpce = *resp.VpcEndpoints[0] - - if len(vpce.Groups) == 0 { + if len(out.Groups) == 0 { return fmt.Errorf("no security group associations") } - for _, group := range vpce.Groups { + for _, group := range out.Groups { if aws.StringValue(group.GroupId) == rs.Primary.Attributes["security_group_id"] { + *vpce = *out + return nil } }