Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate IMDS, command completion, and region validation to aws-sdk-go-v2 #16430

Merged
merged 3 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cmd/kops/create_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"sort"
"strings"

"github.com/aws/aws-sdk-go/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/google/go-containerregistry/pkg/name"
"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -1052,7 +1052,12 @@ func completeSecurityGroup(cmd *cobra.Command, args []string, toComplete string)
}

func completeTenancy(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return ec2.Tenancy_Values(), cobra.ShellCompDirectiveNoFileComp
tenancies := ec2types.Tenancy("").Values()
values := make([]string, len(tenancies))
for i, v := range tenancies {
values[i] = string(v)
}
return values, cobra.ShellCompDirectiveNoFileComp
}

func completeSSLCertificate(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
Expand Down
16 changes: 7 additions & 9 deletions upup/pkg/fi/cloudup/awsup/aws_authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
"encoding/json"
"fmt"

awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
Expand All @@ -41,20 +42,17 @@ var _ bootstrap.Authenticator = &awsAuthenticator{}

// RegionFromMetadata returns the current region from the aws metdata
func RegionFromMetadata(ctx context.Context) (string, error) {
config := aws.NewConfig()
config = config.WithCredentialsChainVerboseErrors(true)

s, err := session.NewSession(config)
cfg, err := awsconfig.LoadDefaultConfig(ctx)
if err != nil {
return "", err
return "", fmt.Errorf("failed to load default aws config: %w", err)
}
metadata := ec2metadata.New(s, config)
metadata := imds.NewFromConfig(cfg)

region, err := metadata.RegionWithContext(ctx)
resp, err := metadata.GetRegion(ctx, &imds.GetRegionInput{})
if err != nil {
return "", fmt.Errorf("failed to get region from ec2 metadata: %w", err)
}
return region, nil
return resp.Region, nil
}

func NewAWSAuthenticator(region string) (bootstrap.Authenticator, error) {
Expand Down
47 changes: 15 additions & 32 deletions upup/pkg/fi/cloudup/awsup/aws_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@ limitations under the License.
package awsup

import (
"context"
"fmt"
"os"
"strings"
"sync"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
ec2v2 "github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/autoscaling"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/elb"
Expand All @@ -37,62 +40,42 @@ import (
)

// allRegions is the list of all regions; tests will set the values
var allRegions []*ec2.Region
var allRegions []ec2types.Region
var allRegionsMutex sync.Mutex

// isRegionCompiledInToAWSSDK checks if the specified region is in the AWS SDK
func isRegionCompiledInToAWSSDK(region string) bool {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is no longer possible, see aws/aws-sdk-go-v2#2586

resolver := endpoints.DefaultResolver()
partitions := resolver.(endpoints.EnumPartitions).Partitions()
for _, p := range partitions {
for _, r := range p.Regions() {
if r.ID() == region {
return true
}
}
}
return false
}

// ValidateRegion checks that an AWS region name is valid
func ValidateRegion(region string) error {
if isRegionCompiledInToAWSSDK(region) {
return nil
}

func ValidateRegion(ctx context.Context, region string) error {
allRegionsMutex.Lock()
defer allRegionsMutex.Unlock()

if allRegions == nil {
klog.V(2).Infof("Querying EC2 for all valid regions")

request := &ec2.DescribeRegionsInput{}
request := &ec2v2.DescribeRegionsInput{}
awsRegion := os.Getenv("AWS_REGION")
if awsRegion == "" {
awsRegion = "us-east-1"
}
config := aws.NewConfig().WithRegion(awsRegion)
config = config.WithCredentialsChainVerboseErrors(true)
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(awsRegion))
if err != nil {
return fmt.Errorf("error loading AWS config: %v", err)
}

sess, err := session.NewSessionWithOptions(session.Options{
Config: *config,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return fmt.Errorf("error starting a new AWS session: %v", err)
}

client := ec2.New(sess, config)
client := ec2v2.NewFromConfig(cfg)

response, err := client.DescribeRegions(request)
response, err := client.DescribeRegions(ctx, request)
if err != nil {
return fmt.Errorf("got an error while querying for valid regions (verify your AWS credentials?): %v", err)
}
allRegions = response.Regions
}

for _, r := range allRegions {
name := aws.StringValue(r.RegionName)
name := awsv2.ToString(r.RegionName)
if name == region {
return nil
}
Expand Down
9 changes: 6 additions & 3 deletions upup/pkg/fi/cloudup/awsup/aws_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@ limitations under the License.
package awsup

import (
"context"
"reflect"
"testing"

ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"k8s.io/kops/pkg/apis/kops"
)

func TestValidateRegion(t *testing.T) {
allRegions = []*ec2.Region{
ctx := context.Background()
allRegions = []ec2types.Region{
{
RegionName: aws.String("us-test-1"),
},
Expand All @@ -35,14 +38,14 @@ func TestValidateRegion(t *testing.T) {
},
}
for _, region := range []string{"us-test-1", "us-test-2"} {
err := ValidateRegion(region)
err := ValidateRegion(ctx, region)
if err != nil {
t.Fatalf("unexpected error validating region %q: %v", region, err)
}
}

for _, region := range []string{"is-lost-1", "no-road-2", "no-real-3"} {
err := ValidateRegion(region)
err := ValidateRegion(ctx, region)
if err == nil {
t.Fatalf("expected error validating region %q", region)
}
Expand Down
3 changes: 2 additions & 1 deletion upup/pkg/fi/cloudup/awsup/mock_aws_cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"

ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface"
Expand Down Expand Up @@ -57,7 +58,7 @@ var _ fi.Cloud = (*MockAWSCloud)(nil)
func InstallMockAWSCloud(region string, zoneLetters string) *MockAWSCloud {
i := BuildMockAWSCloud(region, zoneLetters)
updateAwsCloudInstances(region, i)
allRegions = []*ec2.Region{
allRegions = []ec2types.Region{
{RegionName: aws.String(region)},
}
return i
Expand Down
4 changes: 3 additions & 1 deletion upup/pkg/fi/cloudup/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package cloudup

import (
"context"
"fmt"
"strings"

Expand All @@ -36,6 +37,7 @@ import (

func BuildCloud(cluster *kops.Cluster) (fi.Cloud, error) {
var cloud fi.Cloud
ctx := context.TODO()

region := ""
project := ""
Expand Down Expand Up @@ -75,7 +77,7 @@ func BuildCloud(cluster *kops.Cluster) (fi.Cloud, error) {
return nil, err
}

err = awsup.ValidateRegion(region)
err = awsup.ValidateRegion(ctx, region)
if err != nil {
return nil, err
}
Expand Down
37 changes: 24 additions & 13 deletions upup/pkg/fi/nodeup/nodetasks/prefix.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ limitations under the License.
package nodetasks

import (
"context"
"errors"
"fmt"
"io"
"net/http"
"path"
"strings"

awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"k8s.io/klog/v2"
"k8s.io/kops/pkg/apis/kops"
Expand Down Expand Up @@ -54,12 +56,12 @@ func (e *Prefix) Find(c *fi.NodeupContext) (*Prefix, error) {
return nil, fmt.Errorf("unsupported cloud provider: %s", c.T.BootConfig.CloudProvider)
}

mac, err := getInstanceMetadataFirstValue("mac")
mac, err := getInstanceMetadataFirstValue(c.Context(), "mac")
if err != nil {
return nil, err
}

prefixes, err := getInstanceMetadataList(path.Join("network/interfaces/macs/", mac, "/ipv6-prefix"))
prefixes, err := getInstanceMetadataList(c.Context(), path.Join("network/interfaces/macs/", mac, "/ipv6-prefix"))
if err != nil {
return nil, err
}
Expand All @@ -83,12 +85,13 @@ func (_ *Prefix) CheckChanges(a, e, changes *Prefix) error {
}

func (_ *Prefix) RenderLocal(t *local.LocalTarget, a, e, changes *Prefix) error {
mac, err := getInstanceMetadataFirstValue("mac")
ctx := context.TODO()
mac, err := getInstanceMetadataFirstValue(ctx, "mac")
if err != nil {
return err
}

interfaceId, err := getInstanceMetadataFirstValue(path.Join("network/interfaces/macs/", mac, "/interface-id"))
interfaceId, err := getInstanceMetadataFirstValue(ctx, path.Join("network/interfaces/macs/", mac, "/interface-id"))
if err != nil {
return err
}
Expand All @@ -105,8 +108,8 @@ func (_ *Prefix) RenderLocal(t *local.LocalTarget, a, e, changes *Prefix) error
return nil
}

func getInstanceMetadataFirstValue(category string) (string, error) {
values, err := getInstanceMetadataList(category)
func getInstanceMetadataFirstValue(ctx context.Context, category string) (string, error) {
values, err := getInstanceMetadataList(ctx, category)
if err != nil {
return "", err
}
Expand All @@ -117,10 +120,13 @@ func getInstanceMetadataFirstValue(category string) (string, error) {
return values[0], nil
}

func getInstanceMetadataList(category string) ([]string, error) {
sess := session.Must(session.NewSession())
metadata := ec2metadata.New(sess)
linesStr, err := metadata.GetMetadata(category)
func getInstanceMetadataList(ctx context.Context, category string) ([]string, error) {
cfg, err := awsconfig.LoadDefaultConfig(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load aws config: %v", err)
}
metadata := imds.NewFromConfig(cfg)
resp, err := metadata.GetMetadata(ctx, &imds.GetMetadataInput{Path: category})
if err != nil {
var aerr awserr.RequestFailure
if errors.As(err, &aerr) && aerr.StatusCode() == http.StatusNotFound {
Expand All @@ -129,9 +135,14 @@ func getInstanceMetadataList(category string) ([]string, error) {
return nil, fmt.Errorf("failed to get %q from ec2 meta-data: %v", category, err)
}
}
defer resp.Content.Close()
lines, err := io.ReadAll(resp.Content)
if err != nil {
return nil, fmt.Errorf("failed to read %q from ec2 meta-data: %v", category, err)
}

var values []string
for _, line := range strings.Split(linesStr, "\n") {
for _, line := range strings.Split(string(lines), "\n") {
line = strings.TrimSpace(line)
if len(line) > 0 {
values = append(values, line)
Expand Down
Loading