Skip to content

Commit

Permalink
Minor fixes (#2)
Browse files Browse the repository at this point in the history
* Marking token flag as required
* Switched Identity call
* Updated Readme
  • Loading branch information
cameronnewman authored Jun 18, 2019
1 parent 2086256 commit f1dc4d7
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 126 deletions.
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ Simple CLI tool which enables you to login and retrieve [AWS](https://aws.amazon

## Requirements

* AWS IAM User account
* Access Key and Secret Key stores $HOME/.aws/credentials
* AWS IAM User account

## Install

Expand Down Expand Up @@ -50,28 +50,32 @@ Use "shell [command] --help" for more information about a command.

If the `shell` sub-command is called, `mfa4aws` will output the following temporary security credentials:
```
export AWS_ACCESS_KEY_ID="DDFHAFG....UOCA"
export AWS_SECRET_ACCESS_KEY="JSKA...HJ2F"
export AWS_SESSION_TOKEN="ZQ...1VVQ=="
export AWS_SECURITY_TOKEN="ZQ...1VVQ=="
export AWS_ACCESS_KEY_ID=DDFHAFG....UOCA
export AWS_SECRET_ACCESS_KEY="JSKA...HJ2F
export AWS_SESSION_TOKEN=ZQ...1VVQ==
export AWS_SECURITY_TOKEN=ZQ...1VVQ==
export X_PRINCIPAL_ARN=arn:aws:iam::3678236812376:user/johnsmith
export EXPIRES=59m58.593852s
```

If you use `eval $(mfa4aws shell)` frequently, you may want to create a alias for it:

zsh:
```
alias m4a="function(){eval $( $(command mfa4aws) shell --shell=bash --profile=$@);}"
alias m4a="function(){eval $( $(command mfa4aws) shell --token=$@);}"
```

bash:
```
function m4a { eval $( $(which mfa4aws) shell --shell=bash --profile=$@); }
function m4a { eval $( $(which mfa4aws) shell --token=$@); }
```


## Building

TBA
```
make build
```

## Environment vars

Expand All @@ -81,6 +85,8 @@ The exec sub command will export the following environment variables.
* AWS_SECRET_ACCESS_KEY
* AWS_SESSION_TOKEN
* AWS_SECURITY_TOKEN
* X_PRINCIPAL_ARN
* EXPIRES

# License

Expand Down
67 changes: 38 additions & 29 deletions internal/pkg/awssts/awssts.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ type AWSCredentials struct {
Expires time.Duration `ini:"x_security_token_expires"`
}

//STSIdentity represents the STS Identity
type STSIdentity struct {
Account string
ARN string
UserID string
}

//GenerateSTSCredentials created STS Credentials
func GenerateSTSCredentials(profile string, tokenCode string) (*AWSCredentials, error) {

Expand All @@ -48,11 +55,11 @@ func GenerateSTSCredentials(profile string, tokenCode string) (*AWSCredentials,
return nil, ErrAWSCredentialsFileNotFound
}

if err := checkProfile(f, profile); err != nil {
if err := validateProfile(f, profile); err != nil {
return nil, err
}

if err := checkToken(tokenCode); err != nil {
if err := validateToken(tokenCode); err != nil {
return nil, err
}

Expand All @@ -62,34 +69,34 @@ func GenerateSTSCredentials(profile string, tokenCode string) (*AWSCredentials,

iamInstance := iam.New(awsSession)

iamUser, err := getIAMUserDetails(iamInstance)
mfaSerialNumber, err := getIAMUserMFADevice(iamInstance)
if err != nil {
return nil, err
}

mfaSerialNumber, err := getIAMUserMFADevice(iamInstance)
stsInstance := sts.New(awsSession)

stsSessionCredentials, err := getSTSSessionToken(stsInstance, tokenCode, mfaSerialNumber)
if err != nil {
return nil, err
}

stsInstance := sts.New(awsSession)

stsSessionCredentials, err := generateSTSSessionCredentials(stsInstance, tokenCode, mfaSerialNumber)
identity, err := getSTSIdentity(stsInstance)
if err != nil {
return nil, err
}

return &AWSCredentials{
AWSAccessKeyID: *stsSessionCredentials.AccessKeyId,
AWSAccessKeyID: *stsSessionCredentials.AccessKeyId,
AWSSecretAccessKey: *stsSessionCredentials.SecretAccessKey,
AWSSessionToken: *stsSessionCredentials.SecretAccessKey,
AWSSecurityToken: *stsSessionCredentials.SecretAccessKey,
PrincipalARN: *iamUser.Arn,
Expires: time.Until(*stsSessionCredentials.Expiration),
AWSSessionToken: *stsSessionCredentials.SecretAccessKey,
AWSSecurityToken: *stsSessionCredentials.SecretAccessKey,
PrincipalARN: identity.ARN,
Expires: time.Until(*stsSessionCredentials.Expiration),
}, nil
}

func checkProfile(file []byte, profile string) error {
func validateProfile(file []byte, profile string) error {
const (
profileDefault string = "default"

Expand Down Expand Up @@ -129,28 +136,14 @@ func openFile(path string) ([]byte, error) {
return buf.Bytes(), nil
}

func checkToken(token string) error {

func validateToken(token string) error {
if len(token) <= 5 {
return ErrInvalidToken
}

return nil
}

func getIAMUserDetails(iamInstance iamiface.IAMAPI) (*iam.User, error) {

identity, err := iamInstance.GetUser(&iam.GetUserInput{})
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
return nil, fmt.Errorf("Unable to retrive user - %v", aerr.Message())
}
return nil, fmt.Errorf("unknown error occurred, %v", err)
}

return identity.User, nil
}

func getIAMUserMFADevice(iamInstance iamiface.IAMAPI) (string, error) {
devices, err := iamInstance.ListMFADevices(&iam.ListMFADevicesInput{})
if err != nil {
Expand All @@ -167,7 +160,7 @@ func getIAMUserMFADevice(iamInstance iamiface.IAMAPI) (string, error) {
return *devices.MFADevices[0].SerialNumber, nil
}

func generateSTSSessionCredentials(stsInstance stsiface.STSAPI, tokenCode string, mfaDeviceSerialNumber string) (*sts.Credentials, error) {
func getSTSSessionToken(stsInstance stsiface.STSAPI, tokenCode string, mfaDeviceSerialNumber string) (*sts.Credentials, error) {
stsSession, err := stsInstance.GetSessionToken(&sts.GetSessionTokenInput{
TokenCode: &tokenCode,
SerialNumber: &mfaDeviceSerialNumber,
Expand All @@ -188,3 +181,19 @@ func generateSTSSessionCredentials(stsInstance stsiface.STSAPI, tokenCode string

return stsSession.Credentials, nil
}

func getSTSIdentity(stsInstance stsiface.STSAPI) (*STSIdentity, error) {
identity, err := stsInstance.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
return nil, fmt.Errorf("Unable to retrive user - %v", aerr.Message())
}
return nil, fmt.Errorf("unknown error occurred, %v", err)
}

return &STSIdentity{
Account: *identity.Account,
ARN: *identity.Arn,
UserID: *identity.UserId,
}, nil
}
Loading

0 comments on commit f1dc4d7

Please sign in to comment.