diff --git a/aws/aws.go b/aws/aws.go index 4783e83..e873cda 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -21,12 +21,13 @@ func AssumeRole(login saml.LoginData, role saml.LoginRole, duration int64) (*sts return stsClient.AssumeRoleWithSAML(&input) } -func EnvironmentVariables(credentials *sts.Credentials) map[string]string { +func EnvironmentVariables(stsOutput *sts.AssumeRoleWithSAMLOutput) map[string]string { subject := make(map[string]string) - subject["AWS_ACCESS_KEY_ID"] = *credentials.AccessKeyId - subject["AWS_SECRET_ACCESS_KEY"] = *credentials.SecretAccessKey - subject["AWS_SESSION_TOKEN"] = *credentials.SessionToken + subject["AWS_ACCESS_KEY_ID"] = *stsOutput.Credentials.AccessKeyId + subject["AWS_SECRET_ACCESS_KEY"] = *stsOutput.Credentials.SecretAccessKey + subject["AWS_SESSION_TOKEN"] = *stsOutput.Credentials.SessionToken + subject["AWS_METADATA_USER_ARN"] = *stsOutput.AssumedRoleUser.Arn return subject } diff --git a/aws/aws_test.go b/aws/aws_test.go index 392bf8e..4251c9a 100644 --- a/aws/aws_test.go +++ b/aws/aws_test.go @@ -9,11 +9,17 @@ func TestEnvironmentVariables(t *testing.T) { accessKeyId := "llama" secretAccessKey := "alpaca" sessionToken := "guanaco" + assumedRoleArn := "arn:aws:iam::1234123123:role/sso-vicuña-role" - creds := sts.Credentials{ - AccessKeyId: &accessKeyId, - SecretAccessKey: &secretAccessKey, - SessionToken: &sessionToken, + creds := sts.AssumeRoleWithSAMLOutput{ + AssumedRoleUser: &sts.AssumedRoleUser{ + Arn: &assumedRoleArn, + }, + Credentials: &sts.Credentials{ + AccessKeyId: &accessKeyId, + SecretAccessKey: &secretAccessKey, + SessionToken: &sessionToken, + }, } subject := EnvironmentVariables(&creds) @@ -41,4 +47,12 @@ func TestEnvironmentVariables(t *testing.T) { t.Logf("Got: %s", subject["AWS_SESSION_TOKEN"]) t.Fail() } + + if subject["AWS_METADATA_USER_ARN"] != assumedRoleArn { + t.Log("---------------") + t.Log("Did not correctly set AWS_METADATA_USER_ARN") + t.Logf("Expected: %s", assumedRoleArn) + t.Logf("Got: %s", subject["AWS_METADATA_USER_ARN"]) + t.Fail() + } } diff --git a/cmd/root.go b/cmd/root.go index 2e81322..4542f41 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -79,7 +79,7 @@ These can be configured either in the [okta] section of ~/.config/yak/config.tom go func() { <-channel fmt.Fprintln(os.Stderr, "Received termination signal, exiting...") - if(stateErr == nil) { + if stateErr == nil { terminal.Restore(int(syscall.Stdin), state) } diff --git a/cmd/shim.go b/cmd/shim.go index 352ad9b..5d77007 100644 --- a/cmd/shim.go +++ b/cmd/shim.go @@ -34,7 +34,7 @@ func shimCmd(cmd *cobra.Command, args []string) error { return cli.Exec( command, cli.EnrichedEnvironment( - aws.EnvironmentVariables(creds.Credentials), + aws.EnvironmentVariables(creds), ), ) } diff --git a/format/format.go b/format/format.go index 76fc2a8..74baf15 100644 --- a/format/format.go +++ b/format/format.go @@ -27,7 +27,7 @@ var outputFormatters map[string]func(*sts.AssumeRoleWithSAMLOutput) (string, err outputFormat = "export %s=%s\n" } - for key, value := range aws.EnvironmentVariables(creds.Credentials) { + for key, value := range aws.EnvironmentVariables(creds) { output.WriteString(fmt.Sprintf(outputFormat, key, value)) } diff --git a/format/format_test.go b/format/format_test.go index 010bd3d..ba37f09 100644 --- a/format/format_test.go +++ b/format/format_test.go @@ -13,6 +13,7 @@ import ( var accessKeyId string = "llama" var secretAccessKey string = "alpaca" var sessionToken string = "guanaco" +var assumedRoleArn string = "arn:aws:iam::1234123123:role/sso-vicuña-role" var innerCreds sts.Credentials = sts.Credentials{ AccessKeyId: &accessKeyId, @@ -21,6 +22,9 @@ var innerCreds sts.Credentials = sts.Credentials{ } var creds sts.AssumeRoleWithSAMLOutput = sts.AssumeRoleWithSAMLOutput{ + AssumedRoleUser: &sts.AssumedRoleUser{ + Arn: &assumedRoleArn, + }, Credentials: &innerCreds, } @@ -37,6 +41,7 @@ func TestDefaultEnvCredentials(t *testing.T) { fmt.Sprintf(`export AWS_ACCESS_KEY_ID=%s`, accessKeyId), fmt.Sprintf(`export AWS_SECRET_ACCESS_KEY=%s`, secretAccessKey), fmt.Sprintf(`export AWS_SESSION_TOKEN=%s`, sessionToken), + fmt.Sprintf(`export AWS_METADATA_USER_ARN=%s`, assumedRoleArn), }, setUp: func() { os.Unsetenv("PSModulePath") @@ -49,6 +54,7 @@ func TestDefaultEnvCredentials(t *testing.T) { fmt.Sprintf(`$env:AWS_ACCESS_KEY_ID = "%s"`, accessKeyId), fmt.Sprintf(`$env:AWS_SECRET_ACCESS_KEY = "%s"`, secretAccessKey), fmt.Sprintf(`$env:AWS_SESSION_TOKEN = "%s"`, sessionToken), + fmt.Sprintf(`$env:AWS_METADATA_USER_ARN = "%s"`, assumedRoleArn), }, setUp: func() { os.Setenv("PSModulePath", "something")