Skip to content

Commit

Permalink
Merge pull request #283 from AlexVulaj/generic-role-session-name-assu…
Browse files Browse the repository at this point in the history
…me-customer-role

Use md5 hash as role session name when assuming into customer's support role
  • Loading branch information
openshift-merge-bot[bot] authored Jan 9, 2024
2 parents 0b8e642 + 5d3d480 commit 4321f0b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 22 deletions.
25 changes: 18 additions & 7 deletions cmd/ocm-backplane/cloud/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cloud

import (
"context"
//nolint:gosec
"encoding/json"
"errors"
"fmt"
Expand All @@ -24,14 +25,17 @@ import (
logger "github.com/sirupsen/logrus"
)

const OldFlowSupportRole = "role/RH-Technical-Support-Access"
const (
OldFlowSupportRole = "role/RH-Technical-Support-Access"
CustomerRoleArnName = "Target-Role-Arn"
)

var StsClient = awsutil.StsClient
var AssumeRoleWithJWT = awsutil.AssumeRoleWithJWT
var NewStaticCredentialsProvider = credentials.NewStaticCredentialsProvider
var AssumeRoleSequence = awsutil.AssumeRoleSequence

// Wrapper for the configuration needed for cloud requests
// QueryConfig Wrapper for the configuration needed for cloud requests
type QueryConfig struct {
config.BackplaneConfiguration
OcmConnection *ocmsdk.Connection
Expand Down Expand Up @@ -207,7 +211,8 @@ func (cfg *QueryConfig) getCloudCredentialsFromBackplaneAPI(ocmToken string) (bp
}

type assumeChainResponse struct {
AssumptionSequence []namedRoleArn `json:"assumptionSequence"`
AssumptionSequence []namedRoleArn `json:"assumptionSequence"`
CustomerRoleSessionName string `json:"customerRoleSessionName"`
}

type namedRoleArn struct {
Expand Down Expand Up @@ -263,17 +268,23 @@ func (cfg *QueryConfig) getIsolatedCredentials(ocmToken string) (aws.Credentials
return aws.Credentials{}, fmt.Errorf("failed to unmarshal response: %w", err)
}

roleAssumeSequence := make([]string, 0, len(roleChainResponse.AssumptionSequence))
for _, namedRoleArn := range roleChainResponse.AssumptionSequence {
roleAssumeSequence = append(roleAssumeSequence, namedRoleArn.Arn)
assumeRoleArnSessionSequence := make([]awsutil.RoleArnSession, 0, len(roleChainResponse.AssumptionSequence))
for _, namedRoleArnEntry := range roleChainResponse.AssumptionSequence {
roleArnSession := awsutil.RoleArnSession{RoleArn: namedRoleArnEntry.Arn}
if namedRoleArnEntry.Name == CustomerRoleArnName {
roleArnSession.RoleSessionName = roleChainResponse.CustomerRoleSessionName
} else {
roleArnSession.RoleSessionName = email
}
assumeRoleArnSessionSequence = append(assumeRoleArnSessionSequence, roleArnSession)
}

seedClient := sts.NewFromConfig(aws.Config{
Region: "us-east-1",
Credentials: NewStaticCredentialsProvider(seedCredentials.AccessKeyID, seedCredentials.SecretAccessKey, seedCredentials.SessionToken),
})

targetCredentials, err := AssumeRoleSequence(email, seedClient, roleAssumeSequence, cfg.BackplaneConfiguration.ProxyURL, awsutil.DefaultSTSClientProviderFunc)
targetCredentials, err := AssumeRoleSequence(seedClient, assumeRoleArnSessionSequence, cfg.BackplaneConfiguration.ProxyURL, awsutil.DefaultSTSClientProviderFunc)
if err != nil {
return aws.Credentials{}, fmt.Errorf("failed to assume role sequence: %w", err)
}
Expand Down
26 changes: 16 additions & 10 deletions pkg/awsutil/sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func AssumeRoleWithJWT(jwt string, roleArn string, stsClient stscreds.AssumeRole
return result, nil
}

func AssumeRole(roleSessionName string, stsClient stscreds.AssumeRoleAPIClient, roleArn string) (aws.Credentials, error) {
func AssumeRole(stsClient stscreds.AssumeRoleAPIClient, roleSessionName string, roleArn string) (aws.Credentials, error) {
assumeRoleProvider := stscreds.NewAssumeRoleProvider(stsClient, roleArn, func(options *stscreds.AssumeRoleOptions) {
options.RoleSessionName = roleSessionName
})
Expand All @@ -113,16 +113,21 @@ var DefaultSTSClientProviderFunc STSClientProviderFunc = func(optnFns ...func(op
return sts.NewFromConfig(cfg), nil
}

func AssumeRoleSequence(roleSessionName string, seedClient stscreds.AssumeRoleAPIClient, roleArnSequence []string, proxyURL *string, stsClientProviderFunc STSClientProviderFunc) (aws.Credentials, error) {
if len(roleArnSequence) == 0 {
type RoleArnSession struct {
RoleSessionName string
RoleArn string
}

func AssumeRoleSequence(seedClient stscreds.AssumeRoleAPIClient, roleArnSessionSequence []RoleArnSession, proxyURL *string, stsClientProviderFunc STSClientProviderFunc) (aws.Credentials, error) {
if len(roleArnSessionSequence) == 0 {
return aws.Credentials{}, errors.New("role ARN sequence cannot be empty")
}

nextClient := seedClient
var lastCredentials aws.Credentials

for i, roleArn := range roleArnSequence {
result, err := AssumeRole(roleSessionName, nextClient, roleArn)
for i, roleArnSession := range roleArnSessionSequence {
result, err := AssumeRole(nextClient, roleArnSession.RoleSessionName, roleArnSession.RoleArn)
retryCount := 0
for err != nil {
// IAM policy updates can take a few seconds to resolve, and the sts.Client in AWS' Go SDK doesn't refresh itself on retries.
Expand All @@ -132,20 +137,21 @@ func AssumeRoleSequence(roleSessionName string, seedClient stscreds.AssumeRoleAP
time.Sleep(assumeRoleRetryBackoff)
nextClient, err = createAssumeRoleSequenceClient(stsClientProviderFunc, lastCredentials, proxyURL)
if err != nil {
return aws.Credentials{}, fmt.Errorf("failed to create client with credentials for role %v: %w", roleArn, err)
return aws.Credentials{}, fmt.Errorf("failed to create client with credentials for role %v: %w", roleArnSession.RoleArn, err)
}
result, err = AssumeRole(roleSessionName, nextClient, roleArn)

result, err = AssumeRole(nextClient, roleArnSession.RoleSessionName, roleArnSession.RoleArn)
retryCount++
} else {
return aws.Credentials{}, fmt.Errorf("failed to assume role %v: %w", roleArn, err)
return aws.Credentials{}, fmt.Errorf("failed to assume role %v: %w", roleArnSession.RoleArn, err)
}
}
lastCredentials = result

if i < len(roleArnSequence)-1 {
if i < len(roleArnSessionSequence)-1 {
nextClient, err = createAssumeRoleSequenceClient(stsClientProviderFunc, lastCredentials, proxyURL)
if err != nil {
return aws.Credentials{}, fmt.Errorf("failed to create client with credentials for role %v: %w", roleArn, err)
return aws.Credentials{}, fmt.Errorf("failed to create client with credentials for role %v: %w", roleArnSession.RoleArn, err)
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/awsutil/sts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func TestAssumeRole(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := AssumeRole("", tt.stsClient, "")
got, err := AssumeRole(tt.stsClient, "", "")
if (err != nil) != tt.wantErr {
t.Errorf("AssumeRole() error = %v, wantErr %v", err, tt.wantErr)
return
Expand All @@ -165,7 +165,7 @@ func TestAssumeRole(t *testing.T) {
func TestAssumeRoleSequence(t *testing.T) {
type args struct {
seedClient stscreds.AssumeRoleAPIClient
roleArnSequence []string
roleArnSequence []RoleArnSession
stsClientProviderFunc STSClientProviderFunc
}
tests := []struct {
Expand All @@ -184,15 +184,15 @@ func TestAssumeRoleSequence(t *testing.T) {
{
name: "role arn sequence is empty",
args: args{
roleArnSequence: []string{},
roleArnSequence: []RoleArnSession{},
},
wantErr: true,
},
{
name: "single role arn in sequence",
args: args{
seedClient: defaultSuccessMockSTSClient(),
roleArnSequence: []string{"a"},
roleArnSequence: []RoleArnSession{{RoleArn: "a"}},
stsClientProviderFunc: func(optFns ...func(*config.LoadOptions) error) (stscreds.AssumeRoleAPIClient, error) {
return defaultSuccessMockSTSClient(), nil
},
Expand All @@ -209,7 +209,7 @@ func TestAssumeRoleSequence(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := AssumeRoleSequence("", tt.args.seedClient, tt.args.roleArnSequence, nil, tt.args.stsClientProviderFunc)
got, err := AssumeRoleSequence(tt.args.seedClient, tt.args.roleArnSequence, nil, tt.args.stsClientProviderFunc)
if (err != nil) != tt.wantErr {
t.Errorf("AssumeRoleSequence() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down

0 comments on commit 4321f0b

Please sign in to comment.