diff --git a/cmd/saml2aws/commands/console.go b/cmd/saml2aws/commands/console.go index ce46a17e5..c61ece7cf 100644 --- a/cmd/saml2aws/commands/console.go +++ b/cmd/saml2aws/commands/console.go @@ -3,6 +3,7 @@ package commands import ( "encoding/json" "fmt" + "github.com/versent/saml2aws/pkg/cfg" "io/ioutil" "net/http" "net/url" @@ -14,6 +15,10 @@ import ( "github.com/versent/saml2aws/pkg/flags" ) +const ( + federationURL = "https://signin.aws.amazon.com/federation" +) + // Exec execute the supplied command after seeding the environment func Console(consoleFlags *flags.LoginExecFlags) error { @@ -36,22 +41,10 @@ func Console(consoleFlags *flags.LoginExecFlags) error { return nil } - awsCreds, err := sharedCreds.Load() + awsCreds, err := loadOrLogin(account, sharedCreds, consoleFlags) if err != nil { - return errors.Wrap(err, "error loading credentials") - } - - if awsCreds.Expires.Sub(time.Now()) < 0 { - return errors.New("error aws credentials have expired") - } - - ok, err := checkToken(account.Profile) - if err != nil { - return errors.Wrap(err, "error validating token") - } - - if !ok { - err = Login(consoleFlags) + return errors.Wrap(err, + fmt.Sprintf("error loading credentials for profile: %s", consoleFlags.ExecProfile)) } if err != nil { return errors.Wrap(err, "error logging in") @@ -66,11 +59,55 @@ func Console(consoleFlags *flags.LoginExecFlags) error { } } - fmt.Printf("Opening console for profile %s ...\n", account.Profile) - + fmt.Printf("Presenting credentials for %s to %s\n", account.Profile, federationURL) return federatedLogin(awsCreds, consoleFlags) } +func loadOrLogin(account *cfg.IDPAccount, sharedCreds *awsconfig.CredentialsProvider, execFlags *flags.LoginExecFlags) (*awsconfig.AWSCredentials, error) { + + var err error + + if execFlags.Force { + fmt.Println("force login requested") + return loginRefreshCredentials(sharedCreds, execFlags) + } + + awsCreds, err := sharedCreds.Load() + if err != nil { + if err != awsconfig.ErrCredentialsNotFound { + return nil, errors.Wrap(err, "failed to load credentials") + } + fmt.Println("credentials not found triggering login") + return loginRefreshCredentials(sharedCreds, execFlags) + } + + if awsCreds.Expires.Sub(time.Now()) < 0 { + fmt.Println("expired credentials triggering login") + return loginRefreshCredentials(sharedCreds, execFlags) + } + + ok, err := checkToken(account.Profile) + if err != nil { + return nil, errors.Wrap(err, "error validating token") + } + + if !ok { + fmt.Println("aws rejected credentials triggering login") + return loginRefreshCredentials(sharedCreds, execFlags) + } + + return awsCreds, nil +} + +func loginRefreshCredentials(sharedCreds *awsconfig.CredentialsProvider, execFlags *flags.LoginExecFlags) (*awsconfig.AWSCredentials, error) { + err := Login(execFlags) + if err != nil { + return nil, errors.Wrap(err, "error logging in") + } + + return sharedCreds.Load() +} + func federatedLogin(creds *awsconfig.AWSCredentials, consoleFlags *flags.LoginExecFlags) error { jsonBytes, err := json.Marshal(map[string]string{ "sessionId": creds.AWSAccessKey, @@ -81,7 +118,7 @@ func federatedLogin(creds *awsconfig.AWSCredentials, consoleFlags *flags.LoginEx return err } - req, err := http.NewRequest("GET", "https://signin.aws.amazon.com/federation", nil) + req, err := http.NewRequest("GET", federationURL, nil) if err != nil { return err } @@ -119,7 +156,8 @@ func federatedLogin(creds *awsconfig.AWSCredentials, consoleFlags *flags.LoginEx destination := "https://console.aws.amazon.com/" loginURL := fmt.Sprintf( - "https://signin.aws.amazon.com/federation?Action=login&Issuer=aws-okta&Destination=%s&SigninToken=%s", + "%s?Action=login&Issuer=aws-okta&Destination=%s&SigninToken=%s", + federationURL, url.QueryEscape(destination), url.QueryEscape(signinToken), ) diff --git a/cmd/saml2aws/main.go b/cmd/saml2aws/main.go index 0cd4b7d9f..fb9707110 100644 --- a/cmd/saml2aws/main.go +++ b/cmd/saml2aws/main.go @@ -100,6 +100,7 @@ func main() { consoleFlags := new(flags.LoginExecFlags) consoleFlags.CommonFlags = commonFlags cmdConsole.Flag("profile", "The AWS profile to save the temporary credentials. (env: SAML2AWS_PROFILE)").Envar("SAML2AWS_PROFILE").Short('p').StringVar(&commonFlags.Profile) + cmdConsole.Flag("force", "Refresh credentials even if not expired.").BoolVar(&consoleFlags.Force) // `list` command and settings cmdListRoles := app.Command("list-roles", "List available role ARNs.")