diff --git a/aws/saml.go b/aws/saml.go index 70bb3223..e0d53c88 100644 --- a/aws/saml.go +++ b/aws/saml.go @@ -103,7 +103,7 @@ func ParseSAMLResponse(base64Response string) (*SAMLResponse, error) { } // ExtractRoleArnAndPrincipalArn extracts role ARN and principal ARN from SAML response -func ExtractRoleArnAndPrincipalArn(samlResponse SAMLResponse) (string, string, error) { +func ExtractRoleArnAndPrincipalArn(samlResponse SAMLResponse, roleName string) (string, string, error) { for _, attr := range samlResponse.Assertion.AttributeStatement.Attributes { if attr.Name != roleAttributeName { continue @@ -113,6 +113,9 @@ func ExtractRoleArnAndPrincipalArn(samlResponse SAMLResponse) (string, string, e s := strings.Split(v.Value, ",") roleArn := s[0] principalArn := s[1] + if roleName != "" && strings.Split(roleArn, "/")[1] != roleName { + continue + } return roleArn, principalArn, nil } } diff --git a/aws/saml_test.go b/aws/saml_test.go index 19a32b1e..331da383 100644 --- a/aws/saml_test.go +++ b/aws/saml_test.go @@ -138,6 +138,7 @@ func TestParseSAMLResponse(t *testing.T) { func TestExtractRoleArnAndPrincipalArn(t *testing.T) { type args struct { samlResponse SAMLResponse + roleName string } tests := []struct { name string @@ -173,10 +174,81 @@ func TestExtractRoleArnAndPrincipalArn(t *testing.T) { }, }, }, + roleName: "", }, wantRoleArn: "arn:aws:iam::012345678901:role/TestRole", wantPrincipalArn: "arn:aws:iam::012345678901:saml-provider/TestProvider", }, + { + name: "returns first role when role attribute are multi and no roleName argument", + args: args{ + samlResponse: SAMLResponse{ + Assertion: Assertion{ + AttributeStatement: AttributeStatement{ + Attributes: []Attribute{ + { + Name: "dummy", + AttributeValues: []AttributeValue{ + { + Value: "dummy", + }, + }, + }, + { + Name: roleAttributeName, + AttributeValues: []AttributeValue{ + { + Value: "arn:aws:iam::012345678901:role/TestRole1,arn:aws:iam::012345678901:saml-provider/TestProvider1", + }, + { + Value: "arn:aws:iam::012345678901:role/TestRole2,arn:aws:iam::012345678901:saml-provider/TestProvider2", + }, + }, + }, + }, + }, + }, + }, + roleName: "", + }, + wantRoleArn: "arn:aws:iam::012345678901:role/TestRole1", + wantPrincipalArn: "arn:aws:iam::012345678901:saml-provider/TestProvider1", + }, + { + name: "returns specify role when role attribute are multi and roleName argument", + args: args{ + samlResponse: SAMLResponse{ + Assertion: Assertion{ + AttributeStatement: AttributeStatement{ + Attributes: []Attribute{ + { + Name: "dummy", + AttributeValues: []AttributeValue{ + { + Value: "dummy", + }, + }, + }, + { + Name: roleAttributeName, + AttributeValues: []AttributeValue{ + { + Value: "arn:aws:iam::012345678901:role/TestRole1,arn:aws:iam::012345678901:saml-provider/TestProvider1", + }, + { + Value: "arn:aws:iam::012345678901:role/TestRole2,arn:aws:iam::012345678901:saml-provider/TestProvider2", + }, + }, + }, + }, + }, + }, + }, + roleName: "TestRole2", + }, + wantRoleArn: "arn:aws:iam::012345678901:role/TestRole2", + wantPrincipalArn: "arn:aws:iam::012345678901:saml-provider/TestProvider2", + }, { name: "returns an error when role attribute does not exist", args: args{ @@ -196,13 +268,14 @@ func TestExtractRoleArnAndPrincipalArn(t *testing.T) { }, }, }, + roleName: "", }, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, got1, err := ExtractRoleArnAndPrincipalArn(tt.args.samlResponse) + got, got1, err := ExtractRoleArnAndPrincipalArn(tt.args.samlResponse, tt.args.roleName) if (err != nil) != tt.wantErr { t.Errorf("ExtractRoleArnAndPrincipalArn() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/cmd/root.go b/cmd/root.go index ffd870e4..df0fbce3 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -34,6 +34,7 @@ func Execute() { func newRootCmd() *cobra.Command { var configure bool + var roleName string var profile string var showVersion bool @@ -82,7 +83,7 @@ func newRootCmd() *cobra.Command { return err } - roleArn, principalArn, err := aws.ExtractRoleArnAndPrincipalArn(*response) + roleArn, principalArn, err := aws.ExtractRoleArnAndPrincipalArn(*response, roleName) if err != nil { return err } @@ -102,6 +103,7 @@ func newRootCmd() *cobra.Command { } cmd.PersistentFlags().BoolVarP(&configure, "configure", "c", false, "configure initial settings") cmd.PersistentFlags().StringVarP(&profile, "profile", "p", "default", "AWS profile") + cmd.PersistentFlags().StringVarP(&roleName, "role", "r", "", "AWS IAM role name") cmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "Show version") return cmd