From 62df0dba048187d4760c184bbec642c4e4120707 Mon Sep 17 00:00:00 2001 From: Peter Rifel Date: Sun, 21 Apr 2024 07:06:16 -0400 Subject: [PATCH 1/2] Migrate AWS Verifier to aws-sdk-go-v2 --- cmd/kops-controller/main.go | 2 +- nodeup/pkg/model/bootstrap_client.go | 2 +- .../pkg/fi/cloudup/awsup/aws_authenticator.go | 34 ++++---- upup/pkg/fi/cloudup/awsup/aws_cloud.go | 3 +- upup/pkg/fi/cloudup/awsup/aws_verifier.go | 77 ++++++++++--------- upup/pkg/fi/nodeup/command.go | 2 +- 6 files changed, 57 insertions(+), 63 deletions(-) diff --git a/cmd/kops-controller/main.go b/cmd/kops-controller/main.go index a72ca0ca3ec93..a82127d1113eb 100644 --- a/cmd/kops-controller/main.go +++ b/cmd/kops-controller/main.go @@ -130,7 +130,7 @@ func main() { var verifiers []bootstrap.Verifier var err error if opt.Server.Provider.AWS != nil { - verifier, err := awsup.NewAWSVerifier(opt.Server.Provider.AWS) + verifier, err := awsup.NewAWSVerifier(ctx, opt.Server.Provider.AWS) if err != nil { setupLog.Error(err, "unable to create verifier") os.Exit(1) diff --git a/nodeup/pkg/model/bootstrap_client.go b/nodeup/pkg/model/bootstrap_client.go index c397e29f40ca6..029068c1d2cf3 100644 --- a/nodeup/pkg/model/bootstrap_client.go +++ b/nodeup/pkg/model/bootstrap_client.go @@ -52,7 +52,7 @@ func (b BootstrapClientBuilder) Build(c *fi.NodeupModelBuilderContext) error { switch b.CloudProvider() { case kops.CloudProviderAWS: - a, err := awsup.NewAWSAuthenticator(b.Cloud.Region()) + a, err := awsup.NewAWSAuthenticator(c.Context(), b.Cloud.Region()) if err != nil { return err } diff --git a/upup/pkg/fi/cloudup/awsup/aws_authenticator.go b/upup/pkg/fi/cloudup/awsup/aws_authenticator.go index 68bc7df5ae3db..b40980a56494c 100644 --- a/upup/pkg/fi/cloudup/awsup/aws_authenticator.go +++ b/upup/pkg/fi/cloudup/awsup/aws_authenticator.go @@ -25,17 +25,15 @@ import ( awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/service/sts" + smithyhttp "github.com/aws/smithy-go/transport/http" "k8s.io/kops/pkg/bootstrap" ) const AWSAuthenticationTokenPrefix = "x-aws-sts " type awsAuthenticator struct { - sts *sts.STS + sts *sts.Client } var _ bootstrap.Authenticator = &awsAuthenticator{} @@ -55,32 +53,28 @@ func RegionFromMetadata(ctx context.Context) (string, error) { return resp.Region, nil } -func NewAWSAuthenticator(region string) (bootstrap.Authenticator, error) { - config := aws.NewConfig(). - WithCredentialsChainVerboseErrors(true). - WithRegion(region). - WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint) - sess, err := session.NewSession(config) +func NewAWSAuthenticator(ctx context.Context, region string) (bootstrap.Authenticator, error) { + config, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to load aws config: %w", err) } return &awsAuthenticator{ - sts: sts.New(sess, config), + sts: sts.NewFromConfig(config), }, nil } func (a *awsAuthenticator) CreateToken(body []byte) (string, error) { sha := sha256.Sum256(body) - stsRequest, _ := a.sts.GetCallerIdentityRequest(nil) + presignClient := sts.NewPresignClient(a.sts) // Ensure the signature is only valid for this particular body content. - stsRequest.HTTPRequest.Header.Add("X-Kops-Request-SHA", base64.RawStdEncoding.EncodeToString(sha[:])) + stsRequest, _ := presignClient.PresignGetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) { + po.ClientOptions = append(po.ClientOptions, func(o *sts.Options) { + o.APIOptions = append(o.APIOptions, smithyhttp.AddHeaderValue("X-Kops-Request-SHA", base64.RawStdEncoding.EncodeToString(sha[:]))) + }) + }) - if err := stsRequest.Sign(); err != nil { - return "", err - } - - headers, _ := json.Marshal(stsRequest.HTTPRequest.Header) + headers, _ := json.Marshal(stsRequest.SignedHeader) return AWSAuthenticationTokenPrefix + base64.StdEncoding.EncodeToString(headers), nil } diff --git a/upup/pkg/fi/cloudup/awsup/aws_cloud.go b/upup/pkg/fi/cloudup/awsup/aws_cloud.go index f4139ae61413a..c907a2c069f65 100644 --- a/upup/pkg/fi/cloudup/awsup/aws_cloud.go +++ b/upup/pkg/fi/cloudup/awsup/aws_cloud.go @@ -46,7 +46,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/iam" "github.com/aws/aws-sdk-go-v2/service/route53" "github.com/aws/aws-sdk-go-v2/service/sts" - ec2v1 "github.com/aws/aws-sdk-go/service/ec2" "k8s.io/klog/v2" v1 "k8s.io/api/core/v1" @@ -2358,7 +2357,7 @@ func GetRolesInInstanceProfile(c AWSCloud, profileName string) ([]string, error) // GetInstanceCertificateNames returns the instance hostname and addresses that should go into certificates. // The first value is the node name and any additional values are the DNS name and IP addresses. -func GetInstanceCertificateNames(instances *ec2v1.DescribeInstancesOutput) (addrs []string, err error) { +func GetInstanceCertificateNames(instances *ec2.DescribeInstancesOutput) (addrs []string, err error) { if len(instances.Reservations) != 1 { return nil, fmt.Errorf("too many reservations returned for the single instance-id") } diff --git a/upup/pkg/fi/cloudup/awsup/aws_verifier.go b/upup/pkg/fi/cloudup/awsup/aws_verifier.go index f00aac8ecf7ea..0e13225b83c2e 100644 --- a/upup/pkg/fi/cloudup/awsup/aws_verifier.go +++ b/upup/pkg/fi/cloudup/awsup/aws_verifier.go @@ -27,15 +27,15 @@ import ( "io" "net" "net/http" + "net/url" "strconv" "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/sts" "k8s.io/kops/pkg/bootstrap" nodeidentityaws "k8s.io/kops/pkg/nodeidentity/aws" "k8s.io/kops/pkg/wellknownports" @@ -53,39 +53,38 @@ type awsVerifier struct { partition string opt AWSVerifierOptions - ec2 *ec2.EC2 - sts *sts.STS + ec2 *ec2.Client + sts *sts.PresignClient client http.Client } var _ bootstrap.Verifier = &awsVerifier{} -func NewAWSVerifier(opt *AWSVerifierOptions) (bootstrap.Verifier, error) { - config := aws.NewConfig(). - WithCredentialsChainVerboseErrors(true). - WithRegion(opt.Region). - WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint) - sess, err := session.NewSession(config) +func NewAWSVerifier(ctx context.Context, opt *AWSVerifierOptions) (bootstrap.Verifier, error) { + config, err := awsconfig.LoadDefaultConfig( + ctx, + awsconfig.WithRegion(opt.Region), + ) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to load aws config: %w", err) } - stsClient := sts.New(sess, config) - identity, err := stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + stsClient := sts.NewFromConfig(config) + identity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) if err != nil { return nil, err } - partition := strings.Split(aws.StringValue(identity.Arn), ":")[1] + partition := strings.Split(aws.ToString(identity.Arn), ":")[1] - ec2Client := ec2.New(sess, config) + ec2Client := ec2.NewFromConfig(config) return &awsVerifier{ - accountId: aws.StringValue(identity.Account), + accountId: aws.ToString(identity.Account), partition: partition, opt: *opt, ec2: ec2Client, - sts: stsClient, + sts: sts.NewPresignClient(stsClient), client: http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -128,35 +127,37 @@ func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request, token = strings.TrimPrefix(token, AWSAuthenticationTokenPrefix) // We rely on the client and server using the same version of the same STS library. - stsRequest, _ := a.sts.GetCallerIdentityRequest(nil) - err := stsRequest.Sign() + stsRequest, err := a.sts.PresignGetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) if err != nil { return nil, fmt.Errorf("creating identity request: %v", err) } - stsRequest.HTTPRequest.Header = nil + stsRequest.SignedHeader = nil tokenBytes, err := base64.StdEncoding.DecodeString(token) if err != nil { return nil, fmt.Errorf("decoding authorization token: %v", err) } - err = json.Unmarshal(tokenBytes, &stsRequest.HTTPRequest.Header) + err = json.Unmarshal(tokenBytes, &stsRequest.SignedHeader) if err != nil { return nil, fmt.Errorf("unmarshalling authorization token: %v", err) } // Verify the token has signed the body content. sha := sha256.Sum256(body) - if stsRequest.HTTPRequest.Header.Get("X-Kops-Request-SHA") != base64.RawStdEncoding.EncodeToString(sha[:]) { + if stsRequest.SignedHeader.Get("X-Kops-Request-SHA") != base64.RawStdEncoding.EncodeToString(sha[:]) { return nil, fmt.Errorf("incorrect SHA") } - requestBytes, _ := io.ReadAll(stsRequest.Body) - _, _ = stsRequest.Body.Seek(0, io.SeekStart) - if stsRequest.HTTPRequest.Header.Get("Content-Length") != strconv.Itoa(len(requestBytes)) { - return nil, fmt.Errorf("incorrect content-length") + reqURL, err := url.Parse(stsRequest.URL) + if err != nil { + return nil, fmt.Errorf("parsing STS request URL: %v", err) } - - response, err := a.client.Do(stsRequest.HTTPRequest) + req := &http.Request{ + URL: reqURL, + Method: stsRequest.Method, + Header: stsRequest.SignedHeader, + } + response, err := a.client.Do(req) if err != nil { return nil, fmt.Errorf("sending STS request: %v", err) } @@ -217,8 +218,8 @@ func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request, } instanceID := resource[2] - instances, err := a.ec2.DescribeInstances(&ec2.DescribeInstancesInput{ - InstanceIds: aws.StringSlice([]string{instanceID}), + instances, err := a.ec2.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ + InstanceIds: []string{instanceID}, }) if err != nil { return nil, fmt.Errorf("describing instance for arn %q", arn) @@ -240,17 +241,17 @@ func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request, var challengeEndpoints []string for _, nic := range instance.NetworkInterfaces { - if ip := aws.StringValue(nic.PrivateIpAddress); ip != "" { + if ip := aws.ToString(nic.PrivateIpAddress); ip != "" { challengeEndpoints = append(challengeEndpoints, net.JoinHostPort(ip, strconv.Itoa(wellknownports.NodeupChallenge))) } for _, a := range nic.PrivateIpAddresses { - if ip := aws.StringValue(a.PrivateIpAddress); ip != "" { + if ip := aws.ToString(a.PrivateIpAddress); ip != "" { challengeEndpoints = append(challengeEndpoints, net.JoinHostPort(ip, strconv.Itoa(wellknownports.NodeupChallenge))) } } for _, a := range nic.Ipv6Addresses { - if ip := aws.StringValue(a.Ipv6Address); ip != "" { + if ip := aws.ToString(a.Ipv6Address); ip != "" { challengeEndpoints = append(challengeEndpoints, net.JoinHostPort(ip, strconv.Itoa(wellknownports.NodeupChallenge))) } } @@ -267,9 +268,9 @@ func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request, } for _, tag := range instance.Tags { - tagKey := aws.StringValue(tag.Key) + tagKey := aws.ToString(tag.Key) if tagKey == nodeidentityaws.CloudTagInstanceGroupName { - result.InstanceGroupName = aws.StringValue(tag.Value) + result.InstanceGroupName = aws.ToString(tag.Value) } } diff --git a/upup/pkg/fi/nodeup/command.go b/upup/pkg/fi/nodeup/command.go index 3d48bf1526c87..e7cdbece3f71b 100644 --- a/upup/pkg/fi/nodeup/command.go +++ b/upup/pkg/fi/nodeup/command.go @@ -625,7 +625,7 @@ func getNodeConfigFromServers(ctx context.Context, bootConfig *nodeup.BootConfig switch bootConfig.CloudProvider { case api.CloudProviderAWS: - a, err := awsup.NewAWSAuthenticator(region) + a, err := awsup.NewAWSAuthenticator(ctx, region) if err != nil { return nil, err } From 1e5ed58cd5afa8bdbb407225bb356b180b3aeb72 Mon Sep 17 00:00:00 2001 From: justinsb Date: Mon, 22 Apr 2024 10:19:36 -0400 Subject: [PATCH 2/2] Update token validation for aws-sdk-go v2 We pass the full request details, it's less dependent on client versions. --- go.mod | 2 +- .../pkg/fi/cloudup/awsup/aws_authenticator.go | 25 +++- .../cloudup/awsup/aws_authenticator_test.go | 141 ++++++++++++------ upup/pkg/fi/cloudup/awsup/aws_verifier.go | 140 ++++++++++++----- .../pkg/fi/cloudup/awsup/aws_verifier_test.go | 83 +++++++++++ 5 files changed, 306 insertions(+), 85 deletions(-) create mode 100644 upup/pkg/fi/cloudup/awsup/aws_verifier_test.go diff --git a/go.mod b/go.mod index f77f68434f4da..8bcc7c0989919 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,6 @@ require ( github.com/Masterminds/sprig/v3 v3.2.3 github.com/apparentlymart/go-cidr v1.1.0 github.com/aws/amazon-ec2-instance-selector/v2 v2.4.2-0.20231216170552-14d4dfcbaadf - github.com/aws/aws-sdk-go v1.52.1 github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2/config v1.27.11 github.com/aws/aws-sdk-go-v2/credentials v1.17.11 @@ -111,6 +110,7 @@ require ( github.com/Microsoft/hcsshim v0.11.4 // indirect github.com/armon/go-metrics v0.4.1 // indirect github.com/atotto/clipboard v0.1.4 // indirect + github.com/aws/aws-sdk-go v1.52.1 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect diff --git a/upup/pkg/fi/cloudup/awsup/aws_authenticator.go b/upup/pkg/fi/cloudup/awsup/aws_authenticator.go index b40980a56494c..70cc3c2401d07 100644 --- a/upup/pkg/fi/cloudup/awsup/aws_authenticator.go +++ b/upup/pkg/fi/cloudup/awsup/aws_authenticator.go @@ -22,6 +22,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "net/http" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" @@ -63,18 +64,36 @@ func NewAWSAuthenticator(ctx context.Context, region string) (bootstrap.Authenti }, nil } +type awsV2Token struct { + URL string `json:"url"` + Method string `json:"method"` + SignedHeader http.Header `json:"headers"` +} + func (a *awsAuthenticator) CreateToken(body []byte) (string, error) { sha := sha256.Sum256(body) presignClient := sts.NewPresignClient(a.sts) // Ensure the signature is only valid for this particular body content. - stsRequest, _ := presignClient.PresignGetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) { + stsRequest, err := presignClient.PresignGetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) { po.ClientOptions = append(po.ClientOptions, func(o *sts.Options) { o.APIOptions = append(o.APIOptions, smithyhttp.AddHeaderValue("X-Kops-Request-SHA", base64.RawStdEncoding.EncodeToString(sha[:]))) }) }) + if err != nil { + return "", fmt.Errorf("building AWS STS presigned request: %w", err) + } + + awsV2Token := &awsV2Token{ + URL: stsRequest.URL, + Method: stsRequest.Method, + SignedHeader: stsRequest.SignedHeader, + } + token, err := json.Marshal(awsV2Token) + if err != nil { + return "", fmt.Errorf("converting token to json: %w", err) + } - headers, _ := json.Marshal(stsRequest.SignedHeader) - return AWSAuthenticationTokenPrefix + base64.StdEncoding.EncodeToString(headers), nil + return AWSAuthenticationTokenPrefix + base64.StdEncoding.EncodeToString(token), nil } diff --git a/upup/pkg/fi/cloudup/awsup/aws_authenticator_test.go b/upup/pkg/fi/cloudup/awsup/aws_authenticator_test.go index daf51bb589b3b..3f2ade5fbe015 100644 --- a/upup/pkg/fi/cloudup/awsup/aws_authenticator_test.go +++ b/upup/pkg/fi/cloudup/awsup/aws_authenticator_test.go @@ -22,30 +22,23 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "strings" "testing" - // "github.com/aws/aws-sdk-go-v2/aws" - // "github.com/aws/aws-sdk-go-v2/credentials" - // "github.com/aws/aws-sdk-go-v2/service/sts" - - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" ) func TestAWSPresign(t *testing.T) { - // mockSTSServer := &mockSTSServer{t: t} - // awsConfig := aws.Config{} - // awsConfig.Region = "us-east-1" - // awsConfig.Credentials = credentials.NewStaticCredentialsProvider("accesskey", "secretkey", "") - // awsConfig.HTTPClient = mockSTSServer - // sts := sts.NewFromConfig(awsConfig) - - mySession := session.Must(session.NewSession()) - mySession.Config.Credentials = credentials.NewStaticCredentials("accesskey", "secretkey", "") - sts := sts.New(mySession) - mySession.Config.HTTPClient = &http.Client{Transport: &mockHTTPTransport{}} + mockSTSServer := &mockHTTPClient{t: t} + awsConfig := aws.Config{} + awsConfig.Region = "us-east-1" + awsConfig.Credentials = credentials.NewStaticCredentialsProvider("fakeaccesskey", "fakesecretkey", "") + awsConfig.HTTPClient = mockSTSServer + sts := sts.NewFromConfig(awsConfig) + a := &awsAuthenticator{ sts: sts, } @@ -68,24 +61,29 @@ func TestAWSPresign(t *testing.T) { if err != nil { t.Fatalf("decoding token as base64: %v", err) } - headers := make(map[string][]string) - if err := json.Unmarshal([]byte(data), &headers); err != nil { + decoded := &awsV2Token{} + if err := json.Unmarshal([]byte(data), &decoded); err != nil { t.Fatalf("decoding token as json: %v", err) } - t.Logf("headers: %+v", headers) + t.Logf("decoded: %+v", decoded) + + amzSignature := "" + amzSignedHeaders := "" + amzDate := "" + amzAlgorithm := "" + amzCredential := "" authorization := "" - for header, values := range headers { + for header, values := range decoded.SignedHeader { got := strings.Join(values, "||") switch header { case "User-Agent": // Ignore // TODO: Should we (can we) override the useragent? case "X-Amz-Date": - if len(got) < 10 { - t.Errorf("expected %q header of at least 10 characters, got %q", header, got) - } + amzDate = got + case "Content-Length": if want := "43"; got != want { t.Errorf("unexpected %q header: got %q, want %q", header, got, want) @@ -95,6 +93,11 @@ func TestAWSPresign(t *testing.T) { t.Errorf("unexpected %q header: got %q, want %q", header, got, want) } + case "Host": + if want := "sts.us-east-1.amazonaws.com"; got != want { + t.Errorf("unexpected %q header: got %q, want %q", header, got, want) + } + case "X-Kops-Request-Sha": if want := bodyHashBase64; got != want { t.Errorf("unexpected %q header: got %q, want %q", header, got, want) @@ -103,34 +106,86 @@ func TestAWSPresign(t *testing.T) { // Validated more deeply below authorization = got default: - t.Errorf("unexpected header %q", header) + t.Errorf("unexpected header %q: %q", header, got) } } - if !strings.HasPrefix(authorization, "AWS4-HMAC-SHA256 ") { - t.Errorf("unexpected authorization prefix, got %q", authorization) - } - for _, token := range strings.Split(strings.TrimPrefix(authorization, "AWS4-HMAC-SHA256 "), ", ") { - kv := strings.SplitN(token, "=", 2) - got := kv[1] - switch kv[0] { - case "Signature": - if len(got) < 10 { - t.Errorf("expected %q Authorization value of at least 10 characters, got %q", kv[0], got) + // TODO: This is only aws-sdk-go V1 + if authorization != "" { + if !strings.HasPrefix(authorization, "AWS4-HMAC-SHA256 ") { + t.Errorf("unexpected authorization prefix, got %q", authorization) + } + + for _, token := range strings.Split(strings.TrimPrefix(authorization, "AWS4-HMAC-SHA256 "), ", ") { + kv := strings.SplitN(token, "=", 2) + if len(kv) == 1 { + t.Errorf("invalid token %q in authorization header", token) + continue } - case "Credential": - if len(got) < 10 { - t.Errorf("expected %q Authorization value of at least 10 characters, got %q", kv[0], got) + got := kv[1] + switch kv[0] { + case "Signature": + amzSignature = got + case "Credential": + amzCredential = got + case "SignedHeaders": + amzSignedHeaders = got + default: + t.Errorf("unknown token %q in authorization header", token) } - case "SignedHeaders": - if want := "content-length;content-type;host;x-amz-date;x-kops-request-sha"; got != want { - t.Errorf("unexpected %q Authorization value: got %q, want %q", kv[0], got, want) + } + } + + u, err := url.Parse(decoded.URL) + if err != nil { + t.Errorf("error parsing url %q: %v", decoded.URL, err) + } + for k, values := range u.Query() { + got := strings.Join(values, "||") + + switch k { + case "Action": + if want := "GetCallerIdentity"; got != want { + t.Errorf("unexpected %q query param: got %q, want %q", k, got, want) + } + case "Version": + if want := "2011-06-15"; got != want { + t.Errorf("unexpected %q query param: got %q, want %q", k, got, want) } + case "X-Amz-Date": + amzDate = k + case "X-Amz-Signature": + amzSignature = k + case "X-Amz-Credential": + amzCredential = got + case "X-Amz-SignedHeaders": + amzSignedHeaders = got + case "X-Amz-Algorithm": + amzAlgorithm = got default: - t.Errorf("unknown token %q in authorization header", token) + t.Errorf("unknown token %q=%q in query", k, got) } } + if len(amzCredential) < 10 { + t.Errorf("expected amzCredential value of at least 10 characters, got %q", amzCredential) + } + + if len(amzDate) < 10 { + t.Errorf("expected amz-date of at least 10 characters, got %q", amzDate) + } + + if len(amzSignature) < 10 { + t.Errorf("expected amzSignature value of at least 10 characters, got %q", amzSignature) + } + + if want := "AWS4-HMAC-SHA256"; amzAlgorithm != want { + t.Errorf("unexpected amzAlgorithm: got %q, want %q", amzAlgorithm, want) + } + + if want := "host;x-kops-request-sha"; amzSignedHeaders != want { + t.Errorf("unexpected amzSignedHeaders: got %q, want %q", amzSignedHeaders, want) + } } type mockHTTPClient struct { diff --git a/upup/pkg/fi/cloudup/awsup/aws_verifier.go b/upup/pkg/fi/cloudup/awsup/aws_verifier.go index 0e13225b83c2e..3648fada658b6 100644 --- a/upup/pkg/fi/cloudup/awsup/aws_verifier.go +++ b/upup/pkg/fi/cloudup/awsup/aws_verifier.go @@ -36,6 +36,7 @@ import ( awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/sts" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/kops/pkg/bootstrap" nodeidentityaws "k8s.io/kops/pkg/nodeidentity/aws" "k8s.io/kops/pkg/wellknownports" @@ -54,8 +55,9 @@ type awsVerifier struct { opt AWSVerifierOptions ec2 *ec2.Client - sts *sts.PresignClient client http.Client + + stsRequestValidator *stsRequestValidator } var _ bootstrap.Verifier = &awsVerifier{} @@ -79,12 +81,17 @@ func NewAWSVerifier(ctx context.Context, opt *AWSVerifierOptions) (bootstrap.Ver ec2Client := ec2.NewFromConfig(config) + stsRequestValidator, err := buildSTSRequestValidator(ctx, stsClient) + if err != nil { + return nil, err + } + return &awsVerifier{ - accountId: aws.ToString(identity.Account), - partition: partition, - opt: *opt, - ec2: ec2Client, - sts: sts.NewPresignClient(stsClient), + accountId: aws.ToString(identity.Account), + partition: partition, + opt: *opt, + ec2: ec2Client, + stsRequestValidator: stsRequestValidator, client: http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -126,59 +133,38 @@ func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request, } token = strings.TrimPrefix(token, AWSAuthenticationTokenPrefix) - // We rely on the client and server using the same version of the same STS library. - stsRequest, err := a.sts.PresignGetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) - if err != nil { - return nil, fmt.Errorf("creating identity request: %v", err) - } - - stsRequest.SignedHeader = nil tokenBytes, err := base64.StdEncoding.DecodeString(token) if err != nil { return nil, fmt.Errorf("decoding authorization token: %v", err) } - err = json.Unmarshal(tokenBytes, &stsRequest.SignedHeader) - if err != nil { + var decoded awsV2Token + if err := json.Unmarshal(tokenBytes, &decoded); err != nil { return nil, fmt.Errorf("unmarshalling authorization token: %v", err) } // Verify the token has signed the body content. sha := sha256.Sum256(body) - if stsRequest.SignedHeader.Get("X-Kops-Request-SHA") != base64.RawStdEncoding.EncodeToString(sha[:]) { + if decoded.SignedHeader.Get("X-Kops-Request-SHA") != base64.RawStdEncoding.EncodeToString(sha[:]) { return nil, fmt.Errorf("incorrect SHA") } - reqURL, err := url.Parse(stsRequest.URL) + reqURL, err := url.Parse(decoded.URL) if err != nil { return nil, fmt.Errorf("parsing STS request URL: %v", err) } - req := &http.Request{ - URL: reqURL, - Method: stsRequest.Method, - Header: stsRequest.SignedHeader, - } - response, err := a.client.Do(req) - if err != nil { - return nil, fmt.Errorf("sending STS request: %v", err) - } - if response != nil { - defer response.Body.Close() + signedHeaders := sets.New(strings.Split(reqURL.Query().Get("X-Amz-SignedHeaders"), ";")...) + if !signedHeaders.Has("x-kops-request-sha") { + return nil, fmt.Errorf("unexpected signed headers value") } - responseBody, err := io.ReadAll(response.Body) - if err != nil { - return nil, fmt.Errorf("reading STS response: %v", err) - } - if response.StatusCode != 200 { - return nil, fmt.Errorf("received status code %d from STS: %s", response.StatusCode, string(responseBody)) + if !a.stsRequestValidator.IsValid(reqURL) { + return nil, fmt.Errorf("invalid STS url: host=%q, path=%q", reqURL.Host, reqURL.Path) } - callerIdentity := GetCallerIdentityResponse{} - err = xml.NewDecoder(bytes.NewReader(responseBody)).Decode(&callerIdentity) + callerIdentity, err := a.stsRequestValidator.GetCallerIdentity(ctx, &a.client, &decoded) if err != nil { - return nil, fmt.Errorf("decoding STS response: %v", err) + return nil, err } - if callerIdentity.GetCallerIdentityResult[0].Account != a.accountId { return nil, fmt.Errorf("incorrect account %s", callerIdentity.GetCallerIdentityResult[0].Account) } @@ -276,3 +262,81 @@ func (a awsVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request, return result, nil } + +// stsRequestValidator describes valid STS Presigned URLs, and is used to validate client authentication requests. +type stsRequestValidator struct { + Host string +} + +// IsValid performs some basic pre-validation of the request URL. +func (s *stsRequestValidator) IsValid(u *url.URL) bool { + if u.Host != s.Host { + return false + } + if u.Path != "/" { + return false + } + if u.Query().Get("Action") != "GetCallerIdentity" { + return false + } + if len(u.Query()["Action"]) != 1 { + return false + } + + return true +} + +// GetCallerIdentity will request the presigned token URL, and decode the returned identity. +func (s *stsRequestValidator) GetCallerIdentity(ctx context.Context, httpClient *http.Client, decoded *awsV2Token) (*GetCallerIdentityResponse, error) { + reqURL, err := url.Parse(decoded.URL) + if err != nil { + return nil, fmt.Errorf("parsing STS request URL: %w", err) + } + + if !s.IsValid(reqURL) { + return nil, fmt.Errorf("url not valid for STS request") + } + + req := &http.Request{ + URL: reqURL, + Method: decoded.Method, + Header: decoded.SignedHeader, + } + response, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("sending STS request: %v", err) + } + if response != nil { + defer response.Body.Close() + } + + responseBody, err := io.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("reading STS response: %v", err) + } + if response.StatusCode != 200 { + return nil, fmt.Errorf("received status code %d from STS: %s", response.StatusCode, string(responseBody)) + } + + callerIdentity := &GetCallerIdentityResponse{} + err = xml.NewDecoder(bytes.NewReader(responseBody)).Decode(callerIdentity) + if err != nil { + return nil, fmt.Errorf("decoding STS response: %v", err) + } + + return callerIdentity, nil +} + +// buildSTSRequestValidator determines the form of a valid STS presigned URL. +func buildSTSRequestValidator(ctx context.Context, stsClient *sts.Client) (*stsRequestValidator, error) { + // We build a presigned token ourselves, primarily to get the expected hostname for the endpoint. + signed, err := sts.NewPresignClient(stsClient).PresignGetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return nil, fmt.Errorf("building presigned request: %w", err) + } + u, err := url.Parse(signed.URL) + if err != nil { + return nil, fmt.Errorf("parsing presigned url: %w", err) + } + return &stsRequestValidator{Host: u.Host}, nil +} diff --git a/upup/pkg/fi/cloudup/awsup/aws_verifier_test.go b/upup/pkg/fi/cloudup/awsup/aws_verifier_test.go new file mode 100644 index 0000000000000..12608fefff29c --- /dev/null +++ b/upup/pkg/fi/cloudup/awsup/aws_verifier_test.go @@ -0,0 +1,83 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package awsup + +import ( + "context" + "net/url" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +func TestGetSTSRequestInfo(t *testing.T) { + ctx := context.TODO() + + awsConfig := aws.Config{} + awsConfig.Region = "us-east-1" + awsConfig.Credentials = credentials.NewStaticCredentialsProvider("fakeaccesskey", "fakesecretkey", "") + sts := sts.NewFromConfig(awsConfig) + + stsRequestInfo, err := buildSTSRequestValidator(ctx, sts) + if err != nil { + t.Fatalf("error from getSTSRequestInfo: %v", err) + } + + if got, want := stsRequestInfo.Host, "sts.us-east-1.amazonaws.com"; got != want { + t.Errorf("unexpected host in sts request info; got %q, want %q", got, want) + } + + grid := []struct { + URL string + IsValid bool + }{ + { + URL: "https://sts.us-east-1.amazonaws.com/", + IsValid: false, + }, + { + URL: "https://sts.us-east-1.amazonaws.com/Foo", + IsValid: false, + }, + { + URL: "https://sts.us-east-1.amazonaws.com/?Action=GetCallerIdentity", + IsValid: true, + }, + { + URL: "https://sts.us-east-1.amazonaws.com/Foo?Action=GetCallerIdentity", + IsValid: false, + }, + { + URL: "https://sts.us-east-1.amazonaws.com/?Action=GetCallerIdentity&Action=GetCallerIdentity", + IsValid: false, + }, + } + + for _, g := range grid { + u, err := url.Parse(g.URL) + if err != nil { + t.Fatalf("parsing url %q: %v", g.URL, err) + } + got := stsRequestInfo.IsValid(u) + if got != g.IsValid { + t.Errorf("unexpected result for IsValid(%v); got %v, want %v", g.URL, got, g.IsValid) + } + } + +}