Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sourceArn to sts through headers #749

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/aws-iam-authenticator/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func getConfig() (config.Config, error) {
PartitionID: viper.GetString("server.partition"),
ClusterID: viper.GetString("clusterID"),
ServerEC2DescribeInstancesRoleARN: viper.GetString("server.ec2DescribeInstancesRoleARN"),
SourceARN: viper.GetString("server.sourceARN"),
HostPort: viper.GetInt("server.port"),
Hostname: viper.GetString("server.hostname"),
GenerateKubeconfigPath: viper.GetString("server.generateKubeconfig"),
Expand Down
6 changes: 6 additions & 0 deletions pkg/config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ type Config struct {
// running.
ServerEC2DescribeInstancesRoleARN string

// SourceARN is value which is passed while assuming role specified by ServerEC2DescribeInstancesRoleARN.
// When a service assumes a role in your account, you can include the aws:SourceAccount and aws:SourceArn global
// condition context keys in your role trust policy to limit access to the role to only requests that are generated
// by expected resources. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html
SourceARN string

// Address defines the hostname or IP Address to bind the HTTPS server to listen to. This is useful when creating
// a local server to handle the authentication request for development.
Address string
Expand Down
52 changes: 48 additions & 4 deletions pkg/ec2provider/ec2provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/endpoints"
Expand Down Expand Up @@ -35,6 +36,11 @@ const (
// Maximum time in Milliseconds to wait for a new batch call this also depends on if the instance size has
// already become 100 then it will not respect this limit
maxWaitIntervalForBatch = 200

// Headers for STS request for source ARN
headerSourceArn = "x-amz-source-arn"
// Headers for STS request for source account
headerSourceAccount = "x-amz-source-account"
)

// Get a node name from instance ID
Expand All @@ -60,7 +66,7 @@ type ec2ProviderImpl struct {
instanceIdsChannel chan string
}

func New(roleARN, region string, qps int, burst int) EC2Provider {
func New(roleARN, sourceARN, region string, qps int, burst int) EC2Provider {
haoranleo marked this conversation as resolved.
Show resolved Hide resolved
dnsCache := ec2PrivateDNSCache{
cache: make(map[string]string),
lock: sync.RWMutex{},
Expand All @@ -70,7 +76,7 @@ func New(roleARN, region string, qps int, burst int) EC2Provider {
lock: sync.RWMutex{},
}
return &ec2ProviderImpl{
ec2: ec2.New(newSession(roleARN, region, qps, burst)),
ec2: ec2.New(newSession(roleARN, sourceARN, region, qps, burst)),
privateDNSCache: dnsCache,
ec2Requests: ec2Requests,
instanceIdsChannel: make(chan string, maxChannelSize),
Expand All @@ -81,7 +87,7 @@ func New(roleARN, region string, qps int, burst int) EC2Provider {
// the environment, shared credentials (~/.aws/credentials), or EC2 Instance
// Role.

func newSession(roleARN, region string, qps int, burst int) *session.Session {
func newSession(roleARN, sourceARN, region string, qps int, burst int) *session.Session {
sess := session.Must(session.NewSession())
sess.Handlers.Build.PushFrontNamed(request.NamedHandler{
Name: "authenticatorUserAgent",
Expand All @@ -103,8 +109,9 @@ func newSession(roleARN, region string, qps int, burst int) *session.Session {
logrus.Errorf("Getting error = %s while creating rate limited client ", err)
}

stsClient := applySTSRequestHeaders(sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)), sourceARN)
ap := &stscreds.AssumeRoleProvider{
Client: sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)),
Client: stsClient,
RoleARN: roleARN,
Duration: time.Duration(60) * time.Minute,
}
Expand Down Expand Up @@ -277,3 +284,40 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string
p.unsetRequestInFlightForInstanceId(id)
}
}

func applySTSRequestHeaders(stsClient *sts.STS, sourceARN string) *sts.STS {
// parse both source account and source arn from the sourceARN, and add them as headers to the STS client
if sourceARN != "" {
sourceAcct, err := getSourceAccount(sourceARN)
if err != nil {
panic(fmt.Sprintf("%s is not a valid arn, err: %v", sourceARN, err))
}
reqHeaders := map[string]string{
headerSourceAccount: sourceAcct,
headerSourceArn: sourceARN,
}
stsClient.Handlers.Sign.PushFront(func(s *request.Request) {
s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders))
})
logrus.Infof("configuring STS client with extra headers, %v", reqHeaders)
}
return stsClient
}

// getSourceAccount constructs source acct and return them for use
func getSourceAccount(roleARN string) (string, error) {
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
// arn:partition:service:region:account-id:resource-type/resource-id
// IAM format, region is always blank
// arn:aws:iam::account:role/role-name-with-path
if !arn.IsARN(roleARN) {
return "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
}

parsedArn, err := arn.Parse(roleARN)
if err != nil {
return "", err
}

return parsedArn.AccountID, nil
}
41 changes: 41 additions & 0 deletions pkg/ec2provider/ec2provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,44 @@ func prepare100InstanceOutput() []*ec2.Reservation {
return reservations

}

func TestGetSourceAcctAndArn(t *testing.T) {
type args struct {
roleARN string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "corect role arn",
args: args{
roleARN: "arn:aws:iam::123456789876:role/test-cluster",
},
want: "123456789876",
wantErr: false,
},
{
name: "incorect role arn",
args: args{
roleARN: "arn:aws:iam::123456789876",
},
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getSourceAccount(tt.args.roleARN)
if (err != nil) != tt.wantErr {
t.Errorf("GetSourceAccount() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("GetSourceAccount() got = %v, want %v", got, tt.want)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2

h := &handler{
verifier: token.NewVerifier(c.ClusterID, c.PartitionID, instanceRegion),
ec2Provider: ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, instanceRegion, ec2DescribeQps, ec2DescribeBurst),
ec2Provider: ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, c.SourceARN, instanceRegion, ec2DescribeQps, ec2DescribeBurst),
clusterID: c.ClusterID,
backendMapper: backendMapper,
scrubbedAccounts: c.Config.ScrubbedAWSAccounts,
Expand Down