Skip to content

Commit

Permalink
Merge pull request #92 from GESkunkworks/fix/ram-overload
Browse files Browse the repository at this point in the history
bulk login - non force fix
  • Loading branch information
MichaelPalmer1 authored Dec 9, 2020
2 parents fbe53d4 + a444cca commit a2733fa
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 26 deletions.
57 changes: 31 additions & 26 deletions cmd/gossamer3/commands/bulk_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type SecondaryRoleOutput struct {
}

// Assume assumes a primary role, returning the credentials to assume secondary role if needed
func (input *PrimaryRoleInput) Assume(roleSessionName string, force bool) {
func (input *PrimaryRoleInput) Assume(roleSessionName string, force bool, sharedCredsFile *awsconfig.CredentialsFile) {
var creds *awsconfig.AWSCredentials
var err error
existingCreds := false
Expand Down Expand Up @@ -101,18 +101,17 @@ func (input *PrimaryRoleInput) Assume(roleSessionName string, force bool) {

// Not forcing credential refresh, pull from file
if !force {
sharedCreds := awsconfig.NewSharedCredentials(input.RoleConfig.Profile)

// Check if credentials are expired
if !sharedCreds.Expired() {
creds, err = sharedCreds.Load()
if !sharedCredsFile.Expired(input.RoleConfig.Profile) {
creds, err = sharedCredsFile.Load(input.RoleConfig.Profile)
existingCreds = creds != nil
}
}
}

// Get new credentials
if force || err != nil || creds == nil || input.RoleConfig.Profile == "" {
if force || err != nil || !existingCreds || input.RoleConfig.Profile == "" {
// If session duration is defined at a role level, use that instead of the idp account level
var sessDur = input.Account.SessionDuration
if input.RoleConfig.SessionDuration > 0 {
Expand Down Expand Up @@ -163,7 +162,7 @@ func (input *PrimaryRoleInput) Assume(roleSessionName string, force bool) {
wg.Add(1)

// Perform secondary assumption
go secondaryInput.Assume(roleSessionName, force)
go secondaryInput.Assume(roleSessionName, force, sharedCredsFile)
}

// Create a channel to wait for completion of the wait group
Expand Down Expand Up @@ -201,7 +200,7 @@ func (input *PrimaryRoleInput) Assume(roleSessionName string, force bool) {
}

// Assume assumes a secondary role using the PrimaryRoleInput parent object
func (input *SecondaryRoleInput) Assume(roleSessionName string, force bool) {
func (input *SecondaryRoleInput) Assume(roleSessionName string, force bool, sharedCredsFile *awsconfig.CredentialsFile) {
var creds *awsconfig.AWSCredentials
var err error
existingCreds := false
Expand Down Expand Up @@ -239,17 +238,15 @@ func (input *SecondaryRoleInput) Assume(roleSessionName string, force bool) {

// Not forcing credential refresh, pull from file
if !force {
sharedCreds := awsconfig.NewSharedCredentials(input.RoleAssumption.Profile)

// Check if credentials are expired
if !sharedCreds.Expired() {
creds, err = sharedCreds.Load()
if !sharedCredsFile.Expired(input.RoleAssumption.Profile) {
creds, err = sharedCredsFile.Load(input.RoleAssumption.Profile)
existingCreds = creds != nil
}
}

// Get new credentials if forced, error encountered, or credentials are not found
if force || err != nil || creds == nil {
if force || err != nil || !existingCreds {
creds, err = assumeRole(input.PrimaryCredentials, input.RoleAssumption.RoleArn, roleSessionName, region)
}

Expand Down Expand Up @@ -297,6 +294,12 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {
// Check if any credentials need to be refreshed - only run when force is false
logger.Debug("Check if credentials exist")

// Load entire file from single location
sharedCredsFile, err := awsconfig.LoadCredentialsFile()
if err != nil {
logger.Fatalln(errors.Wrap(err, "couldnt load aws credentials file"))
}

// Not forced, and not assuming all roles
if !loginFlags.Force && !roleConfig.AssumeAllRoles {
var primaryExpired = false // Only prompt login if one of the parent credentials are expired
Expand All @@ -308,8 +311,7 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {
for _, primary := range roleConfig.Roles {
// Only check for expiration of parent role
if primary.Profile != "" {
sharedCreds := awsconfig.NewSharedCredentials(primary.Profile)
if sharedCreds.Expired() {
if sharedCredsFile.Expired(primary.Profile) {
logger.WithField("Role", primary.PrimaryRoleArn).Debugf("Creds have expired")
primaryExpired = true
noCredsExpired = false
Expand All @@ -318,7 +320,7 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {

// Not expired, set the session role name. Only need this once since it should always be the same
if sessionRoleName == "" {
creds, err := sharedCreds.Load()
creds, err := sharedCredsFile.Load(primary.Profile)
if err != nil {
return errors.Wrap(err, "error creating shared creds")
}
Expand Down Expand Up @@ -351,8 +353,7 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {
secondary.Profile = fmt.Sprintf("%s/%s", arnParts[4], strings.TrimPrefix(arnParts[5], "role/"))
}

sharedCreds := awsconfig.NewSharedCredentials(secondary.Profile)
if sharedCreds.Expired() {
if sharedCredsFile.Expired(secondary.Profile) {
logger.WithField("SecondaryRole", secondary.RoleArn).Debugf("Creds have expired")

// Secondary is expired. Add secondary role to primary
Expand Down Expand Up @@ -386,7 +387,7 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {
// Get the role session name
logger.Infof("Using primary role session to assume child roles. No need to login")

return bulkAssumeAsync(rolesToAssume, account, roleConfig, sessionRoleName, false, true, "")
return bulkAssumeAsync(sharedCredsFile, rolesToAssume, account, roleConfig, sessionRoleName, false, true, "")
}
}

Expand Down Expand Up @@ -494,12 +495,14 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {
logger.Debugf("Got groups: %+v", roleConfig.Roles)
}

return bulkAssumeAsync(roleConfig.Roles, account, roleConfig, roleSessionName, loginFlags.Force, false, samlAssertion)
return bulkAssumeAsync(sharedCredsFile, roleConfig.Roles, account, roleConfig, roleSessionName, loginFlags.Force, false, samlAssertion)
}

// bulkAssumeAsync assumes all primary and secondary roles given in the roles slice. If useExistingCreds is true, the samlAssertion
// is NOT needed
func bulkAssumeAsync(roles []cfg.RoleConfig, account *cfg.IDPAccount, roleConfig *cfg.BulkRoleConfig, roleSessionName string, force, useExistingCreds bool, samlAssertion string) error {
func bulkAssumeAsync(sharedCredsFile *awsconfig.CredentialsFile, roles []cfg.RoleConfig,
account *cfg.IDPAccount, roleConfig *cfg.BulkRoleConfig, roleSessionName string, force, useExistingCreds bool, samlAssertion string) error {

logger := logrus.WithFields(logrus.Fields{
"Action": "Bulk Assume",
"UseExistingCreds": useExistingCreds,
Expand Down Expand Up @@ -534,7 +537,7 @@ func bulkAssumeAsync(roles []cfg.RoleConfig, account *cfg.IDPAccount, roleConfig
wg.Add(1)

// Perform role assumption
go input.Assume(roleSessionName, force)
go input.Assume(roleSessionName, force, sharedCredsFile)
}

// Done channel
Expand All @@ -555,27 +558,29 @@ func bulkAssumeAsync(roles []cfg.RoleConfig, account *cfg.IDPAccount, roleConfig

// Handle primary creds, only need to save primary if NOT using existing creds
if creds.Input.RoleConfig.Profile != "" && !useExistingCreds {
sharedCreds := awsconfig.NewSharedCredentials(creds.Input.RoleConfig.Profile)
if err := sharedCreds.Save(creds.PrimaryCredentials); err != nil {
if err := sharedCredsFile.StoreCreds(creds.Input.RoleConfig.Profile, creds.PrimaryCredentials); err != nil {
return errors.Wrap(err, "error saving credentials")
}
}

// Handle secondary creds
for _, childCreds := range creds.Output {
sharedCreds := awsconfig.NewSharedCredentials(childCreds.Input.RoleAssumption.Profile)
if err := sharedCreds.Save(childCreds.Credentials); err != nil {
return errors.Wrap(err, "error saving credentials")
if err := sharedCredsFile.StoreCreds(childCreds.Input.RoleAssumption.Profile, childCreds.Credentials); err != nil {
return errors.Wrap(err, "error saving child credentials")
}
}
wg.Done()

case <-done:
if err := sharedCredsFile.SaveFile(); err != nil {
log.Fatalf("Error storing new credentials: %v", err)
}
logger.Infof("Done!")
return nil

// Timeout
case <-time.After(time.Second * time.Duration(account.Timeout)):
// TODO: Should i store what i have so far?
logger.Errorf("Timed out after %v seconds", account.Timeout)
return errors.New("timed out while assuming roles")
}
Expand Down
124 changes: 124 additions & 0 deletions pkg/awsconfig/awsconfig_bulk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package awsconfig

import (
"os"
"path/filepath"
"time"

"github.com/mitchellh/go-homedir"
"gopkg.in/ini.v1"
)

// CredentialsFile holds the original ini file to eliminate reading it multiple times throughout the process
type CredentialsFile struct {
File *ini.File

fileLoc string
}

// LoadCredentialsFile loads the AWS credentials file and keeps it in a config object
// with an optional fileName parameter override
func LoadCredentialsFile(fileName ...string) (*CredentialsFile, error) {
var file string
if len(fileName) > 0 {
// Filename was passed in as an arg
expanded, err := homedir.Expand(fileName[0])
if err != nil {
return nil, err
}
file = expanded
} else {
// otherwise use default
f, err := locateConfigFile()
if err != nil {
return nil, err
}
file = f
}

logger.WithField("filename", file).Debug("ensureCredentialsExists")

if err := ensureCredentialsExist(file); err != nil {
return nil, err
}

// File exists, read it and load it into an ini config
credsFile, err := ini.Load(file)
if err != nil {
return nil, err
}

return &CredentialsFile{
File: credsFile,
fileLoc: file,
}, nil
}

// SaveFile saves the credentials file to where it was loaded from
func (creds *CredentialsFile) SaveFile() error {
logger.WithField("filename", creds.fileLoc).Debug("storing file")
return creds.File.SaveTo(creds.fileLoc)
}

// Expired checks to see if a profile is expired or not
func (creds *CredentialsFile) Expired(profile string) bool {
cred, err := creds.Load(profile)
if err != nil {
return true
}

return time.Now().After(cred.Expires)
}

// Load loads a credentials file from the
func (creds *CredentialsFile) Load(profile string) (*AWSCredentials, error) {
iniProfile, err := creds.File.GetSection(profile)
if err != nil {
return nil, ErrCredentialsNotFound
}

awsCreds := new(AWSCredentials)

if err := iniProfile.MapTo(awsCreds); err != nil {
return nil, ErrCredentialsNotFound
}

return awsCreds, nil
}

// StoreCreds takes a profile and the awsCreds to store. This does NOT save the file, that needs to be called later
func (creds *CredentialsFile) StoreCreds(profile string, awsCreds *AWSCredentials) error {
iniProfile, err := creds.File.NewSection(profile)
if err != nil {
return err
}

if err := iniProfile.ReflectFrom(awsCreds); err != nil {
return err
}

return nil
}

func ensureCredentialsExist(file string) error {
if _, err := os.Stat(file); err != nil {
if os.IsNotExist(err) {
// File does not exist, create it
dir := filepath.Dir(file)

if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return err
}

logger.WithField("dir", dir).Debugf("Dir created")

if _, err := os.Create(file); err != nil {
return err
}

logger.WithField("file", file).Debugf("File created")
}
return nil
}
return nil
}

0 comments on commit a2733fa

Please sign in to comment.