Skip to content

Commit

Permalink
Merge pull request #16930 from ewbankkit/b-various-aws_route-fixes-Ma…
Browse files Browse the repository at this point in the history
…rkVII

r/aws_route: Correctly handle update of route target
  • Loading branch information
YakDriver authored Mar 25, 2021
2 parents 9216b97 + 5e0d0f6 commit e1ecea3
Show file tree
Hide file tree
Showing 13 changed files with 1,734 additions and 869 deletions.
7 changes: 7 additions & 0 deletions .changelog/16930.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
```release-note:enhancement
resource/aws_route: Validate route destination and target attributes
```

```release-note:bug
resource/aws_route: Correctly handle updates to the route target attributes (`egress_only_gateway_id`, `gateway_id`, `instance_id`, `local_gateway_id`, `nat_gateway_id`, `network_interface_id`, `transit_gateway_id`, `vpc_peering_connection_id`)
```
15 changes: 15 additions & 0 deletions aws/data_source_aws_route.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"fmt"
"log"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
tfec2 "github.com/terraform-providers/terraform-provider-aws/aws/internal/service/ec2"
)

func dataSourceAwsRoute() *schema.Resource {
Expand Down Expand Up @@ -202,3 +204,16 @@ func getRoutes(table *ec2.RouteTable, d *schema.ResourceData) []*ec2.Route {
}
return routes
}

// Helper: Create an ID for a route
func resourceAwsRouteID(d *schema.ResourceData, r *ec2.Route) string {
routeTableID := d.Get("route_table_id").(string)

if destination := aws.StringValue(r.DestinationCidrBlock); destination != "" {
return tfec2.RouteCreateID(routeTableID, destination)
} else if destination := aws.StringValue(r.DestinationIpv6CidrBlock); destination != "" {
return tfec2.RouteCreateID(routeTableID, destination)
}

return ""
}
23 changes: 23 additions & 0 deletions aws/internal/net/cidr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package net

import (
"net"
)

// CIDRBlocksEqual returns whether or not two CIDR blocks are equal:
// - Both CIDR blocks parse to an IP address and network
// - The string representation of the IP addresses are equal
// - The string representation of the networks are equal
// This function is especially useful for IPv6 CIDR blocks which have multiple valid representations.
func CIDRBlocksEqual(cidr1, cidr2 string) bool {
ip1, ipnet1, err := net.ParseCIDR(cidr1)
if err != nil {
return false
}
ip2, ipnet2, err := net.ParseCIDR(cidr2)
if err != nil {
return false
}

return ip2.String() == ip1.String() && ipnet2.String() == ipnet1.String()
}
26 changes: 26 additions & 0 deletions aws/internal/net/cidr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package net

import (
"testing"
)

func Test_CIDRBlocksEqual(t *testing.T) {
for _, ts := range []struct {
cidr1 string
cidr2 string
equal bool
}{
{"10.2.2.0/24", "10.2.2.0/24", true},
{"10.2.2.0/1234", "10.2.2.0/24", false},
{"10.2.2.0/24", "10.2.2.0/1234", false},
{"2001::/15", "2001::/15", true},
{"::/0", "2001::/15", false},
{"::/0", "::0/0", true},
{"", "", false},
} {
equal := CIDRBlocksEqual(ts.cidr1, ts.cidr2)
if ts.equal != equal {
t.Fatalf("CIDRBlocksEqual(%q, %q) should be: %t", ts.cidr1, ts.cidr2, ts.equal)
}
}
}
8 changes: 7 additions & 1 deletion aws/internal/service/ec2/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import (
)

const (
ErrCodeInvalidParameterValue = "InvalidParameterValue"
ErrCodeInvalidParameterException = "InvalidParameterException"
ErrCodeInvalidParameterValue = "InvalidParameterValue"
)

