Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(console) Refactor token / config checks to login correctly #412

Merged
merged 1 commit into from
Jan 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions cmd/saml2aws/commands/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package commands
import (
"encoding/json"
"fmt"
"github.com/versent/saml2aws/pkg/cfg"
"io/ioutil"
"net/http"
"net/url"
Expand All @@ -14,6 +15,10 @@ import (
"github.com/versent/saml2aws/pkg/flags"
)

const (
federationURL = "https://signin.aws.amazon.com/federation"
)

// Exec execute the supplied command after seeding the environment
func Console(consoleFlags *flags.LoginExecFlags) error {

Expand All @@ -36,22 +41,10 @@ func Console(consoleFlags *flags.LoginExecFlags) error {
return nil
}

awsCreds, err := sharedCreds.Load()
awsCreds, err := loadOrLogin(account, sharedCreds, consoleFlags)
if err != nil {
return errors.Wrap(err, "error loading credentials")
}

if awsCreds.Expires.Sub(time.Now()) < 0 {
return errors.New("error aws credentials have expired")
}

ok, err := checkToken(account.Profile)
if err != nil {
return errors.Wrap(err, "error validating token")
}

if !ok {
err = Login(consoleFlags)
return errors.Wrap(err,
fmt.Sprintf("error loading credentials for profile: %s", consoleFlags.ExecProfile))
}
if err != nil {
return errors.Wrap(err, "error logging in")
Expand All @@ -66,11 +59,55 @@ func Console(consoleFlags *flags.LoginExecFlags) error {
}
}

fmt.Printf("Opening console for profile %s ...\n", account.Profile)

fmt.Printf("Presenting credentials for %s to %s\n", account.Profile, federationURL)
return federatedLogin(awsCreds, consoleFlags)
}

func loadOrLogin(account *cfg.IDPAccount, sharedCreds *awsconfig.CredentialsProvider, execFlags *flags.LoginExecFlags) (*awsconfig.AWSCredentials, error) {

var err error

if execFlags.Force {
fmt.Println("force login requested")
return loginRefreshCredentials(sharedCreds, execFlags)
}

awsCreds, err := sharedCreds.Load()
if err != nil {
if err != awsconfig.ErrCredentialsNotFound {
return nil, errors.Wrap(err, "failed to load credentials")
}
fmt.Println("credentials not found triggering login")
return loginRefreshCredentials(sharedCreds, execFlags)
}

if awsCreds.Expires.Sub(time.Now()) < 0 {
fmt.Println("expired credentials triggering login")
return loginRefreshCredentials(sharedCreds, execFlags)
}

ok, err := checkToken(account.Profile)
if err != nil {
return nil, errors.Wrap(err, "error validating token")
}

if !ok {
fmt.Println("aws rejected credentials triggering login")
return loginRefreshCredentials(sharedCreds, execFlags)
}

return awsCreds, nil
}

func loginRefreshCredentials(sharedCreds *awsconfig.CredentialsProvider, execFlags *flags.LoginExecFlags) (*awsconfig.AWSCredentials, error) {
err := Login(execFlags)
if err != nil {
return nil, errors.Wrap(err, "error logging in")
}

return sharedCreds.Load()
}

func federatedLogin(creds *awsconfig.AWSCredentials, consoleFlags *flags.LoginExecFlags) error {
jsonBytes, err := json.Marshal(map[string]string{
"sessionId": creds.AWSAccessKey,
Expand All @@ -81,7 +118,7 @@ func federatedLogin(creds *awsconfig.AWSCredentials, consoleFlags *flags.LoginEx
return err
}

req, err := http.NewRequest("GET", "https://signin.aws.amazon.com/federation", nil)
req, err := http.NewRequest("GET", federationURL, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -119,7 +156,8 @@ func federatedLogin(creds *awsconfig.AWSCredentials, consoleFlags *flags.LoginEx
destination := "https://console.aws.amazon.com/"

loginURL := fmt.Sprintf(
"https://signin.aws.amazon.com/federation?Action=login&Issuer=aws-okta&Destination=%s&SigninToken=%s",
"%s?Action=login&Issuer=aws-okta&Destination=%s&SigninToken=%s",
federationURL,
url.QueryEscape(destination),
url.QueryEscape(signinToken),
)
Expand Down
1 change: 1 addition & 0 deletions cmd/saml2aws/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ func main() {
consoleFlags := new(flags.LoginExecFlags)
consoleFlags.CommonFlags = commonFlags
cmdConsole.Flag("profile", "The AWS profile to save the temporary credentials. (env: SAML2AWS_PROFILE)").Envar("SAML2AWS_PROFILE").Short('p').StringVar(&commonFlags.Profile)
cmdConsole.Flag("force", "Refresh credentials even if not expired.").BoolVar(&consoleFlags.Force)

// `list` command and settings
cmdListRoles := app.Command("list-roles", "List available role ARNs.")
Expand Down