diff --git a/README.md b/README.md index 286100e41..fedaadcaa 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,8 @@ Commands: --config=CONFIG Path/filename of saml2aws config file (env: SAML2AWS_CONFIGFILE) --cache-saml Caches the SAML response (env: SAML2AWS_CACHE_SAML) --cache-file=CACHE-FILE The location of the SAML cache file (env: SAML2AWS_SAML_CACHE_FILE) + --disable-sessions Do not use Okta sessions. Uses Okta sessions by default. (env: SAML2AWS_OKTA_DISABLE_SESSIONS) + --disable-remember-device Do not remember Okta MFA device. Remembers MFA device by default. (env: SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE) login [] Login to a SAML 2.0 IDP and convert the SAML assertion to an STS token. @@ -199,7 +201,8 @@ Commands: The file that will cache the credentials retrieved from AWS. When not specified, will use the default AWS credentials file location. (env: SAML2AWS_CREDENTIALS_FILE) --cache-saml Caches the SAML response (env: SAML2AWS_CACHE_SAML) --cache-file=CACHE-FILE The location of the SAML cache file (env: SAML2AWS_SAML_CACHE_FILE) - + --disable-sessions Do not use Okta sessions. Uses Okta sessions by default. (env: SAML2AWS_OKTA_DISABLE_SESSIONS) + --disable-remember-device Do not remember Okta MFA device. Remembers MFA device by default. (env: SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE) exec [] [...] Exec the supplied command with env vars from STS token. @@ -677,6 +680,23 @@ there is a file per saml2aws profile, the cache directory is called `saml2aws` a You can toggle `--cache-saml` during `login` or during `list-roles`, and you can set it once during `configure` and use it implicitly. +# Okta Sessions + +This requires the use of the keychain (local credentials store). If you disabled the keychain using `--disable-keychain`, Okta sessions will also be disabled. + +Okta sessions are enabled by default. This will store the Okta session locally and save your device for MFA. This means that if the session has not yet expired, you will not be prompted for MFA. + +* To disable remembering the device, you can toggle `--disable-remember-device` during `login` or `configure` commands. +* To disable using Okta sessions, you can toggle `--disable-sessions` during `login` or `configure` commands. + * This will also disable the Okta MFA remember device feature + +Use the `--force` flag during `login` command to prompt for AWS role selection. + +If Okta sessions are disabled via any of the methods mentioned above, the login process will default to the standard authentication process (without using sessions). + +Please note that your Okta session duration and MFA policies are governed by your Okta host organization. + + # License This code is Copyright (c) 2018 [Versent](http://versent.com.au) and released under the MIT license. All rights not explicitly granted in the MIT license are reserved. See the included LICENSE.md file for more details. diff --git a/cmd/saml2aws/commands/login.go b/cmd/saml2aws/commands/login.go index 91b0c6585..418a38872 100644 --- a/cmd/saml2aws/commands/login.go +++ b/cmd/saml2aws/commands/login.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "os" + "strings" "time" "github.com/aws/aws-sdk-go/aws" @@ -187,6 +188,11 @@ func resolveLoginDetails(account *cfg.IDPAccount, loginFlags *flags.LoginExecFla return nil, errors.Wrap(err, "error loading saved password") } } + } else { // if user disabled keychain, dont use Okta sessions & dont remember Okta MFA device + if strings.ToLower(account.Provider) == "okta" { + account.DisableSessions = true + account.DisableRememberDevice = true + } } // log.Printf("%s %s", savedUsername, savedPassword) diff --git a/cmd/saml2aws/commands/login_test.go b/cmd/saml2aws/commands/login_test.go index 762ccf424..0e29f3640 100644 --- a/cmd/saml2aws/commands/login_test.go +++ b/cmd/saml2aws/commands/login_test.go @@ -1,6 +1,7 @@ package commands import ( + "fmt" "testing" "time" @@ -29,6 +30,39 @@ func TestResolveLoginDetailsWithFlags(t *testing.T) { assert.Equal(t, &creds.LoginDetails{Username: "wolfeidau", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456"}, loginDetails) } +func TestOktaResolveLoginDetailsWithFlags(t *testing.T) { + + // Default state - user did not supply values for DisableSessions and DisableSessions + commonFlags := &flags.CommonFlags{URL: "https://id.example.com", Username: "testuser", Password: "testtestlol", MFAToken: "123456", SkipPrompt: true} + loginFlags := &flags.LoginExecFlags{CommonFlags: commonFlags} + + idpa := &cfg.IDPAccount{ + URL: "https://id.example.com", + MFA: "none", + Provider: "Okta", + Username: "testuser", + } + loginDetails, err := resolveLoginDetails(idpa, loginFlags) + + assert.Nil(t, err) + assert.False(t, idpa.DisableSessions, fmt.Errorf("default state, DisableSessions should be false")) + assert.False(t, idpa.DisableRememberDevice, fmt.Errorf("default state, DisableRememberDevice should be false")) + assert.Equal(t, &creds.LoginDetails{Username: "testuser", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456"}, loginDetails) + + // User disabled keychain, resolveLoginDetails should set the account's DisableSessions and DisableSessions fields to true + + commonFlags = &flags.CommonFlags{URL: "https://id.example.com", Username: "testuser", Password: "testtestlol", MFAToken: "123456", SkipPrompt: true, DisableKeychain: true} + loginFlags = &flags.LoginExecFlags{CommonFlags: commonFlags} + + loginDetails, err = resolveLoginDetails(idpa, loginFlags) + + assert.Nil(t, err) + assert.True(t, idpa.DisableSessions, fmt.Errorf("user disabled keychain, DisableSessions should be true")) + assert.True(t, idpa.DisableRememberDevice, fmt.Errorf("user disabled keychain, DisableRememberDevice should be true")) + assert.Equal(t, &creds.LoginDetails{Username: "testuser", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456"}, loginDetails) + +} + func TestResolveRoleSingleEntry(t *testing.T) { adminRole := &saml2aws.AWSRole{ diff --git a/cmd/saml2aws/main.go b/cmd/saml2aws/main.go index 603654db1..712d89c00 100644 --- a/cmd/saml2aws/main.go +++ b/cmd/saml2aws/main.go @@ -80,7 +80,7 @@ func main() { app.Flag("aws-urn", "The URN used by SAML when you login. (env: SAML2AWS_AWS_URN)").Envar("SAML2AWS_AWS_URN").StringVar(&commonFlags.AmazonWebservicesURN) app.Flag("skip-prompt", "Skip prompting for parameters during login.").BoolVar(&commonFlags.SkipPrompt) app.Flag("session-duration", "The duration of your AWS Session. (env: SAML2AWS_SESSION_DURATION)").Envar("SAML2AWS_SESSION_DURATION").IntVar(&commonFlags.SessionDuration) - app.Flag("disable-keychain", "Do not use keychain at all.").Envar("SAML2AWS_DISABLE_KEYCHAIN").BoolVar(&commonFlags.DisableKeychain) + app.Flag("disable-keychain", "Do not use keychain at all. This will also disable Okta sessions & remembering MFA device. (env: SAML2AWS_DISABLE_KEYCHAIN)").Envar("SAML2AWS_DISABLE_KEYCHAIN").BoolVar(&commonFlags.DisableKeychain) app.Flag("region", "AWS region to use for API requests, e.g. us-east-1, us-gov-west-1, cn-north-1 (env: SAML2AWS_REGION)").Envar("SAML2AWS_REGION").Short('r').StringVar(&commonFlags.Region) // `configure` command and settings @@ -94,6 +94,8 @@ func main() { cmdConfigure.Flag("credentials-file", "The file that will cache the credentials retrieved from AWS. When not specified, will use the default AWS credentials file location. (env: SAML2AWS_CREDENTIALS_FILE)").Envar("SAML2AWS_CREDENTIALS_FILE").StringVar(&commonFlags.CredentialsFile) cmdConfigure.Flag("cache-saml", "Caches the SAML response (env: SAML2AWS_CACHE_SAML)").Envar("SAML2AWS_CACHE_SAML").BoolVar(&commonFlags.SAMLCache) cmdConfigure.Flag("cache-file", "The location of the SAML cache file (env: SAML2AWS_SAML_CACHE_FILE)").Envar("SAML2AWS_SAML_CACHE_FILE").StringVar(&commonFlags.SAMLCacheFile) + cmdConfigure.Flag("disable-sessions", "Do not use Okta sessions. Uses Okta sessions by default. (env: SAML2AWS_OKTA_DISABLE_SESSIONS)").Envar("SAML2AWS_OKTA_DISABLE_SESSIONS").BoolVar(&commonFlags.DisableSessions) + cmdConfigure.Flag("disable-remember-device", "Do not remember Okta MFA device. Remembers MFA device by default. (env: SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE)").Envar("SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE").BoolVar(&commonFlags.DisableRememberDevice) configFlags := commonFlags // `login` command and settings @@ -109,6 +111,8 @@ func main() { cmdLogin.Flag("credentials-file", "The file that will cache the credentials retrieved from AWS. When not specified, will use the default AWS credentials file location. (env: SAML2AWS_CREDENTIALS_FILE)").Envar("SAML2AWS_CREDENTIALS_FILE").StringVar(&commonFlags.CredentialsFile) cmdLogin.Flag("cache-saml", "Caches the SAML response (env: SAML2AWS_CACHE_SAML)").Envar("SAML2AWS_CACHE_SAML").BoolVar(&commonFlags.SAMLCache) cmdLogin.Flag("cache-file", "The location of the SAML cache file (env: SAML2AWS_SAML_CACHE_FILE)").Envar("SAML2AWS_SAML_CACHE_FILE").StringVar(&commonFlags.SAMLCacheFile) + cmdLogin.Flag("disable-sessions", "Do not use Okta sessions. Uses Okta sessions by default. (env: SAML2AWS_OKTA_DISABLE_SESSIONS)").Envar("SAML2AWS_OKTA_DISABLE_SESSIONS").BoolVar(&commonFlags.DisableSessions) + cmdLogin.Flag("disable-remember-device", "Do not remember Okta MFA device. Remembers MFA device by default. (env: SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE)").Envar("SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE").BoolVar(&commonFlags.DisableRememberDevice) // `exec` command and settings cmdExec := app.Command("exec", "Exec the supplied command with env vars from STS token.") diff --git a/go.mod b/go.mod index 3e7bd9e5c..155eec4aa 100644 --- a/go.mod +++ b/go.mod @@ -34,8 +34,8 @@ require ( github.com/tidwall/gjson v1.1.1 github.com/tidwall/match v1.0.0 // indirect golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad // indirect - golang.org/x/net v0.0.0-20210119194325-5f4716e94777 - golang.org/x/sys v0.0.0-20210218155724-8ebf48af031b // indirect + golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 + golang.org/x/sys v0.0.0-20210603125802-9665404d3644 // indirect golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf // indirect golang.org/x/text v0.3.5 // indirect gopkg.in/ini.v1 v1.62.0 diff --git a/go.sum b/go.sum index 01956e54f..7fb42cc3c 100644 --- a/go.sum +++ b/go.sum @@ -218,8 +218,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210119194325-5f4716e94777 h1:003p0dJM77cxMSyCPFphvZf/Y5/NXf5fzg6ufd1/Oew= -golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -239,8 +239,9 @@ golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210218155724-8ebf48af031b h1:lAZ0/chPUDWwjqosYR0X4M490zQhMsiJ4K3DbA7o+3g= -golang.org/x/sys v0.0.0-20210218155724-8ebf48af031b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210603125802-9665404d3644 h1:CA1DEQ4NdKphKeL70tvsWNdT5oFh1lOjihRcEDROi0I= +golang.org/x/sys v0.0.0-20210603125802-9665404d3644/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf h1:MZ2shdL+ZM/XzY3ZGOnh4Nlpnxz5GSOhOmtHo3iPU6M= diff --git a/helper/credentials/saml.go b/helper/credentials/saml.go index 7f080ed9c..0f3ba65a8 100644 --- a/helper/credentials/saml.go +++ b/helper/credentials/saml.go @@ -17,6 +17,14 @@ func LookupCredentials(loginDetails *creds.LoginDetails, provider string) error loginDetails.Username = username loginDetails.Password = password + // If the provider is Okta, check for existing Okta Session Cookie (sid) + if provider == "Okta" { + _, oktaSessionCookie, err := CurrentHelper.Get(loginDetails.URL + "/sessionCookie") + if err == nil { + loginDetails.OktaSessionCookie = oktaSessionCookie + } + } + if provider == "OneLogin" { id, secret, err := CurrentHelper.Get(path.Join(loginDetails.URL, "/auth/oauth2/v2/token")) if err != nil { diff --git a/pkg/cfg/cfg.go b/pkg/cfg/cfg.go index b2ad37471..b0c53f5b9 100644 --- a/pkg/cfg/cfg.go +++ b/pkg/cfg/cfg.go @@ -30,32 +30,35 @@ const ( // IDPAccount saml IDP account type IDPAccount struct { - Name string `ini:"name"` - AppID string `ini:"app_id"` // used by OneLogin and AzureAD - URL string `ini:"url"` - Username string `ini:"username"` - Provider string `ini:"provider"` - MFA string `ini:"mfa"` - SkipVerify bool `ini:"skip_verify"` - Timeout int `ini:"timeout"` - AmazonWebservicesURN string `ini:"aws_urn"` - SessionDuration int `ini:"aws_session_duration"` - Profile string `ini:"aws_profile"` - ResourceID string `ini:"resource_id"` // used by F5APM - Subdomain string `ini:"subdomain"` // used by OneLogin - RoleARN string `ini:"role_arn"` - Region string `ini:"region"` - HttpAttemptsCount string `ini:"http_attempts_count"` - HttpRetryDelay string `ini:"http_retry_delay"` - CredentialsFile string `ini:"credentials_file"` - SAMLCache bool `ini:"saml_cache"` - SAMLCacheFile string `ini:"saml_cache_file"` - TargetURL string `ini:"target_url"` + Name string `ini:"name"` + AppID string `ini:"app_id"` // used by OneLogin and AzureAD + URL string `ini:"url"` + Username string `ini:"username"` + Provider string `ini:"provider"` + MFA string `ini:"mfa"` + SkipVerify bool `ini:"skip_verify"` + Timeout int `ini:"timeout"` + AmazonWebservicesURN string `ini:"aws_urn"` + SessionDuration int `ini:"aws_session_duration"` + Profile string `ini:"aws_profile"` + ResourceID string `ini:"resource_id"` // used by F5APM + Subdomain string `ini:"subdomain"` // used by OneLogin + RoleARN string `ini:"role_arn"` + Region string `ini:"region"` + HttpAttemptsCount string `ini:"http_attempts_count"` + HttpRetryDelay string `ini:"http_retry_delay"` + CredentialsFile string `ini:"credentials_file"` + SAMLCache bool `ini:"saml_cache"` + SAMLCacheFile string `ini:"saml_cache_file"` + TargetURL string `ini:"target_url"` + DisableRememberDevice bool `ini:"disable_remember_device"` // used by Okta + DisableSessions bool `ini:"disable_sessions"` // used by Okta } func (ia IDPAccount) String() string { var appID string var policyID string + var oktaCfg string switch ia.Provider { case "OneLogin": appID = fmt.Sprintf(` @@ -66,9 +69,13 @@ func (ia IDPAccount) String() string { case "AzureAD": appID = fmt.Sprintf(` AppID: %s`, ia.AppID) + case "Okta": + oktaCfg = fmt.Sprintf(` + DisableSessions: %v + DisableRememberDevice: %v`, ia.DisableSessions, ia.DisableSessions) } - return fmt.Sprintf(`account {%s%s + return fmt.Sprintf(`account {%s%s%s URL: %s Username: %s Provider: %s @@ -79,7 +86,7 @@ func (ia IDPAccount) String() string { Profile: %s RoleARN: %s Region: %s -}`, appID, policyID, ia.URL, ia.Username, ia.Provider, ia.MFA, ia.SkipVerify, ia.AmazonWebservicesURN, ia.SessionDuration, ia.Profile, ia.RoleARN, ia.Region) +}`, appID, policyID, oktaCfg, ia.URL, ia.Username, ia.Provider, ia.MFA, ia.SkipVerify, ia.AmazonWebservicesURN, ia.SessionDuration, ia.Profile, ia.RoleARN, ia.Region) } // Validate validate the required / expected fields are set diff --git a/pkg/creds/creds.go b/pkg/creds/creds.go index ed08611c4..e006216ba 100644 --- a/pkg/creds/creds.go +++ b/pkg/creds/creds.go @@ -2,12 +2,13 @@ package creds // LoginDetails used to authenticate type LoginDetails struct { - ClientID string // used by OneLogin - ClientSecret string // used by OneLogin - Username string - Password string - MFAToken string - DuoMFAOption string - URL string - StateToken string // used by Okta + ClientID string // used by OneLogin + ClientSecret string // used by OneLogin + Username string + Password string + MFAToken string + DuoMFAOption string + URL string + StateToken string // used by Okta + OktaSessionCookie string // used by Okta } diff --git a/pkg/flags/flags.go b/pkg/flags/flags.go index a84b0ec3e..54ab38c82 100644 --- a/pkg/flags/flags.go +++ b/pkg/flags/flags.go @@ -6,30 +6,32 @@ import ( // CommonFlags flags common to all of the `saml2aws` commands (except `help`) type CommonFlags struct { - AppID string - ClientID string - ClientSecret string - ConfigFile string - IdpAccount string - IdpProvider string - MFA string - MFAToken string - URL string - Username string - Password string - RoleArn string - AmazonWebservicesURN string - SessionDuration int - SkipPrompt bool - SkipVerify bool - Profile string - Subdomain string - ResourceID string - DisableKeychain bool - Region string - CredentialsFile string - SAMLCache bool - SAMLCacheFile string + AppID string + ClientID string + ClientSecret string + ConfigFile string + IdpAccount string + IdpProvider string + MFA string + MFAToken string + URL string + Username string + Password string + RoleArn string + AmazonWebservicesURN string + SessionDuration int + SkipPrompt bool + SkipVerify bool + Profile string + Subdomain string + ResourceID string + DisableKeychain bool + Region string + CredentialsFile string + SAMLCache bool + SAMLCacheFile string + DisableRememberDevice bool + DisableSessions bool } // LoginExecFlags flags for the Login / Exec commands @@ -106,4 +108,10 @@ func ApplyFlagOverrides(commonFlags *CommonFlags, account *cfg.IDPAccount) { if commonFlags.SAMLCacheFile != "" { account.SAMLCacheFile = commonFlags.SAMLCacheFile } + if commonFlags.DisableRememberDevice { + account.DisableRememberDevice = commonFlags.DisableRememberDevice + } + if commonFlags.DisableSessions { + account.DisableSessions = commonFlags.DisableSessions + } } diff --git a/pkg/provider/okta/okta.go b/pkg/provider/okta/okta.go index 7361e31d4..9855fa85b 100644 --- a/pkg/provider/okta/okta.go +++ b/pkg/provider/okta/okta.go @@ -13,6 +13,7 @@ import ( "net/http/cookiejar" "net/url" "regexp" + "strconv" "strings" "time" @@ -21,6 +22,7 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" + "github.com/versent/saml2aws/v2/helper/credentials" "github.com/versent/saml2aws/v2/pkg/cfg" "github.com/versent/saml2aws/v2/pkg/creds" "github.com/versent/saml2aws/v2/pkg/page" @@ -59,9 +61,11 @@ var ( type Client struct { provider.ValidateBase - client *provider.HTTPClient - mfa string - targetURL string + client *provider.HTTPClient + mfa string + targetURL string + disableSessions bool + rememberDevice bool } // AuthRequest represents an mfa okta request @@ -73,8 +77,18 @@ type AuthRequest struct { // VerifyRequest represents an mfa verify request type VerifyRequest struct { - StateToken string `json:"stateToken"` - PassCode string `json:"passCode,omitempty"` + StateToken string `json:"stateToken"` + PassCode string `json:"passCode,omitempty"` + RememberDevice string `json:"rememberDevice,omitempty"` // This is needed to remember Okta MFA device +} + +// Articles referencing the Okta MFA + remembering device +// https://developer.okta.com/docs/reference/api/authn/#verify-security-question-factor +// https://devforum.okta.com/t/how-per-device-remember-me-api-works/3955/3 + +// SessionRequst holds the SessionToken used to create an Okta Session +type SessionRequst struct { + SessionToken string `json:"sessionToken"` } // mfaChallengeContext is used to hold MFA challenge context in a simple struct. @@ -106,35 +120,286 @@ func New(idpAccount *cfg.IDPAccount) (*Client, error) { } client.Jar = jar + disableSessions := idpAccount.DisableSessions + rememberDevice := !idpAccount.DisableRememberDevice + + if idpAccount.DisableSessions { // if user disabled sessions, also dont remember device + rememberDevice = false + } + + // Debug the disableSessions and rememberDevice values + logger.Debugf("okta | disableSessions: %v", disableSessions) + logger.Debugf("okta | rememberDevice: %v", rememberDevice) + return &Client{ - client: client, - mfa: idpAccount.MFA, - targetURL: idpAccount.TargetURL, + client: client, + mfa: idpAccount.MFA, + targetURL: idpAccount.TargetURL, + disableSessions: disableSessions, + rememberDevice: rememberDevice, }, nil } type ctxKey string -// Authenticate logs into Okta and returns a SAML response -func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) { +// createSession calls the Okta sessions API to create a new session using the sessionToken passed in +func (oc *Client) createSession(loginDetails *creds.LoginDetails, sessionToken string) (string, string, error) { + logger.Debug("create session func called") + if loginDetails == nil || sessionToken == "" { + logger.Debugf("unable to create an Okta session, nil input | loginDetails: %v | sessionToken: %s", loginDetails, sessionToken) + return "", "", fmt.Errorf("unable to create an okta session, nil input") + } oktaURL, err := url.Parse(loginDetails.URL) if err != nil { - return "", errors.Wrap(err, "error building oktaURL") + return "", "", errors.Wrap(err, "error building okta url") } oktaOrgHost := oktaURL.Host + //authenticate via okta api + sessionReq := SessionRequst{SessionToken: sessionToken} + sessionReqBody := new(bytes.Buffer) + err = json.NewEncoder(sessionReqBody).Encode(sessionReq) + if err != nil { + return "", "", errors.Wrap(err, "error encoding session req") + } + + sessionReqURL := fmt.Sprintf("https://%s/api/v1/sessions", oktaOrgHost) + + req, err := http.NewRequest("POST", sessionReqURL, sessionReqBody) + if err != nil { + return "", "", errors.Wrap(err, "error building new session request") + } + + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + + res, err := oc.client.Do(req) + if err != nil { + return "", "", errors.Wrap(err, "error retrieving session response") + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return "", "", errors.Wrap(err, "error retrieving body from response") + } + + if res.StatusCode != 200 { // https://developer.okta.com/docs/reference/api/sessions/#response-parameters + if res.StatusCode == 401 { + return "", "", fmt.Errorf("unable to create an Okta session, invalid sessionToken") + } + return "", "", fmt.Errorf("unable to create an Okta session, HTTP Code: %d", res.StatusCode) + } + + resp := string(body) + + oktaSessionExpiresAtStr := gjson.Get(resp, "expiresAt").String() + logger.Debugf("okta session expires at: %s", oktaSessionExpiresAtStr) + + oktaSessionCookie := gjson.Get(resp, "id").String() + + err = credentials.SaveCredentials(loginDetails.URL+"/sessionCookie", loginDetails.Username, oktaSessionCookie) + if err != nil { + return "", "", fmt.Errorf("error storing okta session token | err: %v", err) + } + + oktaSessionToken := gjson.Get(resp, "sessionToken").String() + sessionResponseStatus := gjson.Get(resp, "status").String() + switch sessionResponseStatus { + case "ACTIVE": + logger.Debug("okta session established") + case "MFA_REQUIRED": + oktaSessionToken, err = verifyMfa(oc, oktaOrgHost, loginDetails, resp) + if err != nil { + return "", "", errors.Wrap(err, "error verifying MFA") + } + case "MFA_ENROLL": + // Not yet fully implemented, most likely no need, so just return the status as the error string... + return "", "", fmt.Errorf("MFA_ENROLL") + } + + return oktaSessionCookie, oktaSessionToken, nil +} + +// validateSession calls the Okta session API to check if the session is valid +// returns an error if the session is NOT valid +func (oc *Client) validateSession(loginDetails *creds.LoginDetails) error { + logger.Debug("validate session func called") + + if loginDetails == nil { + logger.Debug("unable to validate the okta session, nil input") + return fmt.Errorf("unable to validate the okta session, nil input") + } + + sessionCookie := loginDetails.OktaSessionCookie + + oktaURL, err := url.Parse(loginDetails.URL) + if err != nil { + return errors.Wrap(err, "error building oktaURL") + } + + oktaOrgHost := oktaURL.Host + + sessionReqURL := fmt.Sprintf("https://%s/api/v1/sessions/me", oktaOrgHost) // This api endpoint returns user details + sessionReqBody := new(bytes.Buffer) + + req, err := http.NewRequest("GET", sessionReqURL, sessionReqBody) + if err != nil { + return errors.Wrap(err, "error building new session request") + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + req.Header.Add("Cookie", fmt.Sprintf("sid=%s", sessionCookie)) + + res, err := oc.client.Do(req) + if err != nil { + return errors.Wrap(err, "error retrieving session response") + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return errors.Wrap(err, "error retrieving body from response") + } + + resp := string(body) + + if res.StatusCode != 200 { + logger.Debug("invalid okta session") + return fmt.Errorf("invalid okta session") + } else { + sessionResponseStatus := gjson.Get(resp, "status").String() + switch sessionResponseStatus { + case "ACTIVE": + logger.Debug("okta session established") + case "MFA_REQUIRED": + _, err := verifyMfa(oc, oktaOrgHost, loginDetails, resp) + if err != nil { + return errors.Wrap(err, "error verifying MFA") + } + case "MFA_ENROLL": + // Not yet fully implemented, so just return the status as the error string... + return fmt.Errorf("MFA_ENROLL") + } + } + + logger.Debug("valid okta session") + return nil +} + +// authWithSession authenticates user via sessions API -> direct to target URL using follow func +func (oc *Client) authWithSession(loginDetails *creds.LoginDetails) (string, error) { + logger.Debug("auth with session func called") + sessionCookie := loginDetails.OktaSessionCookie + err := oc.validateSession(loginDetails) + if err != nil { + modifiedLoginDetails := loginDetails + modifiedLoginDetails.OktaSessionCookie = "" + return oc.Authenticate(modifiedLoginDetails) + } + + req, err := http.NewRequest("GET", loginDetails.URL, nil) + if err != nil { + return "", errors.Wrap(err, "error building authWithSession request") + } + + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + req.Header.Add("Cookie", fmt.Sprintf("sid=%s", sessionCookie)) + + ctx := context.WithValue(context.Background(), ctxKey("authWithSession"), loginDetails) + + res, err := oc.client.Do(req) + if err != nil { + logger.Debugf("error authing with session: %v", err) + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + logger.Debugf("error reading body for auth with session: %v", err) + } + + // This usually happens if using an active session (> 5 mins) but MFA was NOT remembered + if strings.Contains(string(body), "/login/step-up/") { // https://developer.okta.com/docs/reference/api/authn/#step-up-authentication-with-okta-session + logger.Debug("okta step-up prompted, need mfa...") + stateToken, err := getStateTokenFromOktaPageBody(string(body)) + if err != nil { + return "", errors.Wrap(err, "error retrieving saml response") + } + loginDetails.StateToken = stateToken + return oc.Authenticate(loginDetails) + } + + return oc.follow(ctx, req, loginDetails) +} + +// getDeviceTokenFromOkta creates a dummy HTTP call to Okta and returns the device token +// cookie value +// This function is not currently used and but can be used in the future +func (oc *Client) getDeviceTokenFromOkta(loginDetails *creds.LoginDetails) (string, error) { //dummy request to set device token cookie ("dt") req, err := http.NewRequest("GET", loginDetails.URL, nil) if err != nil { return "", errors.Wrap(err, "error building device token request") } - _, err = oc.client.Do(req) + resp, err := oc.client.Do(req) if err != nil { return "", errors.Wrap(err, "error retrieving device token") } + for _, c := range resp.Cookies() { + if c.Name == "DT" { // Device token + return c.Value, nil + } + } + + return "", fmt.Errorf("unable to get a device token from okta") +} + +// setDeviceTokenCookie sets the DT cookie in the HTTP Client cookie jar +// using the okta__saml2aws, we reduce making an extra api call +// this func can be uplifted in the future to set custom device tokens or used with +// getDeviceTokenFromOkta function +func (oc *Client) setDeviceTokenCookie(loginDetails *creds.LoginDetails) error { + + // getDeviceTokenFromOkta is not used but doing this to keep the function code + // uncommented (avoid linting issues) + if false { + dt, _ := oc.getDeviceTokenFromOkta(loginDetails) + logger.Debugf("getDeviceTokenFromOkta is not yet implemented: dt: %s", dt) + } + + oktaURL, err := url.Parse(loginDetails.URL) + if err != nil { + return errors.Wrap(err, "error building oktaURL to set device token cookie") + } + oktaURLScheme := oktaURL.Scheme + oktaURLHost := oktaURL.Host + baseURL := &url.URL{Scheme: oktaURLScheme, Host: oktaURLHost, Path: "/"} + + var cookies []*http.Cookie + cookie := http.Cookie{ + Name: "DT", + Secure: true, + Expires: time.Now().Add(time.Hour * 24 * 30), // 30 Days -> this time might not matter as this cookie is set on every saml2aws login request + Value: fmt.Sprintf("okta_%s_saml2aws", loginDetails.Username), // Okta recommends using an UUID but this should be unique enough. Also, this is key to remembering Okta MFA device + } + cookies = append(cookies, &cookie) + oc.client.Jar.SetCookies(baseURL, cookies) + + return nil +} + +// primaryAuth creates the Okta Primary Authentication request +// returns the authStatus, sessionToken, http response and a error +func (oc *Client) primaryAuth(loginDetails *creds.LoginDetails) (string, string, string, error) { + + oktaURL, err := url.Parse(loginDetails.URL) + if err != nil { + return "", "", "", errors.Wrap(err, "error building oktaURL") + } + + oktaOrgHost := oktaURL.Host //authenticate via okta api authReq := AuthRequest{Username: loginDetails.Username, Password: loginDetails.Password} if loginDetails.StateToken != "" { @@ -143,14 +408,14 @@ func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) authBody := new(bytes.Buffer) err = json.NewEncoder(authBody).Encode(authReq) if err != nil { - return "", errors.Wrap(err, "error encoding authreq") + return "", "", "", errors.Wrap(err, "error encoding authreq") } authSubmitURL := fmt.Sprintf("https://%s/api/v1/authn", oktaOrgHost) - req, err = http.NewRequest("POST", authSubmitURL, authBody) + req, err := http.NewRequest("POST", authSubmitURL, authBody) if err != nil { - return "", errors.Wrap(err, "error building authentication request") + return "", "", "", errors.Wrap(err, "error building authentication request") } req.Header.Add("Content-Type", "application/json") @@ -158,12 +423,12 @@ func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) res, err := oc.client.Do(req) if err != nil { - return "", errors.Wrap(err, "error retrieving auth response") + return "", "", "", errors.Wrap(err, "error retrieving auth response") } body, err := ioutil.ReadAll(res.Body) if err != nil { - return "", errors.Wrap(err, "error retrieving body from response") + return "", "", "", errors.Wrap(err, "error retrieving body from response") } resp := string(body) @@ -171,39 +436,100 @@ func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) authStatus := gjson.Get(resp, "status").String() oktaSessionToken := gjson.Get(resp, "sessionToken").String() + return authStatus, oktaSessionToken, resp, nil +} + +// Authenticate logs into Okta and returns a SAML response +func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) { + + // Set Okta device token + err := oc.setDeviceTokenCookie(loginDetails) + if err != nil { + return "", errors.Wrap(err, "error setting device token in cookie jar") + } + + // Get Okta session cookie (sid) from login details (if found via login.go) + oktaSessionCookie := loginDetails.OktaSessionCookie + + // If user disabled sessions, do not use sessions API + if !oc.disableSessions { + // If Okta session cookie is not empty + // Note on checking StateToken: StateToken is set in the follow func + // if the follow func calls this function (Authenticate), it means the session requires MFA to continue + // so don't call authWithSession, instead flow through to create the primary authentication call + if oktaSessionCookie != "" && loginDetails.StateToken == "" { + return oc.authWithSession(loginDetails) + } + } + + oktaURL, err := url.Parse(loginDetails.URL) + if err != nil { + return "", errors.Wrap(err, "error building oktaURL") + } + + oktaOrgHost := oktaURL.Host + + authStatus, oktaSessionToken, primaryAuthResp, err := oc.primaryAuth(loginDetails) + if err != nil { + return "", err + } + // mfa required if authStatus == "MFA_REQUIRED" { - oktaSessionToken, err = verifyMfa(oc, oktaOrgHost, loginDetails, resp) + oktaSessionToken, err = verifyMfa(oc, oktaOrgHost, loginDetails, primaryAuthResp) if err != nil { return "", errors.Wrap(err, "error verifying MFA") } } - //now call saml endpoint - oktaSessionRedirectURL := fmt.Sprintf("https://%s/login/sessionCookieRedirect", oktaOrgHost) + // if user disabled sessions, default to using standard login WITHOUT sessions + if oc.disableSessions { + //now call saml endpoint + oktaSessionRedirectURL := fmt.Sprintf("https://%s/login/sessionCookieRedirect", oktaOrgHost) - req, err = http.NewRequest("GET", oktaSessionRedirectURL, nil) - if err != nil { - return "", errors.Wrap(err, "error building authentication request") + req, err := http.NewRequest("GET", oktaSessionRedirectURL, nil) + if err != nil { + return "", errors.Wrap(err, "error building authentication request") + } + q := req.URL.Query() + q.Add("checkAccountSetupComplete", "true") + q.Add("token", oktaSessionToken) + q.Add("redirectUrl", loginDetails.URL) + req.URL.RawQuery = q.Encode() + + ctx := context.WithValue(context.Background(), ctxKey("login"), loginDetails) + return oc.follow(ctx, req, loginDetails) } - q := req.URL.Query() - q.Add("checkAccountSetupComplete", "true") - q.Add("token", oktaSessionToken) - q.Add("redirectUrl", loginDetails.URL) - req.URL.RawQuery = q.Encode() - ctx := context.WithValue(context.Background(), ctxKey("login"), loginDetails) - return oc.follow(ctx, req, loginDetails) + // Only reaches here if user DID NOT DISABLE okta sessions + if oktaSessionCookie == "" { + oktaSessionCookie, _, err = oc.createSession(loginDetails, oktaSessionToken) + if err != nil { + return "", err + } + loginDetails.OktaSessionCookie = oktaSessionCookie + } + + return oc.authWithSession(loginDetails) } func (oc *Client) follow(ctx context.Context, req *http.Request, loginDetails *creds.LoginDetails) (string, error) { + if ctx.Value(ctxKey("follow")) != nil { + logger.Debug("follow func called from itself") + } + + if ctx.Value(ctxKey("authWithSession")) != nil { + logger.Debug("follow func called from auth with session func") + } res, err := oc.client.Do(req) if err != nil { + logger.Debug("ERROR FOLLOWING") return "", errors.Wrap(err, "error following") } doc, err := goquery.NewDocumentFromReader(res.Body) if err != nil { + logger.Debug("FAILED TO BUILD DOC FROM RESP") return "", errors.Wrap(err, "failed to build document from response") } @@ -348,7 +674,7 @@ func getMfaChallengeContext(oc *Client, mfaOption int, resp string) (*mfaChallen } // get signature & callback - verifyReq := VerifyRequest{StateToken: stateToken} + verifyReq := VerifyRequest{StateToken: stateToken, RememberDevice: strconv.FormatBool(oc.rememberDevice)} verifyBody := new(bytes.Buffer) // Login flow is different for YubiKeys ( of course ) @@ -426,7 +752,7 @@ func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails, if verifyCode == "" { verifyCode = prompter.StringRequired("Enter verification code") } - tokenReq := VerifyRequest{StateToken: stateToken, PassCode: verifyCode} + tokenReq := VerifyRequest{StateToken: stateToken, PassCode: verifyCode, RememberDevice: strconv.FormatBool(oc.rememberDevice)} tokenBody := new(bytes.Buffer) err = json.NewEncoder(tokenBody).Encode(tokenReq) if err != nil { @@ -465,6 +791,7 @@ func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails, // on 'success' status if gjson.Get(body, "status").String() == "SUCCESS" { fmt.Printf(" Approved\n\n") + logger.Debugf("func verifyMfa | okta exiry: %s", gjson.Get(body, "expiresAt").String()) // DEBUG return gjson.Get(body, "sessionToken").String(), nil } @@ -738,7 +1065,7 @@ func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails, // extract okta session token - verifyReq := VerifyRequest{StateToken: stateToken} + verifyReq := VerifyRequest{StateToken: stateToken, RememberDevice: strconv.FormatBool(oc.rememberDevice)} verifyBody := new(bytes.Buffer) err = json.NewEncoder(verifyBody).Encode(verifyReq) if err != nil { diff --git a/pkg/provider/okta/okta_test.go b/pkg/provider/okta/okta_test.go index 4eadaf56e..9bae1e4e9 100644 --- a/pkg/provider/okta/okta_test.go +++ b/pkg/provider/okta/okta_test.go @@ -2,9 +2,13 @@ package okta import ( "errors" + "fmt" + "net/url" "testing" "github.com/stretchr/testify/assert" + "github.com/versent/saml2aws/v2/pkg/cfg" + "github.com/versent/saml2aws/v2/pkg/creds" ) type stateTokenTests struct { @@ -47,3 +51,66 @@ func TestGetStateTokenFromOktaPageBody(t *testing.T) { }) } } + +func TestSetDeviceTokenCookie(t *testing.T) { + idpAccount := cfg.NewIDPAccount() + idpAccount.URL = "https://idp.example.com/abcd" + idpAccount.Username = "user@example.com" + + loginDetails := &creds.LoginDetails{ + Username: "user@example.com", + Password: "abc123", + URL: "https://idp.example.com/abcd", + } + + oc, err := New(idpAccount) + assert.Nil(t, err) + + err = oc.setDeviceTokenCookie(loginDetails) + assert.Nil(t, err) + + expectedDT := fmt.Sprintf("okta_%s_saml2aws", loginDetails.Username) + actualDT := "" + for _, c := range oc.client.Jar.Cookies(&url.URL{Scheme: "https", Host: "idp.example.com", Path: "/abc"}) { + if c.Name == "DT" { + actualDT = c.Value + } + } + assert.NotEqual(t, actualDT, "") + assert.Equal(t, expectedDT, actualDT) + +} + +func TestOktaCfgFlagsDefaultState(t *testing.T) { + idpAccount := cfg.NewIDPAccount() + idpAccount.URL = "https://idp.example.com/abcd" + idpAccount.Username = "user@example.com" + + oc, err := New(idpAccount) + assert.Nil(t, err) + + assert.False(t, oc.disableSessions, fmt.Errorf("disableSessions should be false by default")) + assert.True(t, oc.rememberDevice, fmt.Errorf("rememberDevice should be true by default")) +} + +func TestOktaCfgFlagsCustomState(t *testing.T) { + idpAccount := cfg.NewIDPAccount() + idpAccount.URL = "https://idp.example.com/abcd" + idpAccount.Username = "user@example.com" + + idpAccount.DisableRememberDevice = true + oc, err := New(idpAccount) + assert.Nil(t, err) + + assert.False(t, oc.disableSessions, fmt.Errorf("disableSessions should be false by default")) + assert.False(t, oc.rememberDevice, fmt.Errorf("DisableRememberDevice was set to true, so rememberDevice should be false")) + + idpAccount.DisableSessions = true + + oc, err = New(idpAccount) + assert.Nil(t, err) + + assert.True(t, oc.disableSessions, fmt.Errorf("DisableSessions was set to true so disableSessions should be true")) + assert.False(t, oc.rememberDevice, fmt.Errorf("DisablDisableSessionseRememberDevice was set to true, so rememberDevice should be false")) + +}