Skip to content

Commit

Permalink
Merge pull request #16430 from rifelpet/aws-sdk-go-v2
Browse files Browse the repository at this point in the history
Migrate IMDS, command completion, and region validation to aws-sdk-go-v2
  • Loading branch information
k8s-ci-robot authored Mar 29, 2024
2 parents db03ce8 + 498bcc1 commit 2fcbd9e
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 61 deletions.
9 changes: 7 additions & 2 deletions cmd/kops/create_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"sort"
"strings"

"github.com/aws/aws-sdk-go/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/google/go-containerregistry/pkg/name"
"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -1052,7 +1052,12 @@ func completeSecurityGroup(cmd *cobra.Command, args []string, toComplete string)
}

func completeTenancy(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return ec2.Tenancy_Values(), cobra.ShellCompDirectiveNoFileComp
tenancies := ec2types.Tenancy("").Values()
values := make([]string, len(tenancies))
for i, v := range tenancies {
values[i] = string(v)
}
return values, cobra.ShellCompDirectiveNoFileComp
}

func completeSSLCertificate(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
Expand Down
16 changes: 7 additions & 9 deletions upup/pkg/fi/cloudup/awsup/aws_authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
"encoding/json"
"fmt"

awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
Expand All @@ -41,20 +42,17 @@ var _ bootstrap.Authenticator = &awsAuthenticator{}

// RegionFromMetadata returns the current region from the aws metdata
func RegionFromMetadata(ctx context.Context) (string, error) {
config := aws.NewConfig()
config = config.WithCredentialsChainVerboseErrors(true)

s, err := session.NewSession(config)
cfg, err := awsconfig.LoadDefaultConfig(ctx)
if err != nil {
return "", err
return "", fmt.Errorf("failed to load default aws config: %w", err)
}
metadata := ec2metadata.New(s, config)
metadata := imds.NewFromConfig(cfg)

region, err := metadata.RegionWithContext(ctx)
resp, err := metadata.GetRegion(ctx, &imds.GetRegionInput{})
if err != nil {
return "", fmt.Errorf("failed to get region from ec2 metadata: %w", err)
}
return region, nil
return resp.Region, nil
}

func NewAWSAuthenticator(region string) (bootstrap.Authenticator, error) {
Expand Down
47 changes: 15 additions & 32 deletions upup/pkg/fi/cloudup/awsup/aws_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@ limitations under the License.
package awsup

import (
"context"
"fmt"
"os"
"strings"
"sync"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
ec2v2 "github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/autoscaling"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/elb"
Expand All @@ -37,62 +40,42 @@ import (
)

// allRegions is the list of all regions; tests will set the values
var allRegions []*ec2.Region
var allRegions []ec2types.Region
var allRegionsMutex sync.Mutex

// isRegionCompiledInToAWSSDK checks if the specified region is in the AWS SDK
func isRegionCompiledInToAWSSDK(region string) bool {
resolver := endpoints.DefaultResolver()
partitions := resolver.(endpoints.EnumPartitions).Partitions()
for _, p := range partitions {
for _, r := range p.Regions() {
if r.ID() == region {
return true
}
}
}
return false
}

// ValidateRegion checks that an AWS region name is valid
func ValidateRegion(region string) error {
if isRegionCompiledInToAWSSDK(region) {
return nil
}

func ValidateRegion(ctx context.Context, region string) error {
allRegionsMutex.Lock()
defer allRegionsMutex.Unlock()

if allRegions == nil {
klog.V(2).Infof("Querying EC2 for all valid regions")

request := &ec2.DescribeRegionsInput{}
request := &ec2v2.DescribeRegionsInput{}
awsRegion := os.Getenv("AWS_REGION")
if awsRegion == "" {
awsRegion = "us-east-1"
}
config := aws.NewConfig().WithRegion(awsRegion)
config = config.WithCredentialsChainVerboseErrors(true)
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(awsRegion))
if err != nil {
return fmt.Errorf("error loading AWS config: %v", err)
}

sess, err := session.NewSessionWithOptions(session.Options{
Config: *config,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return fmt.Errorf("error starting a new AWS session: %v", err)
}

client := ec2.New(sess, config)
client := ec2v2.NewFromConfig(cfg)

response, err := client.DescribeRegions(request)
response, err := client.DescribeRegions(ctx, request)
if err != nil {
return fmt.Errorf("got an error while querying for valid regions (verify your AWS credentials?): %v", err)
}
allRegions = response.Regions
}

for _, r := range allRegions {
name := aws.StringValue(r.RegionName)
name := awsv2.ToString(r.RegionName)
if name == region {
return nil
}
Expand Down
9 changes: 6 additions & 3 deletions upup/pkg/fi/cloudup/awsup/aws_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@ limitations under the License.
package awsup

import (
"context"
"reflect"
"testing"

ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"k8s.io/kops/pkg/apis/kops"
)

func TestValidateRegion(t *testing.T) {
allRegions = []*ec2.Region{
ctx := context.Background()
allRegions = []ec2types.Region{
{
RegionName: aws.String("us-test-1"),
},
Expand All @@ -35,14 +38,14 @@ func TestValidateRegion(t *testing.T) {
},
}
for _, region := range []string{"us-test-1", "us-test-2"} {
err := ValidateRegion(region)
err := ValidateRegion(ctx, region)
if err != nil {
t.Fatalf("unexpected error validating region %q: %v", region, err)
}
}

for _, region := range []string{"is-lost-1", "no-road-2", "no-real-3"} {
err := ValidateRegion(region)
err := ValidateRegion(ctx, region)
if err == nil {
t.Fatalf("expected error validating region %q", region)
}
Expand Down
3 changes: 2 additions & 1 deletion upup/pkg/fi/cloudup/awsup/mock_aws_cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"

ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface"
Expand Down Expand Up @@ -57,7 +58,7 @@ var _ fi.Cloud = (*MockAWSCloud)(nil)
func InstallMockAWSCloud(region string, zoneLetters string) *MockAWSCloud {
i := BuildMockAWSCloud(region, zoneLetters)
updateAwsCloudInstances(region, i)
allRegions = []*ec2.Region{
allRegions = []ec2types.Region{
{RegionName: aws.String(region)},
}
return i
Expand Down
4 changes: 3 additions & 1 deletion upup/pkg/fi/cloudup/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package cloudup

import (
"context"
"fmt"
"strings"

Expand All @@ -36,6 +37,7 @@ import (

func BuildCloud(cluster *kops.Cluster) (fi.Cloud, error) {
var cloud fi.Cloud
ctx := context.TODO()

region := ""
project := ""
Expand Down Expand Up @@ -75,7 +77,7 @@ func BuildCloud(cluster *kops.Cluster) (fi.Cloud, error) {
return nil, err
}

err = awsup.ValidateRegion(region)
err = awsup.ValidateRegion(ctx, region)
if err != nil {
return nil, err
}
Expand Down
37 changes: 24 additions & 13 deletions upup/pkg/fi/nodeup/nodetasks/prefix.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ limitations under the License.
package nodetasks

import (
"context"
"errors"
"fmt"
"io"
"net/http"
"path"
"strings"

awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"k8s.io/klog/v2"
"k8s.io/kops/pkg/apis/kops"
Expand Down Expand Up @@ -54,12 +56,12 @@ func (e *Prefix) Find(c *fi.NodeupContext) (*Prefix, error) {
return nil, fmt.Errorf("unsupported cloud provider: %s", c.T.BootConfig.CloudProvider)
}

mac, err := getInstanceMetadataFirstValue("mac")
mac, err := getInstanceMetadataFirstValue(c.Context(), "mac")
if err != nil {
return nil, err
}

prefixes, err := getInstanceMetadataList(path.Join("network/interfaces/macs/", mac, "/ipv6-prefix"))
prefixes, err := getInstanceMetadataList(c.Context(), path.Join("network/interfaces/macs/", mac, "/ipv6-prefix"))
if err != nil {
return nil, err
}
Expand All @@ -83,12 +85,13 @@ func (_ *Prefix) CheckChanges(a, e, changes *Prefix) error {
}

func (_ *Prefix) RenderLocal(t *local.LocalTarget, a, e, changes *Prefix) error {
mac, err := getInstanceMetadataFirstValue("mac")
ctx := context.TODO()
mac, err := getInstanceMetadataFirstValue(ctx, "mac")
if err != nil {
return err
}

interfaceId, err := getInstanceMetadataFirstValue(path.Join("network/interfaces/macs/", mac, "/interface-id"))
interfaceId, err := getInstanceMetadataFirstValue(ctx, path.Join("network/interfaces/macs/", mac, "/interface-id"))
if err != nil {
return err
}
Expand All @@ -105,8 +108,8 @@ func (_ *Prefix) RenderLocal(t *local.LocalTarget, a, e, changes *Prefix) error
return nil
}

func getInstanceMetadataFirstValue(category string) (string, error) {
values, err := getInstanceMetadataList(category)
func getInstanceMetadataFirstValue(ctx context.Context, category string) (string, error) {
values, err := getInstanceMetadataList(ctx, category)
if err != nil {
return "", err
}
Expand All @@ -117,10 +120,13 @@ func getInstanceMetadataFirstValue(category string) (string, error) {
return values[0], nil
}

func getInstanceMetadataList(category string) ([]string, error) {
sess := session.Must(session.NewSession())
metadata := ec2metadata.New(sess)
linesStr, err := metadata.GetMetadata(category)
func getInstanceMetadataList(ctx context.Context, category string) ([]string, error) {
cfg, err := awsconfig.LoadDefaultConfig(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load aws config: %v", err)
}
metadata := imds.NewFromConfig(cfg)
resp, err := metadata.GetMetadata(ctx, &imds.GetMetadataInput{Path: category})
if err != nil {
var aerr awserr.RequestFailure
if errors.As(err, &aerr) && aerr.StatusCode() == http.StatusNotFound {
Expand All @@ -129,9 +135,14 @@ func getInstanceMetadataList(category string) ([]string, error) {
return nil, fmt.Errorf("failed to get %q from ec2 meta-data: %v", category, err)
}
}
defer resp.Content.Close()
lines, err := io.ReadAll(resp.Content)
if err != nil {
return nil, fmt.Errorf("failed to read %q from ec2 meta-data: %v", category, err)
}

var values []string
for _, line := range strings.Split(linesStr, "\n") {
for _, line := range strings.Split(string(lines), "\n") {
line = strings.TrimSpace(line)
if len(line) > 0 {
values = append(values, line)
Expand Down

0 comments on commit 2fcbd9e

Please sign in to comment.