Skip to content

Commit

Permalink
Merge pull request #658 from kmala/release-1.25
Browse files Browse the repository at this point in the history
Add cluster details to the sts through headers
  • Loading branch information
k8s-ci-robot authored Sep 19, 2023
2 parents 5a08ab2 + 0807352 commit 0103f7f
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 3 deletions.
37 changes: 34 additions & 3 deletions pkg/providers/v1/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ const volumeAttachmentStuck = "VolumeAttachmentStuck"
// Indicates that a node has volumes stuck in attaching state and hence it is not fit for scheduling more pods
const nodeWithImpairedVolumes = "NodeWithImpairedVolumes"

const headerSourceArn = "x-amz-source-arn"
const headerSourceAccount = "x-amz-source-account"

const (
// volumeAttachmentConsecutiveErrorLimit is the number of consecutive errors we will ignore when waiting for a volume to attach/detach
volumeAttachmentStatusConsecutiveErrorLimit = 10
Expand Down Expand Up @@ -614,6 +617,11 @@ type CloudConfig struct {

// RoleARN is the IAM role to assume when interaction with AWS APIs.
RoleARN string
// SourceARN is value which is passed while assuming role specified by RoleARN. When a service
// assumes a role in your account, you can include the aws:SourceAccount and aws:SourceArn global
// condition context keys in your role trust policy to limit access to the role to only requests that are generated
// by expected resources. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html
SourceARN string

// KubernetesClusterTag is the legacy cluster id we'll use to identify our cluster resources
KubernetesClusterTag string
Expand Down Expand Up @@ -1260,12 +1268,15 @@ func init() {

var creds *credentials.Credentials
if cfg.Global.RoleARN != "" {
klog.Infof("Using AWS assumed role %v", cfg.Global.RoleARN)
stsClient, err := getSTSClient(sess, cfg.Global.RoleARN, cfg.Global.SourceARN)
if err != nil {
return nil, fmt.Errorf("unable to create sts client, %v", err)
}
creds = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
assumeRoleProvider(&stscreds.AssumeRoleProvider{
Client: sts.New(sess),
Client: stsClient,
RoleARN: cfg.Global.RoleARN,
}),
})
Expand All @@ -1276,13 +1287,33 @@ func init() {
})
}

func getSTSClient(sess *session.Session, roleARN, sourceARN string) (*sts.STS, error) {
klog.Infof("Using AWS assumed role %v", roleARN)
stsClient := sts.New(sess)
sourceAcct, err := GetSourceAccount(roleARN)
if err != nil {
return nil, err
}
reqHeaders := map[string]string{
headerSourceAccount: sourceAcct,
}
if sourceARN != "" {
reqHeaders[headerSourceArn] = sourceARN
}
stsClient.Handlers.Sign.PushFront(func(s *request.Request) {
s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders))
})
klog.V(4).Infof("configuring STS client with extra headers, %v", reqHeaders)
return stsClient, nil
}

// readAWSCloudConfig reads an instance of AWSCloudConfig from config reader.
func readAWSCloudConfig(config io.Reader) (*CloudConfig, error) {
var cfg CloudConfig
var err error

if config != nil {
err = gcfg.ReadInto(&cfg, config)
err = gcfg.FatalOnly(gcfg.ReadInto(&cfg, config))
if err != nil {
return nil, err
}
Expand Down
21 changes: 21 additions & 0 deletions pkg/providers/v1/aws_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ limitations under the License.
package aws

import (
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"

"k8s.io/apimachinery/pkg/util/sets"
)
Expand All @@ -43,3 +46,21 @@ func stringSetFromPointers(in []*string) sets.String {
}
return out
}

// GetSourceAccount constructs source acct and return them for use
func GetSourceAccount(roleARN string) (string, error) {
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
// arn:partition:service:region:account-id:resource-type/resource-id
// IAM format, region is always blank
// arn:aws:iam::account:role/role-name-with-path
if !arn.IsARN(roleARN) {
return "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
}

parsedArn, err := arn.Parse(roleARN)
if err != nil {
return "", err
}

return parsedArn.AccountID, nil
}
60 changes: 60 additions & 0 deletions pkg/providers/v1/aws_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package aws

import "testing"

func TestGetSourceAcctAndArn(t *testing.T) {
type args struct {
roleARN string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "corect role arn",
args: args{
roleARN: "arn:aws:iam::123456789876:role/test-cluster",
},
want: "123456789876",
wantErr: false,
},
{
name: "incorect role arn",
args: args{
roleARN: "arn:aws:iam::123456789876",
},
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := GetSourceAccount(tt.args.roleARN)
if (err != nil) != tt.wantErr {
t.Errorf("GetSourceAccount() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("GetSourceAccount() got = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 0103f7f

Please sign in to comment.