Skip to content

Commit

Permalink
Clean-up + better testing (#12)
Browse files Browse the repository at this point in the history
* Clean-up + better testing
* Updated the EXPIRES export
  • Loading branch information
cameronnewman authored Jun 24, 2019
1 parent 11602b2 commit dc42a9a
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 275 deletions.
114 changes: 8 additions & 106 deletions internal/pkg/aws/aws.go
Original file line number Diff line number Diff line change
@@ -1,77 +1,30 @@
package aws

import (
"bytes"
"os/user"
"path/filepath"
"regexp"
"time"

"gopkg.in/ini.v1"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"

"github.com/spf13/afero"
)

const (
tokenValidationRegex string = "^[0-9]+$"
)

var (
appFs = afero.NewOsFs()

tokenValidationRegexComplied = regexp.MustCompilePOSIX(tokenValidationRegex)
)

//Credentials represents the set of attributes used to authenticate to AWS with a short lived session
type Credentials struct {
AWSAccessKeyID string `ini:"aws_access_key_id"`
AWSSecretAccessKey string `ini:"aws_secret_access_key"`
AWSSessionToken string `ini:"aws_session_token"`
AWSSecurityToken string `ini:"aws_security_token"`
PrincipalARN string `ini:"x_principal_arn"`
Expires time.Duration `ini:"x_security_token_expires"`
AWSAccessKeyID string `ini:"aws_access_key_id"`
AWSSecretAccessKey string `ini:"aws_secret_access_key"`
AWSSessionToken string `ini:"aws_session_token"`
AWSSecurityToken string `ini:"aws_security_token"`
PrincipalARN string `ini:"x_principal_arn"`
Expires time.Time `ini:"x_security_token_expires"`
}

//GenerateSTSCredentials created STS Credentials
func GenerateSTSCredentials(profile string, tokenCode string) (*Credentials, error) {

const (
awsCredentialsFolder string = ".aws"
awsCredentialsFile string = "credentials"
)

user, err := user.Current()
if err != nil {
return nil, err
}

path := filepath.Join(user.HomeDir, awsCredentialsFolder, awsCredentialsFile)

f, err := openFile(path)
awsSession, err := createSession("", profile)
if err != nil {
return nil, ErrAWSCredentialsFileNotFound
}

if err := validateProfile(f, profile); err != nil {
return nil, err
}

if err := validateToken(tokenCode); err != nil {
return nil, err
}

awsSession := session.Must(session.NewSessionWithOptions(session.Options{
Config: aws.Config{
Credentials: credentials.NewSharedCredentials(path, profile),
},
}))

iamInstance := iam.New(awsSession)

mfaSerialNumber, err := getIAMUserMFADevice(iamInstance)
Expand All @@ -96,57 +49,6 @@ func GenerateSTSCredentials(profile string, tokenCode string) (*Credentials, err
AWSSessionToken: *stsSessionCredentials.SessionToken,
AWSSecurityToken: *stsSessionCredentials.SessionToken,
PrincipalARN: identity.ARN,
Expires: time.Until(*stsSessionCredentials.Expiration),
Expires: *stsSessionCredentials.Expiration,
}, nil
}

func validateProfile(file []byte, profile string) error {
const (
profileDefault string = "default"

credentialsAWSAccessKeyID string = "aws_access_key_id"
credentialsAWSSecretAccessKey string = "aws_secret_access_key"
)

if len(profile) == 0 {
profile = profileDefault
}

creds, err := ini.Load(file)
if err != nil {
return ErrInvalidAWSCredentialsFile
}

if !creds.Section(profile).HasKey(credentialsAWSAccessKeyID) ||
!creds.Section(profile).HasKey(credentialsAWSSecretAccessKey) {
return ErrInvalidAWSCredentialsFile
}

return nil
}

func openFile(path string) ([]byte, error) {
f, err := appFs.Open(path)
if err != nil {
return nil, err
}
defer f.Close()

buf := bytes.NewBuffer(nil)
_, err = buf.ReadFrom(f)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}

func validateToken(token string) error {
if len(token) <= 5 {
return ErrInvalidToken
}
if !tokenValidationRegexComplied.MatchString(token) {
return ErrInvalidToken
}

return nil
}
166 changes: 0 additions & 166 deletions internal/pkg/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,169 +98,3 @@ func TestGenerateSTSCredentials(t *testing.T) {
})
}
}

func Test_validateProfile(t *testing.T) {
type args struct {
file []byte
profile string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"Valid/NoProfileNameDefined",
args{
file: []byte(`
[default]
aws_access_key_id = blahblah
aws_secret_access_key = blahblah/blahblah`),
profile: "",
},
false,
},
{
"Valid/NonDefaultProfileNameDefined",
args{
file: []byte(`
[candycrush]
aws_access_key_id = blahblah
aws_secret_access_key = blahblah/blahblah`),
profile: "candycrush",
},
false,
},
{
"Invalid/InvalidCredentialsFile",
args{
file: []byte(`
-[default]
_aws_access_key_id = blahblah
a-ws_secret_access_key = blahblah/blahblah`),
profile: "",
},
true,
},
{
"Invalid/InvalidProfile",
args{
file: []byte(""),
profile: "blah",
},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := validateProfile(tt.args.file, tt.args.profile); (err != nil) != tt.wantErr {
t.Errorf("validateProfile() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

func Test_openFile(t *testing.T) {

type args struct {
path string
}
tests := []struct {
name string
args args
want []byte
wantErr bool
}{
{
"Invalid/NonExistentFile",
args{
path: "/some/unknown/path",
},
nil,
true,
},
{
"Valid/FileExists",
args{
path: "/knowntestfile.txt",
},
[]byte(`test`),
false,
},
{
"Valid/EmptyFileExists",
args{
path: "/emptyknowntestfile.txt",
},
[]byte(""),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := openFile(tt.args.path)
if (err != nil) != tt.wantErr {
t.Errorf("openFile() error = %v, wantErr %v", err, tt.wantErr)
return
}

if !reflect.DeepEqual(got, tt.want) {
t.Errorf("openFile() = %v, want %v", got, tt.want)
}
})
}
}

func Test_validateToken(t *testing.T) {
type args struct {
token string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"Invalid/EmptyToken",
args{
token: "",
},
true,
},
{
"Invalid/ShortToken",
args{
token: "2321",
},
true,
},
{
"Invalid/NotNumbers",
args{
token: "23ss21",
},
true,
},
{
"Invalid/NotNumbersLong",
args{
token: "5364f'[73",
},
true,
},
{
"Valid/Token",
args{
token: "1234532",
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := validateToken(tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("validateToken() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
Loading

0 comments on commit dc42a9a

Please sign in to comment.