Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Commit

Permalink
Refresh our session as a way to test it between runs & add better deb…
Browse files Browse the repository at this point in the history
…ug output
  • Loading branch information
mipearson committed May 7, 2021
1 parent ec1a8f1 commit 6180520
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 13 deletions.
4 changes: 2 additions & 2 deletions cli/assume_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ var notARoleErrorMessage = `'%s' is neither an IAM role ARN nor a configured ali
Run 'yak --list-roles' to see which roles and aliases you can use.`

func AssumeRole(role string) (*sts.AssumeRoleWithSAMLOutput, error) {

creds := getAssumedRoleFromCache(role)

if creds == nil {
Expand All @@ -41,6 +40,8 @@ func AssumeRole(role string) (*sts.AssumeRoleWithSAMLOutput, error) {
return nil, err
}

log.WithField("role", creds).Debug("assume_role.go: Role assumption credentials from AWS")

cache.WriteDefault(role, creds)
cache.Export()
}
Expand All @@ -56,7 +57,6 @@ func getAssumedRoleFromCache(role string) *sts.AssumeRoleWithSAMLOutput {
}

return &data

}

func ResolveRole(roleName string) (string, error) {
Expand Down
49 changes: 41 additions & 8 deletions cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,39 @@ func GetRolesFromCache() ([]saml.LoginRole, bool) {
return roles, true
}

func oktaDomain() string {
return viper.GetString("okta.domain")
}

func oktaUsername() string {
return viper.GetString("okta.username")
}

func oktaSessionCacheKey() string {
return fmt.Sprintf("okta:sessionToken:%s:%s", viper.GetString("okta.domain"), viper.GetString("okta.username"))
return fmt.Sprintf("okta:sessionToken:%s:%s", oktaDomain(), oktaUsername())
}

func getOktaSessionFromCache() (*okta.OktaSession, bool) {
data, ok := cache.Check(oktaSessionCacheKey()).(okta.OktaSession)
return &data, ok
}

func checkOktaSession(session *okta.OktaSession) bool {
response, err := okta.GetSession(oktaDomain(), session)

// This needs explaining: Okta's "Create Session" API call gives
// us a session ID that we set as the `sid` cookie. Get & Refresh return a
// *different* ID that we can't use as the cookie, but they both
// extend the calling session.

if err == nil {
session.ExpiresAt = response.ExpiresAt
cache.Write(oktaSessionCacheKey(), *session, session.ExpiresAt.Sub(time.Now()))
}

return err == nil
}

func GetLoginDataWithTimeout() (saml.LoginData, error) {
errorChannel := make(chan error)
resultChannel := make(chan saml.LoginData)
Expand Down Expand Up @@ -93,11 +117,19 @@ func GetLoginDataWithTimeout() (saml.LoginData, error) {
func getLoginData() (saml.LoginData, error) {
session, gotSession := getOktaSessionFromCache()

if gotSession && session.ExpiresAt.After(time.Now()) {
log.Infof("Okta session found in cache (%s), expires %s", session.Id, session.ExpiresAt.String())
gotSession = checkOktaSession(session)
if gotSession {
log.Infof("Refreshed session, now expires %s", session.ExpiresAt.String())
}
}

if !gotSession {
var authResponse okta.OktaAuthResponse
var err error

log.Infof("Okta session not found in cache")
log.Infof("Okta session not in cache or no longer valid, re-authenticating")

if viper.GetBool("cache.cache_only") {
return saml.LoginData{}, errors.New("Could not find credentials in cache and --cache-only specified. Exiting.")
Expand Down Expand Up @@ -131,7 +163,7 @@ func getLoginData() (saml.LoginData, error) {

}

samlPayload, err := okta.AwsSamlLogin(viper.GetString("okta.domain"), viper.GetString("okta.aws_saml_endpoint"), *session)
samlPayload, err := okta.AwsSamlLogin(oktaDomain(), viper.GetString("okta.aws_saml_endpoint"), *session)
if err != nil {
return saml.LoginData{}, err
}
Expand All @@ -141,6 +173,7 @@ func getLoginData() (saml.LoginData, error) {
if err != nil {
return saml.LoginData{}, err
}
log.WithField("saml", samlResponse).Debug("okta.go: SAML response from Okta")

return saml.CreateLoginData(samlResponse, samlPayload), nil
}
Expand Down Expand Up @@ -197,11 +230,11 @@ func chooseMFA(authResponse okta.OktaAuthResponse) (okta.AuthResponseFactor, err
}

func getOktaSession(authResponse okta.OktaAuthResponse) (session *okta.OktaSession, err error) {
log.Infof("Creating new Okta session for %s", viper.GetString("okta.domain"))
session, err = okta.CreateSession(viper.GetString("okta.domain"), authResponse)
log.Infof("Creating new Okta session for %s", oktaDomain())
session, err = okta.CreateSession(oktaDomain(), authResponse)

if err == nil {
cache.WriteDefault(oktaSessionCacheKey(), *session)
cache.Write(oktaSessionCacheKey(), *session, session.ExpiresAt.Sub(time.Now()))
}

return
Expand Down Expand Up @@ -298,7 +331,7 @@ func promptLogin() (okta.OktaAuthResponse, error) {

for unauthorised && (retries < maxLoginRetries) {
retries++
username := viper.GetString("okta.username")
username := oktaUsername()
promptUsername := (username == "")

// Viper isn't used here because it's really hard to get Viper to not accept values through the config file
Expand Down Expand Up @@ -328,7 +361,7 @@ func promptLogin() (okta.OktaAuthResponse, error) {
}
}

authResponse, err = okta.Authenticate(viper.GetString("okta.domain"), okta.UserData{username, password})
authResponse, err = okta.Authenticate(oktaDomain(), okta.UserData{username, password})

if authResponse.YakStatusCode == okta.YAK_STATUS_UNAUTHORISED && retries < maxLoginRetries && !envPassword {
fmt.Fprintln(os.Stderr, "Sorry, try again.")
Expand Down
2 changes: 2 additions & 0 deletions cmd/list_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"fmt"

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/spf13/viper"

Expand All @@ -14,6 +15,7 @@ func listRolesCmd(cmd *cobra.Command, args []string) error {
roles, gotRoles := cli.GetRolesFromCache()

if !gotRoles {
log.Infof("Role list not in cache, grabbing from AWS")
loginData, err := cli.GetLoginDataWithTimeout()

if err != nil {
Expand Down
8 changes: 7 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ These can be configured either in the [okta] section of ~/.config/yak/config.tom
return err
}

if viper.GetBool("verbose") {
if viper.GetBool("debug") {
log.SetLevel(log.DebugLevel)
} else if viper.GetBool("verbose") {
log.SetLevel(log.InfoLevel)
} else {
log.SetLevel(log.WarnLevel)
Expand Down Expand Up @@ -122,11 +124,15 @@ func init() {
rootCmd.PersistentFlags().Bool("clear-cache", false, "Delete all data from yak's cache. If no other arguments are given, exit without error")
rootCmd.PersistentFlags().Bool("version", false, "Print the current version and exit")
rootCmd.PersistentFlags().BoolP("verbose", "v", false, "Print our actions as we take them")
rootCmd.PersistentFlags().Bool("debug", false, "Print detailed debug information, including session tokens")

rootCmd.PersistentFlags().Bool("credits", false, "Print the contributing authors")
viper.BindPFlag("list-roles", rootCmd.PersistentFlags().Lookup("list-roles"))
viper.BindPFlag("clear-cache", rootCmd.PersistentFlags().Lookup("clear-cache"))
viper.BindPFlag("version", rootCmd.PersistentFlags().Lookup("version"))
viper.BindPFlag("credits", rootCmd.PersistentFlags().Lookup("credits"))
viper.BindPFlag("verbose", rootCmd.PersistentFlags().Lookup("verbose"))
viper.BindPFlag("debug", rootCmd.PersistentFlags().Lookup("debug"))

rootCmd.PersistentFlags().StringP("okta-username", "u", "", "Your Okta username")
rootCmd.PersistentFlags().String("okta-domain", "", "The domain to use for requests to Okta")
Expand Down
49 changes: 47 additions & 2 deletions okta/okta.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"os"
"time"

log "github.com/sirupsen/logrus"
"golang.org/x/net/html"
)

Expand Down Expand Up @@ -76,10 +77,11 @@ type OktaAuthResponse struct {
}

type OktaSession struct {
Id string `json:"id"`
ExpiresAt string `json:"expiresAt"`
Id string `json:"id"`
ExpiresAt time.Time `json:"expiresAt"`
}

// TODO: DRY
func CreateSession(oktaHref string, authResponse OktaAuthResponse) (*OktaSession, error) {
authBody, err := json.Marshal(map[string]string{"sessionToken": authResponse.SessionToken})
if err != nil {
Expand Down Expand Up @@ -108,9 +110,49 @@ func CreateSession(oktaHref string, authResponse OktaAuthResponse) (*OktaSession
if err := json.Unmarshal(body, &session); err != nil {
return nil, err
}
log.WithField("session", session).Debug("okta.go: Created Session from Okta")
return &session, nil
}

// TODO: DRY
func GetSession(oktaHref string, session *OktaSession) (*OktaSession, error) {
oktaUrl, err := url.Parse(oktaHref)
if err != nil {
return nil, err
}

sessionEndpoint, _ := url.Parse("/api/v1/sessions/me")
sessionUrl := oktaUrl.ResolveReference(sessionEndpoint)

jar, _ := cookiejar.New(nil)
jar.SetCookies(sessionUrl, []*http.Cookie{{Name: "sid", Value: session.Id}})

client := http.Client{
Jar: jar,
}

resp, err := client.Get(sessionUrl.String())
if err != nil {
return nil, err
}

if resp.StatusCode > 300 {
return nil, fmt.Errorf("Status code %d, expected < 2xx", resp.StatusCode)
}

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}

newSession := OktaSession{}
if err := json.Unmarshal(body, &newSession); err != nil {
return nil, err
}
log.WithField("session", string(body)).Debug("okta.go: Retrieved Session from Okta")
return &newSession, nil
}

func Authenticate(oktaHref string, userData UserData) (OktaAuthResponse, error) {
authBody, err := json.Marshal(userData)

Expand All @@ -130,11 +172,13 @@ func Authenticate(oktaHref string, userData UserData) (OktaAuthResponse, error)
body, yakStatus, err := makeRequest(primaryAuthUrl.String(), bytes.NewBuffer(authBody))

if err != nil {
log.WithField("err", err).Debug("okta.go: Okta login error")
return OktaAuthResponse{YakStatusCode: yakStatus}, err
}

authResponse := OktaAuthResponse{YakStatusCode: YAK_STATUS_OK}
json.Unmarshal(body, &authResponse)
log.WithField("response", authResponse).Debug("okta.go: Auth response for Okta login")

return authResponse, nil
}
Expand Down Expand Up @@ -266,6 +310,7 @@ func AwsSamlLogin(oktaHref string, samlHref string, oktasession OktaSession) (st

func makeRequest(url string, body io.Reader) ([]byte, int, error) {
resp, err := http.Post(url, "application/json", body)
log.WithField("url", url).WithField("statusCode", resp.StatusCode).Debug("okta.go: Okta request")

if err != nil {
return []byte{}, YAK_STATUS_NET_ERROR, err
Expand Down

0 comments on commit 6180520

Please sign in to comment.