Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bulk login - non force fix #92

Merged
merged 3 commits into from
Dec 9, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions cmd/gossamer3/commands/bulk_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type SecondaryRoleOutput struct {
}

// Assume assumes a primary role, returning the credentials to assume secondary role if needed
func (input *PrimaryRoleInput) Assume(roleSessionName string, force bool) {
func (input *PrimaryRoleInput) Assume(roleSessionName string, force bool, sharedCreds *awsconfig.Credentials) {
Squwid marked this conversation as resolved.
Show resolved Hide resolved
var creds *awsconfig.AWSCredentials
var err error
existingCreds := false
Expand Down Expand Up @@ -101,18 +101,17 @@ func (input *PrimaryRoleInput) Assume(roleSessionName string, force bool) {

// Not forcing credential refresh, pull from file
if !force {
sharedCreds := awsconfig.NewSharedCredentials(input.RoleConfig.Profile)

// Check if credentials are expired
if !sharedCreds.Expired() {
creds, err = sharedCreds.Load()
if !sharedCreds.Expired(input.RoleConfig.Profile) {
creds, err = sharedCreds.Load(input.RoleConfig.Profile)
existingCreds = creds != nil
}
}
}

// Get new credentials
if force || err != nil || creds == nil || input.RoleConfig.Profile == "" {
if force || err != nil || !existingCreds || input.RoleConfig.Profile == "" {
// If session duration is defined at a role level, use that instead of the idp account level
var sessDur = input.Account.SessionDuration
if input.RoleConfig.SessionDuration > 0 {
Expand Down Expand Up @@ -163,7 +162,7 @@ func (input *PrimaryRoleInput) Assume(roleSessionName string, force bool) {
wg.Add(1)

// Perform secondary assumption
go secondaryInput.Assume(roleSessionName, force)
go secondaryInput.Assume(roleSessionName, force, sharedCreds)
}

// Create a channel to wait for completion of the wait group
Expand Down Expand Up @@ -201,7 +200,7 @@ func (input *PrimaryRoleInput) Assume(roleSessionName string, force bool) {
}

// Assume assumes a secondary role using the PrimaryRoleInput parent object
func (input *SecondaryRoleInput) Assume(roleSessionName string, force bool) {
func (input *SecondaryRoleInput) Assume(roleSessionName string, force bool, sharedCreds *awsconfig.Credentials) {
var creds *awsconfig.AWSCredentials
var err error
existingCreds := false
Expand Down Expand Up @@ -239,17 +238,15 @@ func (input *SecondaryRoleInput) Assume(roleSessionName string, force bool) {

// Not forcing credential refresh, pull from file
if !force {
sharedCreds := awsconfig.NewSharedCredentials(input.RoleAssumption.Profile)

// Check if credentials are expired
if !sharedCreds.Expired() {
creds, err = sharedCreds.Load()
if !sharedCreds.Expired(input.RoleAssumption.Profile) {
creds, err = sharedCreds.Load(input.RoleAssumption.Profile)
existingCreds = creds != nil
}
}

// Get new credentials if forced, error encountered, or credentials are not found
if force || err != nil || creds == nil {
if force || err != nil || !existingCreds {
creds, err = assumeRole(input.PrimaryCredentials, input.RoleAssumption.RoleArn, roleSessionName, region)
}

Expand Down Expand Up @@ -297,6 +294,12 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {
// Check if any credentials need to be refreshed - only run when force is false
logger.Debug("Check if credentials exist")

// Load entire file from single location
sharedCreds, err := awsconfig.LoadCredentials()
if err != nil {
logger.Fatalln(errors.Wrap(err, "couldnt load aws credentials file"))
}

// Not forced, and not assuming all roles
if !loginFlags.Force && !roleConfig.AssumeAllRoles {
var primaryExpired = false // Only prompt login if one of the parent credentials are expired
Expand All @@ -308,8 +311,7 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {
for _, primary := range roleConfig.Roles {
// Only check for expiration of parent role
if primary.Profile != "" {
sharedCreds := awsconfig.NewSharedCredentials(primary.Profile)
if sharedCreds.Expired() {
if sharedCreds.Expired(primary.Profile) {
logger.WithField("Role", primary.PrimaryRoleArn).Debugf("Creds have expired")
primaryExpired = true
noCredsExpired = false
Expand All @@ -318,7 +320,7 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {

// Not expired, set the session role name. Only need this once since it should always be the same
if sessionRoleName == "" {
creds, err := sharedCreds.Load()
creds, err := sharedCreds.Load(primary.Profile)
if err != nil {
return errors.Wrap(err, "error creating shared creds")
}
Expand Down Expand Up @@ -351,8 +353,7 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {
secondary.Profile = fmt.Sprintf("%s/%s", arnParts[4], strings.TrimPrefix(arnParts[5], "role/"))
}

sharedCreds := awsconfig.NewSharedCredentials(secondary.Profile)
if sharedCreds.Expired() {
if sharedCreds.Expired(secondary.Profile) {
logger.WithField("SecondaryRole", secondary.RoleArn).Debugf("Creds have expired")

// Secondary is expired. Add secondary role to primary
Expand Down Expand Up @@ -386,7 +387,7 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {
// Get the role session name
logger.Infof("Using primary role session to assume child roles. No need to login")

return bulkAssumeAsync(rolesToAssume, account, roleConfig, sessionRoleName, false, true, "")
return bulkAssumeAsync(sharedCreds, rolesToAssume, account, roleConfig, sessionRoleName, false, true, "")
}
}

Expand Down Expand Up @@ -494,12 +495,14 @@ func BulkLogin(loginFlags *flags.LoginExecFlags) error {
logger.Debugf("Got groups: %+v", roleConfig.Roles)
}

return bulkAssumeAsync(roleConfig.Roles, account, roleConfig, roleSessionName, loginFlags.Force, false, samlAssertion)
return bulkAssumeAsync(sharedCreds, roleConfig.Roles, account, roleConfig, roleSessionName, loginFlags.Force, false, samlAssertion)
}

// bulkAssumeAsync assumes all primary and secondary roles given in the roles slice. If useExistingCreds is true, the samlAssertion
// is NOT needed
func bulkAssumeAsync(roles []cfg.RoleConfig, account *cfg.IDPAccount, roleConfig *cfg.BulkRoleConfig, roleSessionName string, force, useExistingCreds bool, samlAssertion string) error {
func bulkAssumeAsync(sharedCreds *awsconfig.Credentials, roles []cfg.RoleConfig,
account *cfg.IDPAccount, roleConfig *cfg.BulkRoleConfig, roleSessionName string, force, useExistingCreds bool, samlAssertion string) error {

logger := logrus.WithFields(logrus.Fields{
"Action": "Bulk Assume",
"UseExistingCreds": useExistingCreds,
Expand Down Expand Up @@ -534,7 +537,7 @@ func bulkAssumeAsync(roles []cfg.RoleConfig, account *cfg.IDPAccount, roleConfig
wg.Add(1)

// Perform role assumption
go input.Assume(roleSessionName, force)
go input.Assume(roleSessionName, force, sharedCreds)
}

// Done channel
Expand All @@ -555,27 +558,29 @@ func bulkAssumeAsync(roles []cfg.RoleConfig, account *cfg.IDPAccount, roleConfig

// Handle primary creds, only need to save primary if NOT using existing creds
if creds.Input.RoleConfig.Profile != "" && !useExistingCreds {
sharedCreds := awsconfig.NewSharedCredentials(creds.Input.RoleConfig.Profile)
if err := sharedCreds.Save(creds.PrimaryCredentials); err != nil {
if err := sharedCreds.StoreCreds(creds.Input.RoleConfig.Profile, creds.PrimaryCredentials); err != nil {
return errors.Wrap(err, "error saving credentials")
}
}

// Handle secondary creds
for _, childCreds := range creds.Output {
sharedCreds := awsconfig.NewSharedCredentials(childCreds.Input.RoleAssumption.Profile)
if err := sharedCreds.Save(childCreds.Credentials); err != nil {
return errors.Wrap(err, "error saving credentials")
if err := sharedCreds.StoreCreds(childCreds.Input.RoleAssumption.Profile, childCreds.Credentials); err != nil {
return errors.Wrap(err, "error saving child credentials")
}
}
wg.Done()

case <-done:
if err := sharedCreds.Save(); err != nil {
log.Fatalf("Error storing new credentials: %v", err)
}
logger.Infof("Done!")
return nil

// Timeout
case <-time.After(time.Second * time.Duration(account.Timeout)):
// TODO: Should i store what i have so far?
logger.Errorf("Timed out after %v seconds", account.Timeout)
return errors.New("timed out while assuming roles")
}
Expand Down
119 changes: 119 additions & 0 deletions pkg/awsconfig/awsconfig_bulk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package awsconfig

import (
"os"
"path/filepath"
"time"

"gopkg.in/ini.v1"
)

// Credentials holds the original ini file to eliminate reading it multiple times throughout the process
type Credentials struct {
Squwid marked this conversation as resolved.
Show resolved Hide resolved
File *ini.File

fileLoc string
}

// LoadCredentials loads the AWS credentials file and keeps it in a config object
// with an optional fileName parameter override
func LoadCredentials(fileName ...string) (*Credentials, error) {
Squwid marked this conversation as resolved.
Show resolved Hide resolved
var file string
if len(fileName) > 0 {
// Filename was passed in as an arg
file = fileName[0]
Squwid marked this conversation as resolved.
Show resolved Hide resolved
} else {
// otherwise use default
f, err := locateConfigFile()
if err != nil {
return nil, err
}
file = f
}

logger.WithField("filename", file).Debug("ensureConfigExists")
Squwid marked this conversation as resolved.
Show resolved Hide resolved

if err := ensureCredentialsExist(file); err != nil {
return nil, err
}

// File exists, read it and load it into an ini config
credsFile, err := ini.Load(file)
if err != nil {
return nil, err
}

return &Credentials{
File: credsFile,
fileLoc: file,
}, nil
}

// Save saves the credentials file to where it was loaded from
func (creds *Credentials) Save() error {
logger.WithField("filename", creds.fileLoc).Debug("storing file")
return creds.File.SaveTo(creds.fileLoc)
}

// Expired checks to see if a profile is expired or not
func (creds *Credentials) Expired(profile string) bool {
cred, err := creds.Load(profile)
if err != nil {
return true
}

return time.Now().After(cred.Expires)
}

// Load loads a credentials file from the
func (creds *Credentials) Load(profile string) (*AWSCredentials, error) {
iniProfile, err := creds.File.GetSection(profile)
if err != nil {
return nil, ErrCredentialsNotFound
}

awsCreds := new(AWSCredentials)

if err := iniProfile.MapTo(awsCreds); err != nil {
return nil, ErrCredentialsNotFound
}

return awsCreds, nil
}

// StoreCreds takes a profile and the awsCreds to store. This does NOT save the file, that needs to be called later
func (creds *Credentials) StoreCreds(profile string, awsCreds *AWSCredentials) error {
iniProfile, err := creds.File.NewSection(profile)
if err != nil {
return err
}

if err := iniProfile.ReflectFrom(awsCreds); err != nil {
return err
}

return nil
}

func ensureCredentialsExist(file string) error {
if _, err := os.Stat(file); err != nil {
if os.IsNotExist(err) {
// File does not exist, create it
dir := filepath.Dir(file)

if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return err
}

logger.WithField("dir", dir).Debugf("Dir created")

if _, err := os.Create(file); err != nil {
return err
}

logger.WithField("file", file).Debugf("File created")
}
return nil
}
return nil
}