diff --git a/README.md b/README.md index 2453c60..f41ed21 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ Simple CLI tool which enables you to login and retrieve [AWS](https://aws.amazon ## Requirements -* AWS IAM User account * Access Key and Secret Key stores $HOME/.aws/credentials +* AWS IAM User account ## Install @@ -50,28 +50,32 @@ Use "shell [command] --help" for more information about a command. If the `shell` sub-command is called, `mfa4aws` will output the following temporary security credentials: ``` -export AWS_ACCESS_KEY_ID="DDFHAFG....UOCA" -export AWS_SECRET_ACCESS_KEY="JSKA...HJ2F" -export AWS_SESSION_TOKEN="ZQ...1VVQ==" -export AWS_SECURITY_TOKEN="ZQ...1VVQ==" +export AWS_ACCESS_KEY_ID=DDFHAFG....UOCA +export AWS_SECRET_ACCESS_KEY="JSKA...HJ2F +export AWS_SESSION_TOKEN=ZQ...1VVQ== +export AWS_SECURITY_TOKEN=ZQ...1VVQ== +export X_PRINCIPAL_ARN=arn:aws:iam::3678236812376:user/johnsmith +export EXPIRES=59m58.593852s ``` If you use `eval $(mfa4aws shell)` frequently, you may want to create a alias for it: zsh: ``` -alias m4a="function(){eval $( $(command mfa4aws) shell --shell=bash --profile=$@);}" +alias m4a="function(){eval $( $(command mfa4aws) shell --token=$@);}" ``` bash: ``` -function m4a { eval $( $(which mfa4aws) shell --shell=bash --profile=$@); } +function m4a { eval $( $(which mfa4aws) shell --token=$@); } ``` ## Building -TBA +``` +make build +``` ## Environment vars @@ -81,6 +85,8 @@ The exec sub command will export the following environment variables. * AWS_SECRET_ACCESS_KEY * AWS_SESSION_TOKEN * AWS_SECURITY_TOKEN +* X_PRINCIPAL_ARN +* EXPIRES # License diff --git a/internal/pkg/awssts/awssts.go b/internal/pkg/awssts/awssts.go index 45ae24e..aa96616 100644 --- a/internal/pkg/awssts/awssts.go +++ b/internal/pkg/awssts/awssts.go @@ -33,6 +33,13 @@ type AWSCredentials struct { Expires time.Duration `ini:"x_security_token_expires"` } +//STSIdentity represents the STS Identity +type STSIdentity struct { + Account string + ARN string + UserID string +} + //GenerateSTSCredentials created STS Credentials func GenerateSTSCredentials(profile string, tokenCode string) (*AWSCredentials, error) { @@ -48,11 +55,11 @@ func GenerateSTSCredentials(profile string, tokenCode string) (*AWSCredentials, return nil, ErrAWSCredentialsFileNotFound } - if err := checkProfile(f, profile); err != nil { + if err := validateProfile(f, profile); err != nil { return nil, err } - if err := checkToken(tokenCode); err != nil { + if err := validateToken(tokenCode); err != nil { return nil, err } @@ -62,34 +69,34 @@ func GenerateSTSCredentials(profile string, tokenCode string) (*AWSCredentials, iamInstance := iam.New(awsSession) - iamUser, err := getIAMUserDetails(iamInstance) + mfaSerialNumber, err := getIAMUserMFADevice(iamInstance) if err != nil { return nil, err } - mfaSerialNumber, err := getIAMUserMFADevice(iamInstance) + stsInstance := sts.New(awsSession) + + stsSessionCredentials, err := getSTSSessionToken(stsInstance, tokenCode, mfaSerialNumber) if err != nil { return nil, err } - stsInstance := sts.New(awsSession) - - stsSessionCredentials, err := generateSTSSessionCredentials(stsInstance, tokenCode, mfaSerialNumber) + identity, err := getSTSIdentity(stsInstance) if err != nil { return nil, err } return &AWSCredentials{ - AWSAccessKeyID: *stsSessionCredentials.AccessKeyId, + AWSAccessKeyID: *stsSessionCredentials.AccessKeyId, AWSSecretAccessKey: *stsSessionCredentials.SecretAccessKey, - AWSSessionToken: *stsSessionCredentials.SecretAccessKey, - AWSSecurityToken: *stsSessionCredentials.SecretAccessKey, - PrincipalARN: *iamUser.Arn, - Expires: time.Until(*stsSessionCredentials.Expiration), + AWSSessionToken: *stsSessionCredentials.SecretAccessKey, + AWSSecurityToken: *stsSessionCredentials.SecretAccessKey, + PrincipalARN: identity.ARN, + Expires: time.Until(*stsSessionCredentials.Expiration), }, nil } -func checkProfile(file []byte, profile string) error { +func validateProfile(file []byte, profile string) error { const ( profileDefault string = "default" @@ -129,8 +136,7 @@ func openFile(path string) ([]byte, error) { return buf.Bytes(), nil } -func checkToken(token string) error { - +func validateToken(token string) error { if len(token) <= 5 { return ErrInvalidToken } @@ -138,19 +144,6 @@ func checkToken(token string) error { return nil } -func getIAMUserDetails(iamInstance iamiface.IAMAPI) (*iam.User, error) { - - identity, err := iamInstance.GetUser(&iam.GetUserInput{}) - if err != nil { - if aerr, ok := err.(awserr.Error); ok { - return nil, fmt.Errorf("Unable to retrive user - %v", aerr.Message()) - } - return nil, fmt.Errorf("unknown error occurred, %v", err) - } - - return identity.User, nil -} - func getIAMUserMFADevice(iamInstance iamiface.IAMAPI) (string, error) { devices, err := iamInstance.ListMFADevices(&iam.ListMFADevicesInput{}) if err != nil { @@ -167,7 +160,7 @@ func getIAMUserMFADevice(iamInstance iamiface.IAMAPI) (string, error) { return *devices.MFADevices[0].SerialNumber, nil } -func generateSTSSessionCredentials(stsInstance stsiface.STSAPI, tokenCode string, mfaDeviceSerialNumber string) (*sts.Credentials, error) { +func getSTSSessionToken(stsInstance stsiface.STSAPI, tokenCode string, mfaDeviceSerialNumber string) (*sts.Credentials, error) { stsSession, err := stsInstance.GetSessionToken(&sts.GetSessionTokenInput{ TokenCode: &tokenCode, SerialNumber: &mfaDeviceSerialNumber, @@ -188,3 +181,19 @@ func generateSTSSessionCredentials(stsInstance stsiface.STSAPI, tokenCode string return stsSession.Credentials, nil } + +func getSTSIdentity(stsInstance stsiface.STSAPI) (*STSIdentity, error) { + identity, err := stsInstance.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + return nil, fmt.Errorf("Unable to retrive user - %v", aerr.Message()) + } + return nil, fmt.Errorf("unknown error occurred, %v", err) + } + + return &STSIdentity{ + Account: *identity.Account, + ARN: *identity.Arn, + UserID: *identity.UserId, + }, nil +} diff --git a/internal/pkg/awssts/awssts_test.go b/internal/pkg/awssts/awssts_test.go index 2927178..1f42dec 100644 --- a/internal/pkg/awssts/awssts_test.go +++ b/internal/pkg/awssts/awssts_test.go @@ -6,11 +6,11 @@ import ( "reflect" "testing" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam/iamiface" "github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts/stsiface" - "github.com/aws/aws-sdk-go/aws/awserr" "github.com/spf13/afero" "mfa4aws/internal/pkg/awssts/mock/iammock" @@ -65,7 +65,7 @@ func TestGenerateSTSCredentials(t *testing.T) { } } -func Test_checkProfile(t *testing.T) { +func Test_validateProfile(t *testing.T) { type args struct { file []byte profile string @@ -119,8 +119,8 @@ func Test_checkProfile(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := checkProfile(tt.args.file, tt.args.profile); (err != nil) != tt.wantErr { - t.Errorf("checkProfile() error = %v, wantErr %v", err, tt.wantErr) + if err := validateProfile(tt.args.file, tt.args.profile); (err != nil) != tt.wantErr { + t.Errorf("validateProfile() error = %v, wantErr %v", err, tt.wantErr) } }) } @@ -177,7 +177,7 @@ func Test_openFile(t *testing.T) { } } -func Test_checkToken(t *testing.T) { +func Test_validateToken(t *testing.T) { type args struct { token string } @@ -210,73 +210,8 @@ func Test_checkToken(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := checkToken(tt.args.token); (err != nil) != tt.wantErr { - t.Errorf("checkToken() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func Test_getIAMUserDetails(t *testing.T) { - type args struct { - iamInstance iamiface.IAMAPI - } - tests := []struct { - name string - args args - want *iam.User - wantErr bool - }{ - { - "Vaild/EmptyUser", - args{ - iamInstance: &iammock.IAMAPIMock{ - GetUserFunc: func(in1 *iam.GetUserInput) (*iam.GetUserOutput, error) { - - output := &iam.GetUserOutput{ - User: &iam.User{}, - } - return output, nil - }, - }, - }, - &iam.User{}, - false, - }, - { - "Invaild/EmptyUser/Error", - args{ - iamInstance: &iammock.IAMAPIMock{ - GetUserFunc: func(in1 *iam.GetUserInput) (*iam.GetUserOutput, error) { - return nil, xerrors.New("blah") - }, - }, - }, - nil, - true, - }, - { - "Invaild/EmptyUser/awserrError", - args{ - iamInstance: &iammock.IAMAPIMock{ - GetUserFunc: func(in1 *iam.GetUserInput) (*iam.GetUserOutput, error) { - return nil, awserr.New("5000", "blah", xerrors.New("blah")) - }, - }, - }, - nil, - true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := getIAMUserDetails(tt.args.iamInstance) - if (err != nil) != tt.wantErr { - t.Errorf("getIAMUserDetails() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("getIAMUserDetails() = %v, want %v", got, tt.want) + if err := validateToken(tt.args.token); (err != nil) != tt.wantErr { + t.Errorf("validateToken() error = %v, wantErr %v", err, tt.wantErr) } }) } @@ -340,7 +275,7 @@ func Test_getIAMUserMFADevice(t *testing.T) { args{ iamInstance: &iammock.IAMAPIMock{ ListMFADevicesFunc: func(in1 *iam.ListMFADevicesInput) (*iam.ListMFADevicesOutput, error) { - + output := &iam.ListMFADevicesOutput{ MFADevices: nil, } @@ -366,7 +301,7 @@ func Test_getIAMUserMFADevice(t *testing.T) { } } -func Test_generateSTSSessionCredentials(t *testing.T) { +func Test_getSTSSessionToken(t *testing.T) { type args struct { stsInstance stsiface.STSAPI tokenCode string @@ -386,7 +321,7 @@ func Test_generateSTSSessionCredentials(t *testing.T) { return &sts.GetSessionTokenOutput{}, nil }, }, - tokenCode: "123456", + tokenCode: "123456", mfaDeviceSerialNumber: "sfagstfey", }, nil, @@ -400,7 +335,7 @@ func Test_generateSTSSessionCredentials(t *testing.T) { return nil, awserr.New("5000", "blah", xerrors.New("blah")) }, }, - tokenCode: "123456", + tokenCode: "123456", mfaDeviceSerialNumber: "sfagstfey", }, nil, @@ -414,7 +349,7 @@ func Test_generateSTSSessionCredentials(t *testing.T) { return nil, awserr.New(sts.ErrCodeExpiredTokenException, "Blah", xerrors.New("blah")) }, }, - tokenCode: "123456", + tokenCode: "123456", mfaDeviceSerialNumber: "sfagstfey", }, nil, @@ -428,7 +363,7 @@ func Test_generateSTSSessionCredentials(t *testing.T) { return nil, awserr.New(sts.ErrCodeInvalidIdentityTokenException, "Blah", xerrors.New("blah")) }, }, - tokenCode: "123456", + tokenCode: "123456", mfaDeviceSerialNumber: "sfagstfey", }, nil, @@ -442,7 +377,7 @@ func Test_generateSTSSessionCredentials(t *testing.T) { return nil, xerrors.New("blah") }, }, - tokenCode: "123456", + tokenCode: "123456", mfaDeviceSerialNumber: "sfagstfey", }, nil, @@ -451,13 +386,87 @@ func Test_generateSTSSessionCredentials(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := generateSTSSessionCredentials(tt.args.stsInstance, tt.args.tokenCode, tt.args.mfaDeviceSerialNumber) + got, err := getSTSSessionToken(tt.args.stsInstance, tt.args.tokenCode, tt.args.mfaDeviceSerialNumber) + if (err != nil) != tt.wantErr { + t.Errorf("getSTSSessionToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("getSTSSessionToken() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getSTSIdentity(t *testing.T) { + type args struct { + stsInstance stsiface.STSAPI + } + tests := []struct { + name string + args args + want *STSIdentity + wantErr bool + }{ + { + "Valid/User", + args{ + stsInstance: &stsmock.STSAPIMock{ + GetCallerIdentityFunc: func(in1 *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { + + account := "342563637373" + arn := "ashgajsdhgajsdg" + userID := "asjkdhkasdhaksd" + + return &sts.GetCallerIdentityOutput{ + Account: &account, + Arn: &arn, + UserId: &userID, + }, nil + }, + }, + }, + &STSIdentity{ + Account: "342563637373", + ARN: "ashgajsdhgajsdg", + UserID: "asjkdhkasdhaksd", + }, + false, + }, + { + "Invaild/Error", + args{ + stsInstance: &stsmock.STSAPIMock{ + GetCallerIdentityFunc: func(in1 *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { + return nil, xerrors.New("blah") + }, + }, + }, + nil, + true, + }, + { + "Invaild/awserrError", + args{ + stsInstance: &stsmock.STSAPIMock{ + GetCallerIdentityFunc: func(in1 *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { + return nil, awserr.New("askjdhaksjhd", "Blah", xerrors.New("blah")) + }, + }, + }, + nil, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getSTSIdentity(tt.args.stsInstance) if (err != nil) != tt.wantErr { - t.Errorf("generateSTSSessionCredentials() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("getSTSIdentity() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("generateSTSSessionCredentials() = %v, want %v", got, tt.want) + t.Errorf("getSTSIdentity() = %v, want %v", got, tt.want) } }) } diff --git a/internal/pkg/cmd/root.go b/internal/pkg/cmd/root.go index d253076..362c680 100644 --- a/internal/pkg/cmd/root.go +++ b/internal/pkg/cmd/root.go @@ -18,13 +18,15 @@ var rootCmd = &cobra.Command{Use: "shell"} // Execute is the entry point for the MFA command func Execute(version string) { - rootCmd.PersistentFlags().StringVarP(&awsProfile, "profile", "p", "default", "AWS Profile name in $HOME/.aws/credentials") - rootCmd.PersistentFlags().StringVarP(&mfaToken, "token", "t", "", "Current MFA value to use for STS generation") + persistentFlags := rootCmd.PersistentFlags() + persistentFlags.StringVarP(&awsProfile, "profile", "p", "default", "AWS Profile name in $HOME/.aws/credentials") + persistentFlags.StringVarP(&mfaToken, "token", "t", "", "Current MFA value to use for STS generation") + cobra.MarkFlagRequired(persistentFlags, "token") - releaseVersion = version + releaseVersion = version - if err := rootCmd.Execute(); err != nil { - fmt.Println(err) - os.Exit(1) - } + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } }