Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New resource aws_vpc_endpoint_security_group_association #13737

Merged
7 changes: 7 additions & 0 deletions .changelog/13737.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
```release-note:new-resource
aws_vpc_endpoint_security_group_association
```

```release-note:enhancement
resource/aws_vpc_endpoint: The `security_group_ids` attribute can now be empty when the resource is created. In this case the VPC's default security is associated with the VPC endpoint
```
1 change: 1 addition & 0 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,7 @@ func Provider() *schema.Provider {
"aws_vpc_endpoint_connection_notification": ec2.ResourceVPCEndpointConnectionNotification(),
"aws_vpc_endpoint_policy": ec2.ResourceVPCEndpointPolicy(),
"aws_vpc_endpoint_route_table_association": ec2.ResourceVPCEndpointRouteTableAssociation(),
"aws_vpc_endpoint_security_group_association": ec2.ResourceVPCEndpointSecurityGroupAssociation(),
"aws_vpc_endpoint_service": ec2.ResourceVPCEndpointService(),
"aws_vpc_endpoint_service_allowed_principal": ec2.ResourceVPCEndpointServiceAllowedPrincipal(),
"aws_vpc_endpoint_subnet_association": ec2.ResourceVPCEndpointSubnetAssociation(),
Expand Down
85 changes: 65 additions & 20 deletions internal/service/ec2/find.go
Original file line number Diff line number Diff line change
Expand Up @@ -1796,8 +1796,55 @@ func FindVPCMainRouteTable(conn *ec2.EC2, id string) (*ec2.RouteTable, error) {
return FindRouteTable(conn, input)
}

// FindVPCEndpointByID returns the VPC endpoint corresponding to the specified identifier.
// Returns NotFoundError if no VPC endpoint is found.
func FindVPCEndpoint(conn *ec2.EC2, input *ec2.DescribeVpcEndpointsInput) (*ec2.VpcEndpoint, error) {
output, err := FindVPCEndpoints(conn, input)

if err != nil {
return nil, err
}

if len(output) == 0 || output[0] == nil {
return nil, tfresource.NewEmptyResultError(input)
}

if count := len(output); count > 1 {
return nil, tfresource.NewTooManyResultsError(count, input)
}

return output[0], nil
}

func FindVPCEndpoints(conn *ec2.EC2, input *ec2.DescribeVpcEndpointsInput) ([]*ec2.VpcEndpoint, error) {
var output []*ec2.VpcEndpoint

err := conn.DescribeVpcEndpointsPages(input, func(page *ec2.DescribeVpcEndpointsOutput, lastPage bool) bool {
if page == nil {
return !lastPage
}

for _, v := range page.VpcEndpoints {
if v != nil {
output = append(output, v)
}
}

return !lastPage
})

if tfawserr.ErrCodeEquals(err, ErrCodeInvalidVpcEndpointIdNotFound) {
return nil, &resource.NotFoundError{
LastError: err,
LastRequest: input,
}
}

if err != nil {
return nil, err
}

return output, nil
}

func FindVPCEndpointByID(conn *ec2.EC2, vpcEndpointID string) (*ec2.VpcEndpoint, error) {
input := &ec2.DescribeVpcEndpointsInput{
VpcEndpointIds: aws.StringSlice([]string{vpcEndpointID}),
Expand Down Expand Up @@ -1826,43 +1873,41 @@ func FindVPCEndpointByID(conn *ec2.EC2, vpcEndpointID string) (*ec2.VpcEndpoint,
return output, nil
}

func FindVPCEndpoint(conn *ec2.EC2, input *ec2.DescribeVpcEndpointsInput) (*ec2.VpcEndpoint, error) {
output, err := conn.DescribeVpcEndpoints(input)

if tfawserr.ErrCodeEquals(err, ErrCodeInvalidVpcEndpointIdNotFound) {
return nil, &resource.NotFoundError{
LastError: err,
LastRequest: input,
}
}
// FindVPCEndpointRouteTableAssociationExists returns NotFoundError if no association for the specified VPC endpoint and route table IDs is found.
func FindVPCEndpointRouteTableAssociationExists(conn *ec2.EC2, vpcEndpointID string, routeTableID string) error {
vpcEndpoint, err := FindVPCEndpointByID(conn, vpcEndpointID)

if err != nil {
return nil, err
return err
}

if output == nil || len(output.VpcEndpoints) == 0 || output.VpcEndpoints[0] == nil {
return nil, tfresource.NewEmptyResultError(input)
for _, vpcEndpointRouteTableID := range vpcEndpoint.RouteTableIds {
if aws.StringValue(vpcEndpointRouteTableID) == routeTableID {
return nil
}
}

return output.VpcEndpoints[0], nil
return &resource.NotFoundError{
LastError: fmt.Errorf("VPC Endpoint (%s) Route Table (%s) Association not found", vpcEndpointID, routeTableID),
}
}

// FindVPCEndpointRouteTableAssociationExists returns NotFoundError if no association for the specified VPC endpoint and route table IDs is found.
func FindVPCEndpointRouteTableAssociationExists(conn *ec2.EC2, vpcEndpointID string, routeTableID string) error {
// FindVPCEndpointSecurityGroupAssociationExists returns NotFoundError if no association for the specified VPC endpoint and security group IDs is found.
func FindVPCEndpointSecurityGroupAssociationExists(conn *ec2.EC2, vpcEndpointID, securityGroupID string) error {
vpcEndpoint, err := FindVPCEndpointByID(conn, vpcEndpointID)

if err != nil {
return err
}

for _, vpcEndpointRouteTableID := range vpcEndpoint.RouteTableIds {
if aws.StringValue(vpcEndpointRouteTableID) == routeTableID {
for _, group := range vpcEndpoint.Groups {
if aws.StringValue(group.GroupId) == securityGroupID {
return nil
}
}

return &resource.NotFoundError{
LastError: fmt.Errorf("VPC Endpoint Route Table Association (%s/%s) not found", vpcEndpointID, routeTableID),
LastError: fmt.Errorf("VPC Endpoint (%s) Security Group (%s) Association not found", vpcEndpointID, securityGroupID),
}
}

Expand Down
4 changes: 4 additions & 0 deletions internal/service/ec2/id.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ func VPCEndpointRouteTableAssociationCreateID(vpcEndpointID, routeTableID string
return fmt.Sprintf("a-%s%d", vpcEndpointID, create.StringHashcode(routeTableID))
}

func VPCEndpointSecurityGroupAssociationCreateID(vpcEndpointID, securityGroupID string) string {
return fmt.Sprintf("a-%s%d", vpcEndpointID, create.StringHashcode(securityGroupID))
}

func VPCEndpointSubnetAssociationCreateID(vpcEndpointID, subnetID string) string {
return fmt.Sprintf("a-%s%d", vpcEndpointID, create.StringHashcode(subnetID))
}
Expand Down
6 changes: 0 additions & 6 deletions internal/service/ec2/vpc_endpoint.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ec2

import (
"errors"
"fmt"
"log"
"time"
Expand Down Expand Up @@ -157,11 +156,6 @@ func ResourceVPCEndpoint() *schema.Resource {
}

func resourceVPCEndpointCreate(d *schema.ResourceData, meta interface{}) error {
if d.Get("vpc_endpoint_type").(string) == ec2.VpcEndpointTypeInterface &&
d.Get("security_group_ids").(*schema.Set).Len() == 0 {
return errors.New("An Interface VPC Endpoint must always have at least one Security Group")
}

conn := meta.(*conns.AWSClient).EC2Conn
defaultTagsConfig := meta.(*conns.AWSClient).DefaultTagsConfig
tags := defaultTagsConfig.MergeTags(tftags.New(d.Get("tags").(map[string]interface{})))
Expand Down
195 changes: 195 additions & 0 deletions internal/service/ec2/vpc_endpoint_security_group_association.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package ec2

import (
"fmt"
"log"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/hashicorp/aws-sdk-go-base/v2/awsv1shim/v2/tfawserr"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-provider-aws/internal/conns"
"github.com/hashicorp/terraform-provider-aws/internal/tfresource"
)

func ResourceVPCEndpointSecurityGroupAssociation() *schema.Resource {
return &schema.Resource{
Create: resourceVPCEndpointSecurityGroupAssociationCreate,
Read: resourceVPCEndpointSecurityGroupAssociationRead,
Delete: resourceVPCEndpointSecurityGroupAssociationDelete,

Schema: map[string]*schema.Schema{
"replace_default_association": {
Type: schema.TypeBool,
Optional: true,
Default: false,
ForceNew: true,
},
"security_group_id": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
},
"vpc_endpoint_id": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
},
},
}
}

func resourceVPCEndpointSecurityGroupAssociationCreate(d *schema.ResourceData, meta interface{}) error {
conn := meta.(*conns.AWSClient).EC2Conn

vpcEndpointID := d.Get("vpc_endpoint_id").(string)
securityGroupID := d.Get("security_group_id").(string)
replaceDefaultAssociation := d.Get("replace_default_association").(bool)

defaultSecurityGroupID := ""
if replaceDefaultAssociation {
vpcEndpoint, err := FindVPCEndpointByID(conn, vpcEndpointID)

if err != nil {
return fmt.Errorf("error reading VPC Endpoint (%s): %w", vpcEndpointID, err)
}

vpcID := aws.StringValue(vpcEndpoint.VpcId)

defaultSecurityGroup, err := FindVPCDefaultSecurityGroup(conn, vpcID)

if err != nil {
return fmt.Errorf("error reading EC2 VPC (%s) default Security Group: %w", vpcID, err)
}

defaultSecurityGroupID = aws.StringValue(defaultSecurityGroup.GroupId)

if defaultSecurityGroupID == securityGroupID {
return fmt.Errorf("%s is the default Security Group for EC2 VPC (%s)", securityGroupID, vpcID)
}

foundDefaultAssociation := false

for _, group := range vpcEndpoint.Groups {
if aws.StringValue(group.GroupId) == defaultSecurityGroupID {
foundDefaultAssociation = true
break
}
}

if !foundDefaultAssociation {
return fmt.Errorf("no association of default Security Group (%s) with VPC Endpoint (%s)", defaultSecurityGroupID, vpcEndpointID)
}
}

err := createVpcEndpointSecurityGroupAssociation(conn, vpcEndpointID, securityGroupID)

if err != nil {
return err
}

d.SetId(VPCEndpointSecurityGroupAssociationCreateID(vpcEndpointID, securityGroupID))

if replaceDefaultAssociation {
// Delete the existing VPC endpoint/default security group association.
if err := deleteVpcEndpointSecurityGroupAssociation(conn, vpcEndpointID, defaultSecurityGroupID); err != nil {
return err
}
}

return resourceVPCEndpointSecurityGroupAssociationRead(d, meta)
}

func resourceVPCEndpointSecurityGroupAssociationRead(d *schema.ResourceData, meta interface{}) error {
conn := meta.(*conns.AWSClient).EC2Conn

vpcEndpointID := d.Get("vpc_endpoint_id").(string)
securityGroupID := d.Get("security_group_id").(string)
// Human friendly ID for error messages since d.Id() is non-descriptive
id := fmt.Sprintf("%s/%s", vpcEndpointID, securityGroupID)

err := FindVPCEndpointSecurityGroupAssociationExists(conn, vpcEndpointID, securityGroupID)

if !d.IsNewResource() && tfresource.NotFound(err) {
log.Printf("[WARN] VPC Endpoint Security Group Association (%s) not found, removing from state", id)
d.SetId("")
return nil
}

if err != nil {
return fmt.Errorf("error reading VPC Security Group Association (%s): %w", id, err)
}

return nil
}

func resourceVPCEndpointSecurityGroupAssociationDelete(d *schema.ResourceData, meta interface{}) error {
conn := meta.(*conns.AWSClient).EC2Conn

vpcEndpointID := d.Get("vpc_endpoint_id").(string)
securityGroupID := d.Get("security_group_id").(string)
replaceDefaultAssociation := d.Get("replace_default_association").(bool)

if replaceDefaultAssociation {
vpcEndpoint, err := FindVPCEndpointByID(conn, vpcEndpointID)

if err != nil {
return fmt.Errorf("error reading VPC Endpoint (%s): %w", vpcEndpointID, err)
}

vpcID := aws.StringValue(vpcEndpoint.VpcId)

defaultSecurityGroup, err := FindVPCDefaultSecurityGroup(conn, vpcID)

if err != nil {
return fmt.Errorf("error reading EC2 VPC (%s) default Security Group: %w", vpcID, err)
}

// Add back the VPC endpoint/default security group association.
err = createVpcEndpointSecurityGroupAssociation(conn, vpcEndpointID, aws.StringValue(defaultSecurityGroup.GroupId))

if err != nil {
return err
}
}

return deleteVpcEndpointSecurityGroupAssociation(conn, vpcEndpointID, securityGroupID)
}

// createVpcEndpointSecurityGroupAssociation creates the specified VPC endpoint/security group association.
func createVpcEndpointSecurityGroupAssociation(conn *ec2.EC2, vpcEndpointID, securityGroupID string) error {
input := &ec2.ModifyVpcEndpointInput{
VpcEndpointId: aws.String(vpcEndpointID),
AddSecurityGroupIds: aws.StringSlice([]string{securityGroupID}),
}

log.Printf("[DEBUG] Creating VPC Endpoint Security Group Association: %s", input)
_, err := conn.ModifyVpcEndpoint(input)

if err != nil {
return fmt.Errorf("error creating VPC Endpoint (%s) Security Group (%s) Association: %w", vpcEndpointID, securityGroupID, err)
}

return nil
}

// deleteVpcEndpointSecurityGroupAssociation deletes the specified VPC endpoint/security group association.
func deleteVpcEndpointSecurityGroupAssociation(conn *ec2.EC2, vpcEndpointID, securityGroupID string) error {
input := &ec2.ModifyVpcEndpointInput{
VpcEndpointId: aws.String(vpcEndpointID),
RemoveSecurityGroupIds: aws.StringSlice([]string{securityGroupID}),
}

log.Printf("[DEBUG] Deleting VPC Endpoint Security Group Association: %s", input)
_, err := conn.ModifyVpcEndpoint(input)

if tfawserr.ErrCodeEquals(err, ErrCodeInvalidVpcEndpointIdNotFound, ErrCodeInvalidGroupNotFound, ErrCodeInvalidParameter) {
return nil
}

if err != nil {
return fmt.Errorf("error deleting VPC Endpoint (%s) Security Group (%s) Association: %w", vpcEndpointID, securityGroupID, err)
}

return nil
}
Loading