const (
Expand All @@ -21,9 +22,14 @@ const (
)

const (
ErrCodeInvalidRouteNotFound = "InvalidRoute.NotFound"
ErrCodeInvalidRouteTableIDNotFound = "InvalidRouteTableID.NotFound"
)

const (
ErrCodeInvalidTransitGatewayIDNotFound = "InvalidTransitGatewayID.NotFound"
)

const (
ErrCodeClientVpnEndpointIdNotFound = "InvalidClientVpnEndpointId.NotFound"
ErrCodeClientVpnAuthorizationRuleNotFound = "InvalidClientVpnEndpointAuthorizationRuleNotFound"
Expand Down
75 changes: 75 additions & 0 deletions aws/internal/service/ec2/finder/finder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/hashicorp/aws-sdk-go-base/tfawserr"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
tfnet "github.com/terraform-providers/terraform-provider-aws/aws/internal/net"
tfec2 "github.com/terraform-providers/terraform-provider-aws/aws/internal/service/ec2"
)

Expand Down Expand Up @@ -95,6 +98,78 @@ func InstanceByID(conn *ec2.EC2, id string) (*ec2.Instance, error) {
return output.Reservations[0].Instances[0], nil
}

// RouteTableByID returns the route table corresponding to the specified identifier.
// Returns NotFoundError if no route table is found.
func RouteTableByID(conn *ec2.EC2, routeTableID string) (*ec2.RouteTable, error) {
input := &ec2.DescribeRouteTablesInput{
RouteTableIds: aws.StringSlice([]string{routeTableID}),
}

return RouteTable(conn, input)
}

func RouteTable(conn *ec2.EC2, input *ec2.DescribeRouteTablesInput) (*ec2.RouteTable, error) {
output, err := conn.DescribeRouteTables(input)

if tfawserr.ErrCodeEquals(err, tfec2.ErrCodeInvalidRouteTableIDNotFound) {
return nil, &resource.NotFoundError{
LastError: err,
LastRequest: input,
}
}

if err != nil {
return nil, err
}

if output == nil || len(output.RouteTables) == 0 || output.RouteTables[0] == nil {
return nil, &resource.NotFoundError{
Message: "Empty result",
LastRequest: input,
}
}

return output.RouteTables[0], nil
}

// RouteFinder returns the route corresponding to the specified destination.
// Returns NotFoundError if no route is found.
type RouteFinder func(*ec2.EC2, string, string) (*ec2.Route, error)

// RouteByIPv4Destination returns the route corresponding to the specified IPv4 destination.
// Returns NotFoundError if no route is found.
func RouteByIPv4Destination(conn *ec2.EC2, routeTableID, destinationCidr string) (*ec2.Route, error) {
routeTable, err := RouteTableByID(conn, routeTableID)
if err != nil {
return nil, err
}

for _, route := range routeTable.Routes {
if tfnet.CIDRBlocksEqual(aws.StringValue(route.DestinationCidrBlock), destinationCidr) {
return route, nil
}
}

return nil, &resource.NotFoundError{}
}

// RouteByIPv6Destination returns the route corresponding to the specified IPv6 destination.
// Returns NotFoundError if no route is found.
func RouteByIPv6Destination(conn *ec2.EC2, routeTableID, destinationIpv6Cidr string) (*ec2.Route, error) {
routeTable, err := RouteTableByID(conn, routeTableID)
if err != nil {
return nil, err
}

for _, route := range routeTable.Routes {
if tfnet.CIDRBlocksEqual(aws.StringValue(route.DestinationIpv6CidrBlock), destinationIpv6Cidr) {
return route, nil
}
}

return nil, &resource.NotFoundError{}
}

// SecurityGroupByID looks up a security group by ID. When not found, returns nil and potentially an API error.
func SecurityGroupByID(conn *ec2.EC2, id string) (*ec2.SecurityGroup, error) {
req := &ec2.DescribeSecurityGroupsInput{
Expand Down
5 changes: 5 additions & 0 deletions aws/internal/service/ec2/id.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ func ClientVpnRouteParseID(id string) (string, string, string, error) {
"target-subnet-id"+clientVpnRouteIDSeparator+"destination-cidr-block", id)
}

// RouteCreateID returns a route resource ID.
func RouteCreateID(routeTableID, destination string) string {
return fmt.Sprintf("r-%s%d", routeTableID, hashcode.String(destination))
}

const transitGatewayPrefixListReferenceSeparator = "_"

func TransitGatewayPrefixListReferenceCreateID(transitGatewayRouteTableID string, prefixListID string) string {
Expand Down
8 changes: 8 additions & 0 deletions aws/resource_aws_lb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
)

func init() {
RegisterServiceErrorCheckFunc(elbv2.EndpointsID, testAccErrorCheckSkipELBv2)

resource.AddTestSweepers("aws_lb", &resource.Sweeper{
Name: "aws_lb",
F: testSweepLBs,
Expand All @@ -25,6 +27,12 @@ func init() {
})
}

func testAccErrorCheckSkipELBv2(t *testing.T) resource.ErrorCheckFunc {
return testAccErrorCheckSkipMessagesContaining(t,
"ValidationError: Type must be one of: 'application, network'",
)
}

func testSweepLBs(region string) error {
client, err := sharedClientForRegion(region)
if err != nil {
Expand Down
Loading

0 comments on commit e1ecea3

Please sign in to comment.